diff --git a/fastapi/applications.py b/fastapi/applications.py index 480d03ab8..d6148b9d5 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -978,12 +978,15 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None if ignore_trailing_slash: + def ignore_trailing_whitespace_middleware(app): async def ignore_trailing_whitespace_wrapper(scope, receive, send): if scope["type"] in {"http", "websocket"}: scope["path"] = scope["path"].rstrip("/") await app(scope, receive, send) + return ignore_trailing_whitespace_wrapper + self.add_middleware(ignore_trailing_whitespace_middleware) self.setup() diff --git a/fastapi/routing.py b/fastapi/routing.py index 7daec9b0f..0935a29f2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -881,6 +881,7 @@ class APIRouter(routing.Router): ) -> Callable[[DecoratedCallable], DecoratedCallable]: if self.ignore_trailing_slash: path = path.rstrip("/") + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( path, @@ -1134,6 +1135,7 @@ class APIRouter(routing.Router): ) -> Callable[[DecoratedCallable], DecoratedCallable]: if self.ignore_trailing_slash: path = path.rstrip("/") + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) return func @@ -1375,9 +1377,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, routing.WebSocketRoute): - self.add_websocket_route( - prefix + path, route.endpoint, name=route.name - ) + self.add_websocket_route(prefix + path, route.endpoint, name=route.name) for handler in router.on_startup: self.add_event_handler("startup", handler) for handler in router.on_shutdown: diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 8cd3c802a..be35993ff 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -1,41 +1,49 @@ -from fastapi import FastAPI, WebSocket, APIRouter +from fastapi import APIRouter, FastAPI, WebSocket from fastapi.testclient import TestClient app = FastAPI(ignore_trailing_slash=True) router = APIRouter() + @app.get("/example") async def example_endpoint(): return {"msg": "Example"} + @app.get("/example2/") -async def example_endpoint(): +async def example_endpoint_with_slash(): return {"msg": "Example 2"} + @app.websocket("/websocket") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.send_text("Websocket") await websocket.close() + @app.websocket("/websocket2/") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint_with_slash(websocket: WebSocket): await websocket.accept() await websocket.send_text("Websocket 2") await websocket.close() + @router.get("/example") def route_endpoint(): return {"msg": "Routing Example"} + @router.get("/example2/") -def route_endpoint(): +def route_endpoint_with_slash(): return {"msg": "Routing Example 2"} + app.include_router(router, prefix="/router") client = TestClient(app) + def test_ignoring_trailing_slash(): response = client.get("/example/", follow_redirects=False) assert response.status_code == 200 @@ -44,16 +52,18 @@ def test_ignoring_trailing_slash(): assert response.status_code == 200 assert response.json()["msg"] == "Example 2" + def test_ignoring_trailing_shlash_ws(): with client.websocket_connect("/websocket/") as websocket: assert websocket.receive_text() == "Websocket" with client.websocket_connect("/websocket2") as websocket: assert websocket.receive_text() == "Websocket 2" + def test_ignoring_trailing_routing(): response = client.get("router/example/", follow_redirects=False) assert response.status_code == 200 assert response.json()["msg"] == "Routing Example" response = client.get("router/example2", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Routing Example 2" \ No newline at end of file + assert response.json()["msg"] == "Routing Example 2"