From b00fbd3427737b80a7e58f556b65a9d7f6f457aa Mon Sep 17 00:00:00 2001 From: Synrom Date: Wed, 4 Sep 2024 19:24:44 +0000 Subject: [PATCH] Fix middleware and add tests --- fastapi/applications.py | 17 +++++---- fastapi/routing.py | 14 ++++---- tests/test_ignore_trailing_slash.py | 56 ++++++++++++++++++----------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6fc595634..5a7170dd6 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -810,7 +810,7 @@ class FastAPI(Starlette): """ ), ] = True, - ignore_trailing_whitespaces: Annotated[ + ignore_trailing_slash: Annotated[ bool, Doc( """ @@ -950,7 +950,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, - ignore_trailing_whitespaces=ignore_trailing_whitespaces, + ignore_trailing_slash=ignore_trailing_slash, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] @@ -970,11 +970,14 @@ class FastAPI(Starlette): ) self.middleware_stack: Union[ASGIApp, None] = None - if ignore_trailing_whitespaces: - async def middleware_ignore_tailing_whitespace(request: Request, call_next): - request.scope["path"] = request.scope["path"].rstrip("/") - return await call_next(request) - self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace) + if ignore_trailing_slash: + def ignore_trailing_whitespace_middleware(app): + async def ignore_trailing_whitespace_wrapper(scope, receive, send): + scope["path"] = scope["path"].rstrip("/") + await app(scope, receive, send) + return ignore_trailing_whitespace_wrapper + self.add_middleware(ignore_trailing_whitespace_middleware) + self.setup() def openapi(self) -> Dict[str, Any]: diff --git a/fastapi/routing.py b/fastapi/routing.py index cb14fd399..d59701ef5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -811,7 +811,7 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), - ignore_trailing_whitespaces: Annotated[ + ignore_trailing_slash: Annotated[ bool, Doc( """ @@ -843,7 +843,7 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function - self.ignore_trailing_whitespaces = ignore_trailing_whitespaces + self.ignore_trailing_slash = ignore_trailing_slash def route( self, @@ -852,7 +852,7 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( @@ -900,7 +900,7 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") route_class = route_class_override or self.route_class responses = responses or {} @@ -1020,7 +1020,7 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") current_dependencies = self.dependencies.copy() if dependencies: @@ -1105,7 +1105,7 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) @@ -1265,7 +1265,7 @@ class APIRouter(routing.Router): responses = {} for route in router.routes: path = route.path - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") if isinstance(route, APIRoute): combined_responses = {**responses, **route.responses} diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 551ccd61b..16537d328 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -1,29 +1,45 @@ -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket, APIRouter from fastapi.testclient import TestClient -recognizing_app = FastAPI() -ignoring_app = FastAPI(ignore_trailing_whitespaces=True) +app = FastAPI(ignore_trailing_slash=True) +router = APIRouter() -@recognizing_app.get("/example") -@ignoring_app.get("/example") -async def return_data(): - return {"msg": "Reached the route!"} +@app.get("/example") +async def example_endpoint(): + return {"msg": "Example"} -recognizing_client = TestClient(recognizing_app) -ignoring_client = TestClient(ignoring_app) +@app.websocket("/websocket") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket") + await websocket.close() -def test_recognizing_trailing_slash(): - response = recognizing_client.get("/example", follow_redirects=False) - assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" - response = recognizing_client.get("/example/", follow_redirects=False) - assert response.status_code == 307 - assert response.headers["location"].endswith("/example") +@router.get("/example") +def route_endpoint(): + return {"msg": "Routing Example"} + +app.include_router(router, prefix="/router") + +client = TestClient(app) def test_ignoring_trailing_slash(): - response = ignoring_client.get("/example", follow_redirects=False) + response = client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example" + response = client.get("/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example" + +def test_ignoring_trailing_shlash_ws(): + with client.websocket_connect("/websocket") as websocket: + assert websocket.receive_text() == "Websocket" + with client.websocket_connect("/websocket/") as websocket: + assert websocket.receive_text() == "Websocket" + +def test_ignoring_trailing_routing(): + response = client.get("router/example", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" - response = ignoring_client.get("/example/", follow_redirects=False) + assert response.json()["msg"] == "Routing Example" + response = client.get("router/example/", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" + assert response.json()["msg"] == "Routing Example"