From f5d7df3c6ce0b7184404a87195393a48e1c6d2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 21 Feb 2022 16:51:26 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Include=20route=20in=20scope=20to?= =?UTF-8?q?=20allow=20middleware=20and=20other=20tools=20to=20extract=20it?= =?UTF-8?q?s=20information=20(#4603)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 17 +++++++++++-- tests/test_route_scope.py | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 tests/test_route_scope.py diff --git a/fastapi/routing.py b/fastapi/routing.py index f6d5370d6..7dae04521 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -13,6 +13,7 @@ from typing import ( Optional, Sequence, Set, + Tuple, Type, Union, ) @@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.routing import BaseRoute +from starlette.routing import BaseRoute, Match from starlette.routing import Mount as Mount # noqa from starlette.routing import ( compile_path, @@ -53,7 +54,7 @@ from starlette.routing import ( websocket_session, ) from starlette.status import WS_1008_POLICY_VIOLATION -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Scope from starlette.websockets import WebSocket @@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute): ) self.path_regex, self.path_format, self.param_convertors = compile_path(path) + def matches(self, scope: Scope) -> Tuple[Match, Scope]: + match, child_scope = super().matches(scope) + if match != Match.NONE: + child_scope["route"] = self + return match, child_scope + class APIRoute(routing.Route): def __init__( @@ -432,6 +439,12 @@ class APIRoute(routing.Route): dependency_overrides_provider=self.dependency_overrides_provider, ) + def matches(self, scope: Scope) -> Tuple[Match, Scope]: + match, child_scope = super().matches(scope) + if match != Match.NONE: + child_scope["route"] = self + return match, child_scope + class APIRouter(routing.Router): def __init__( diff --git a/tests/test_route_scope.py b/tests/test_route_scope.py new file mode 100644 index 000000000..a188e9a5f --- /dev/null +++ b/tests/test_route_scope.py @@ -0,0 +1,50 @@ +import pytest +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.routing import APIRoute, APIWebSocketRoute +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.get("/users/{user_id}") +async def get_user(user_id: str, request: Request): + route: APIRoute = request.scope["route"] + return {"user_id": user_id, "path": route.path} + + +@app.websocket("/items/{item_id}") +async def websocket_item(item_id: str, websocket: WebSocket): + route: APIWebSocketRoute = websocket.scope["route"] + await websocket.accept() + await websocket.send_json({"item_id": item_id, "path": route.path}) + + +client = TestClient(app) + + +def test_get(): + response = client.get("/users/rick") + assert response.status_code == 200, response.text + assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"} + + +def test_invalid_method_doesnt_match(): + response = client.post("/users/rick") + assert response.status_code == 405, response.text + + +def test_invalid_path_doesnt_match(): + response = client.post("/usersx/rick") + assert response.status_code == 404, response.text + + +def test_websocket(): + with client.websocket_connect("/items/portal-gun") as websocket: + data = websocket.receive_json() + assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"} + + +def test_websocket_invalid_path_doesnt_match(): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect("/itemsx/portal-gun") as websocket: + websocket.receive_json()