Browse Source

Include route in scope to allow middleware and other tools to extract its information (#4603)

pull/4301/merge
Sebastián Ramírez 3 years ago
committed by GitHub
parent
commit
f5d7df3c6c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      fastapi/routing.py
  2. 50
      tests/test_route_scope.py

17
fastapi/routing.py

@ -13,6 +13,7 @@ from typing import (
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.routing import BaseRoute, Match
from starlette.routing import Mount as Mount # noqa
from starlette.routing import (
compile_path,
@ -53,7 +54,7 @@ from starlette.routing import (
websocket_session,
)
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket
@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRoute(routing.Route):
def __init__(
@ -432,6 +439,12 @@ class APIRoute(routing.Route):
dependency_overrides_provider=self.dependency_overrides_provider,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRouter(routing.Router):
def __init__(

50
tests/test_route_scope.py

@ -0,0 +1,50 @@
import pytest
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.routing import APIRoute, APIWebSocketRoute
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/users/{user_id}")
async def get_user(user_id: str, request: Request):
route: APIRoute = request.scope["route"]
return {"user_id": user_id, "path": route.path}
@app.websocket("/items/{item_id}")
async def websocket_item(item_id: str, websocket: WebSocket):
route: APIWebSocketRoute = websocket.scope["route"]
await websocket.accept()
await websocket.send_json({"item_id": item_id, "path": route.path})
client = TestClient(app)
def test_get():
response = client.get("/users/rick")
assert response.status_code == 200, response.text
assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"}
def test_invalid_method_doesnt_match():
response = client.post("/users/rick")
assert response.status_code == 405, response.text
def test_invalid_path_doesnt_match():
response = client.post("/usersx/rick")
assert response.status_code == 404, response.text
def test_websocket():
with client.websocket_connect("/items/portal-gun") as websocket:
data = websocket.receive_json()
assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"}
def test_websocket_invalid_path_doesnt_match():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/itemsx/portal-gun") as websocket:
websocket.receive_json()
Loading…
Cancel
Save