Browse Source

Include route in scope to allow middleware and other tools to extract its information (#4603)

pull/4301/merge
Sebastián Ramírez 3 years ago
committed by GitHub
parent
commit
f5d7df3c6c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      fastapi/routing.py
  2. 50
      tests/test_route_scope.py

17
fastapi/routing.py

@ -13,6 +13,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
Type, Type,
Union, Union,
) )
@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute from starlette.routing import BaseRoute, Match
from starlette.routing import Mount as Mount # noqa from starlette.routing import Mount as Mount # noqa
from starlette.routing import ( from starlette.routing import (
compile_path, compile_path,
@ -53,7 +54,7 @@ from starlette.routing import (
websocket_session, websocket_session,
) )
from starlette.status import WS_1008_POLICY_VIOLATION from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
) )
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRoute(routing.Route): class APIRoute(routing.Route):
def __init__( def __init__(
@ -432,6 +439,12 @@ class APIRoute(routing.Route):
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
) )
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
return match, child_scope
class APIRouter(routing.Router): class APIRouter(routing.Router):
def __init__( def __init__(

50
tests/test_route_scope.py

@ -0,0 +1,50 @@
import pytest
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.routing import APIRoute, APIWebSocketRoute
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/users/{user_id}")
async def get_user(user_id: str, request: Request):
route: APIRoute = request.scope["route"]
return {"user_id": user_id, "path": route.path}
@app.websocket("/items/{item_id}")
async def websocket_item(item_id: str, websocket: WebSocket):
route: APIWebSocketRoute = websocket.scope["route"]
await websocket.accept()
await websocket.send_json({"item_id": item_id, "path": route.path})
client = TestClient(app)
def test_get():
response = client.get("/users/rick")
assert response.status_code == 200, response.text
assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"}
def test_invalid_method_doesnt_match():
response = client.post("/users/rick")
assert response.status_code == 405, response.text
def test_invalid_path_doesnt_match():
response = client.post("/usersx/rick")
assert response.status_code == 404, response.text
def test_websocket():
with client.websocket_connect("/items/portal-gun") as websocket:
data = websocket.receive_json()
assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"}
def test_websocket_invalid_path_doesnt_match():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/itemsx/portal-gun") as websocket:
websocket.receive_json()
Loading…
Cancel
Save