From d4eaafb8043a988701ff965c1c061c586b27f9ff Mon Sep 17 00:00:00 2001 From: Synrom Date: Wed, 4 Sep 2024 17:03:56 +0000 Subject: [PATCH 01/12] Add ignore_trailing_slashes flag to applications --- fastapi/applications.py | 14 ++++++++++++++ fastapi/routing.py | 27 +++++++++++++++++++++++---- tests/test_ignore_trailing_slash.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/test_ignore_trailing_slash.py diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..6fc595634 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -810,6 +810,13 @@ class FastAPI(Starlette): """ ), ] = True, + ignore_trailing_whitespaces: Annotated[ + bool, + Doc( + """ + """ + ), + ] = False, **extra: Annotated[ Any, Doc( @@ -943,6 +950,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + ignore_trailing_whitespaces=ignore_trailing_whitespaces, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] @@ -961,6 +969,12 @@ class FastAPI(Starlette): [] if middleware is None else list(middleware) ) self.middleware_stack: Union[ASGIApp, None] = None + + if ignore_trailing_whitespaces: + async def middleware_ignore_tailing_whitespace(request: Request, call_next): + request.scope["path"] = request.scope["path"].rstrip("/") + return await call_next(request) + self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace) self.setup() def openapi(self) -> Dict[str, Any]: diff --git a/fastapi/routing.py b/fastapi/routing.py index 61a112fc4..cb14fd399 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -811,6 +811,13 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + ignore_trailing_whitespaces: Annotated[ + bool, + Doc( + """ + """ + ), + ] = False, ) -> None: super().__init__( routes=routes, @@ -836,6 +843,7 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.ignore_trailing_whitespaces = ignore_trailing_whitespaces def route( self, @@ -844,6 +852,8 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( path, @@ -890,6 +900,8 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") route_class = route_class_override or self.route_class responses = responses or {} combined_responses = {**self.responses, **responses} @@ -1008,6 +1020,8 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -1091,6 +1105,8 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) return func @@ -1248,6 +1264,9 @@ class APIRouter(routing.Router): if responses is None: responses = {} for route in router.routes: + path = route.path + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") if isinstance(route, APIRoute): combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( @@ -1278,7 +1297,7 @@ class APIRouter(routing.Router): self.generate_unique_id_function, ) self.add_api_route( - prefix + route.path, + prefix + path, route.endpoint, response_model=route.response_model, status_code=route.status_code, @@ -1310,7 +1329,7 @@ class APIRouter(routing.Router): elif isinstance(route, routing.Route): methods = list(route.methods or []) self.add_route( - prefix + route.path, + prefix + path, route.endpoint, methods=methods, include_in_schema=route.include_in_schema, @@ -1323,14 +1342,14 @@ class APIRouter(routing.Router): if route.dependencies: current_dependencies.extend(route.dependencies) self.add_api_websocket_route( - prefix + route.path, + prefix + path, route.endpoint, dependencies=current_dependencies, name=route.name, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( - prefix + route.path, route.endpoint, name=route.name + prefix + path, route.endpoint, name=route.name ) for handler in router.on_startup: self.add_event_handler("startup", handler) diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py new file mode 100644 index 000000000..551ccd61b --- /dev/null +++ b/tests/test_ignore_trailing_slash.py @@ -0,0 +1,29 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +recognizing_app = FastAPI() +ignoring_app = FastAPI(ignore_trailing_whitespaces=True) + +@recognizing_app.get("/example") +@ignoring_app.get("/example") +async def return_data(): + return {"msg": "Reached the route!"} + +recognizing_client = TestClient(recognizing_app) +ignoring_client = TestClient(ignoring_app) + +def test_recognizing_trailing_slash(): + response = recognizing_client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!" + response = recognizing_client.get("/example/", follow_redirects=False) + assert response.status_code == 307 + assert response.headers["location"].endswith("/example") + +def test_ignoring_trailing_slash(): + response = ignoring_client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!" + response = ignoring_client.get("/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!" From b00fbd3427737b80a7e58f556b65a9d7f6f457aa Mon Sep 17 00:00:00 2001 From: Synrom Date: Wed, 4 Sep 2024 19:24:44 +0000 Subject: [PATCH 02/12] Fix middleware and add tests --- fastapi/applications.py | 17 +++++---- fastapi/routing.py | 14 ++++---- tests/test_ignore_trailing_slash.py | 56 ++++++++++++++++++----------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6fc595634..5a7170dd6 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -810,7 +810,7 @@ class FastAPI(Starlette): """ ), ] = True, - ignore_trailing_whitespaces: Annotated[ + ignore_trailing_slash: Annotated[ bool, Doc( """ @@ -950,7 +950,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, - ignore_trailing_whitespaces=ignore_trailing_whitespaces, + ignore_trailing_slash=ignore_trailing_slash, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] @@ -970,11 +970,14 @@ class FastAPI(Starlette): ) self.middleware_stack: Union[ASGIApp, None] = None - if ignore_trailing_whitespaces: - async def middleware_ignore_tailing_whitespace(request: Request, call_next): - request.scope["path"] = request.scope["path"].rstrip("/") - return await call_next(request) - self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace) + if ignore_trailing_slash: + def ignore_trailing_whitespace_middleware(app): + async def ignore_trailing_whitespace_wrapper(scope, receive, send): + scope["path"] = scope["path"].rstrip("/") + await app(scope, receive, send) + return ignore_trailing_whitespace_wrapper + self.add_middleware(ignore_trailing_whitespace_middleware) + self.setup() def openapi(self) -> Dict[str, Any]: diff --git a/fastapi/routing.py b/fastapi/routing.py index cb14fd399..d59701ef5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -811,7 +811,7 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), - ignore_trailing_whitespaces: Annotated[ + ignore_trailing_slash: Annotated[ bool, Doc( """ @@ -843,7 +843,7 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function - self.ignore_trailing_whitespaces = ignore_trailing_whitespaces + self.ignore_trailing_slash = ignore_trailing_slash def route( self, @@ -852,7 +852,7 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( @@ -900,7 +900,7 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") route_class = route_class_override or self.route_class responses = responses or {} @@ -1020,7 +1020,7 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") current_dependencies = self.dependencies.copy() if dependencies: @@ -1105,7 +1105,7 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) @@ -1265,7 +1265,7 @@ class APIRouter(routing.Router): responses = {} for route in router.routes: path = route.path - if self.ignore_trailing_whitespaces: + if self.ignore_trailing_slash: path = path.rstrip("/") if isinstance(route, APIRoute): combined_responses = {**responses, **route.responses} diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 551ccd61b..16537d328 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -1,29 +1,45 @@ -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket, APIRouter from fastapi.testclient import TestClient -recognizing_app = FastAPI() -ignoring_app = FastAPI(ignore_trailing_whitespaces=True) +app = FastAPI(ignore_trailing_slash=True) +router = APIRouter() -@recognizing_app.get("/example") -@ignoring_app.get("/example") -async def return_data(): - return {"msg": "Reached the route!"} +@app.get("/example") +async def example_endpoint(): + return {"msg": "Example"} -recognizing_client = TestClient(recognizing_app) -ignoring_client = TestClient(ignoring_app) +@app.websocket("/websocket") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket") + await websocket.close() -def test_recognizing_trailing_slash(): - response = recognizing_client.get("/example", follow_redirects=False) - assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" - response = recognizing_client.get("/example/", follow_redirects=False) - assert response.status_code == 307 - assert response.headers["location"].endswith("/example") +@router.get("/example") +def route_endpoint(): + return {"msg": "Routing Example"} + +app.include_router(router, prefix="/router") + +client = TestClient(app) def test_ignoring_trailing_slash(): - response = ignoring_client.get("/example", follow_redirects=False) + response = client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example" + response = client.get("/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example" + +def test_ignoring_trailing_shlash_ws(): + with client.websocket_connect("/websocket") as websocket: + assert websocket.receive_text() == "Websocket" + with client.websocket_connect("/websocket/") as websocket: + assert websocket.receive_text() == "Websocket" + +def test_ignoring_trailing_routing(): + response = client.get("router/example", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" - response = ignoring_client.get("/example/", follow_redirects=False) + assert response.json()["msg"] == "Routing Example" + response = client.get("router/example/", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Reached the route!" + assert response.json()["msg"] == "Routing Example" From 6823cb4a2026e6fc22fe973b97fab46acfac74be Mon Sep 17 00:00:00 2001 From: Synrom Date: Thu, 5 Sep 2024 16:49:56 +0000 Subject: [PATCH 03/12] Add docs, more tests and condition in middleware --- fastapi/applications.py | 10 +++++++++- fastapi/routing.py | 7 +++++++ tests/test_ignore_trailing_slash.py | 30 +++++++++++++++++++++-------- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 5a7170dd6..480d03ab8 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -814,6 +814,13 @@ class FastAPI(Starlette): bool, Doc( """ + To ignore (or not) trailing slashes at the end of URIs. + + For example, by setting `ignore_trailing_slash` to True, + requests to `/auth` and `/auth/` will have the same behaviour. + + By default (`ignore_trailing_slash` is False), the two requests are treated differently. + One of them will result in a 307-redirect. """ ), ] = False, @@ -973,7 +980,8 @@ class FastAPI(Starlette): if ignore_trailing_slash: def ignore_trailing_whitespace_middleware(app): async def ignore_trailing_whitespace_wrapper(scope, receive, send): - scope["path"] = scope["path"].rstrip("/") + if scope["type"] in {"http", "websocket"}: + scope["path"] = scope["path"].rstrip("/") await app(scope, receive, send) return ignore_trailing_whitespace_wrapper self.add_middleware(ignore_trailing_whitespace_middleware) diff --git a/fastapi/routing.py b/fastapi/routing.py index d59701ef5..2b68cb0d1 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -815,6 +815,13 @@ class APIRouter(routing.Router): bool, Doc( """ + To ignore (or not) trailing slashes at the end of URIs. + + For example, by setting `ignore_trailing_slash` to True, + requests to `/auth` and `/auth/` will have the same behaviour. + + By default (`ignore_trailing_slash` is False), the two requests are treated differently. + One of them will result in a 307-redirect. """ ), ] = False, diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 16537d328..8cd3c802a 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -8,38 +8,52 @@ router = APIRouter() async def example_endpoint(): return {"msg": "Example"} +@app.get("/example2/") +async def example_endpoint(): + return {"msg": "Example 2"} + @app.websocket("/websocket") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.send_text("Websocket") await websocket.close() +@app.websocket("/websocket2/") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket 2") + await websocket.close() + @router.get("/example") def route_endpoint(): return {"msg": "Routing Example"} +@router.get("/example2/") +def route_endpoint(): + return {"msg": "Routing Example 2"} + app.include_router(router, prefix="/router") client = TestClient(app) def test_ignoring_trailing_slash(): - response = client.get("/example", follow_redirects=False) - assert response.status_code == 200 - assert response.json()["msg"] == "Example" response = client.get("/example/", follow_redirects=False) assert response.status_code == 200 assert response.json()["msg"] == "Example" + response = client.get("/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example 2" def test_ignoring_trailing_shlash_ws(): - with client.websocket_connect("/websocket") as websocket: - assert websocket.receive_text() == "Websocket" with client.websocket_connect("/websocket/") as websocket: assert websocket.receive_text() == "Websocket" + with client.websocket_connect("/websocket2") as websocket: + assert websocket.receive_text() == "Websocket 2" def test_ignoring_trailing_routing(): - response = client.get("router/example", follow_redirects=False) - assert response.status_code == 200 - assert response.json()["msg"] == "Routing Example" response = client.get("router/example/", follow_redirects=False) assert response.status_code == 200 assert response.json()["msg"] == "Routing Example" + response = client.get("router/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Routing Example 2" \ No newline at end of file From 25accde3c6c7ec03842dd3bc3ded9e154ea83d40 Mon Sep 17 00:00:00 2001 From: Synrom Date: Fri, 6 Sep 2024 10:50:11 +0000 Subject: [PATCH 04/12] Formatting --- fastapi/applications.py | 3 +++ fastapi/routing.py | 6 +++--- tests/test_ignore_trailing_slash.py | 20 +++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 480d03ab8..d6148b9d5 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -978,12 +978,15 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None if ignore_trailing_slash: + def ignore_trailing_whitespace_middleware(app): async def ignore_trailing_whitespace_wrapper(scope, receive, send): if scope["type"] in {"http", "websocket"}: scope["path"] = scope["path"].rstrip("/") await app(scope, receive, send) + return ignore_trailing_whitespace_wrapper + self.add_middleware(ignore_trailing_whitespace_middleware) self.setup() diff --git a/fastapi/routing.py b/fastapi/routing.py index 7daec9b0f..0935a29f2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -881,6 +881,7 @@ class APIRouter(routing.Router): ) -> Callable[[DecoratedCallable], DecoratedCallable]: if self.ignore_trailing_slash: path = path.rstrip("/") + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( path, @@ -1134,6 +1135,7 @@ class APIRouter(routing.Router): ) -> Callable[[DecoratedCallable], DecoratedCallable]: if self.ignore_trailing_slash: path = path.rstrip("/") + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) return func @@ -1375,9 +1377,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, routing.WebSocketRoute): - self.add_websocket_route( - prefix + path, route.endpoint, name=route.name - ) + self.add_websocket_route(prefix + path, route.endpoint, name=route.name) for handler in router.on_startup: self.add_event_handler("startup", handler) for handler in router.on_shutdown: diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 8cd3c802a..be35993ff 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -1,41 +1,49 @@ -from fastapi import FastAPI, WebSocket, APIRouter +from fastapi import APIRouter, FastAPI, WebSocket from fastapi.testclient import TestClient app = FastAPI(ignore_trailing_slash=True) router = APIRouter() + @app.get("/example") async def example_endpoint(): return {"msg": "Example"} + @app.get("/example2/") -async def example_endpoint(): +async def example_endpoint_with_slash(): return {"msg": "Example 2"} + @app.websocket("/websocket") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.send_text("Websocket") await websocket.close() + @app.websocket("/websocket2/") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint_with_slash(websocket: WebSocket): await websocket.accept() await websocket.send_text("Websocket 2") await websocket.close() + @router.get("/example") def route_endpoint(): return {"msg": "Routing Example"} + @router.get("/example2/") -def route_endpoint(): +def route_endpoint_with_slash(): return {"msg": "Routing Example 2"} + app.include_router(router, prefix="/router") client = TestClient(app) + def test_ignoring_trailing_slash(): response = client.get("/example/", follow_redirects=False) assert response.status_code == 200 @@ -44,16 +52,18 @@ def test_ignoring_trailing_slash(): assert response.status_code == 200 assert response.json()["msg"] == "Example 2" + def test_ignoring_trailing_shlash_ws(): with client.websocket_connect("/websocket/") as websocket: assert websocket.receive_text() == "Websocket" with client.websocket_connect("/websocket2") as websocket: assert websocket.receive_text() == "Websocket 2" + def test_ignoring_trailing_routing(): response = client.get("router/example/", follow_redirects=False) assert response.status_code == 200 assert response.json()["msg"] == "Routing Example" response = client.get("router/example2", follow_redirects=False) assert response.status_code == 200 - assert response.json()["msg"] == "Routing Example 2" \ No newline at end of file + assert response.json()["msg"] == "Routing Example 2" From baccda3b29009dd8ff0fe3a5043ac0954ad13865 Mon Sep 17 00:00:00 2001 From: Synrom Date: Fri, 6 Sep 2024 11:13:54 +0000 Subject: [PATCH 05/12] Fix linting and formatting --- fastapi/applications.py | 15 +++++++++------ fastapi/routing.py | 15 ++++++++++++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index d6148b9d5..f44f52ed3 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -979,15 +979,18 @@ class FastAPI(Starlette): if ignore_trailing_slash: - def ignore_trailing_whitespace_middleware(app): - async def ignore_trailing_whitespace_wrapper(scope, receive, send): + class _IgnoreTrailingWhitespaceMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: if scope["type"] in {"http", "websocket"}: scope["path"] = scope["path"].rstrip("/") - await app(scope, receive, send) - - return ignore_trailing_whitespace_wrapper + await self.app(scope, receive, send) - self.add_middleware(ignore_trailing_whitespace_middleware) + self.add_middleware(_IgnoreTrailingWhitespaceMiddleware) self.setup() diff --git a/fastapi/routing.py b/fastapi/routing.py index 0935a29f2..b06704574 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1293,10 +1293,10 @@ class APIRouter(routing.Router): if responses is None: responses = {} for route in router.routes: - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") if isinstance(route, APIRoute): + path = route.path + if self.ignore_trailing_slash: + path = path.rstrip("/") combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( route.response_class, @@ -1356,6 +1356,9 @@ class APIRouter(routing.Router): generate_unique_id_function=current_generate_unique_id, ) elif isinstance(route, routing.Route): + path = route.path + if self.ignore_trailing_slash: + path = path.rstrip("/") methods = list(route.methods or []) self.add_route( prefix + path, @@ -1365,6 +1368,9 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, APIWebSocketRoute): + path = route.path + if self.ignore_trailing_slash: + path = path.rstrip("/") current_dependencies = [] if dependencies: current_dependencies.extend(dependencies) @@ -1377,6 +1383,9 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, routing.WebSocketRoute): + path = route.path + if self.ignore_trailing_slash: + path = path.rstrip("/") self.add_websocket_route(prefix + path, route.endpoint, name=route.name) for handler in router.on_startup: self.add_event_handler("startup", handler) From 4fb5d22a737d0b36a197673129b83a5d2b87605c Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 12:45:27 +0200 Subject: [PATCH 06/12] Support and test app.websocket_route --- fastapi/routing.py | 10 ++++++++++ tests/test_ignore_trailing_slash.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/fastapi/routing.py b/fastapi/routing.py index b06704574..33e3b7434 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1040,6 +1040,16 @@ class APIRouter(routing.Router): return decorator + def add_websocket_route( + self, + path: str, + endpoint: Callable[..., Any], + name: Optional[str] = None, + ) -> None: + if self.ignore_trailing_slash: + path = path.rstrip("/") + super().add_websocket_route(path, endpoint, name) + def add_api_websocket_route( self, path: str, diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index be35993ff..61ef7b496 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -29,6 +29,20 @@ async def websocket_endpoint_with_slash(websocket: WebSocket): await websocket.close() +@app.websocket_route("/websocket_route") +async def websocket_route_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route") + await websocket.close() + + +@app.websocket_route("/websocket_route_2/") +async def websocket_route_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route 2") + await websocket.close() + + @router.get("/example") def route_endpoint(): return {"msg": "Routing Example"} @@ -58,6 +72,10 @@ def test_ignoring_trailing_shlash_ws(): assert websocket.receive_text() == "Websocket" with client.websocket_connect("/websocket2") as websocket: assert websocket.receive_text() == "Websocket 2" + with client.websocket_connect("/websocket_route/") as websocket: + assert websocket.receive_text() == "Websocket route" + with client.websocket_connect("/websocket_route_2/") as websocket: + assert websocket.receive_text() == "Websocket route 2" def test_ignoring_trailing_routing(): From a2b2114e97b14995b91a74ea74b26008eeeeb263 Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 13:11:12 +0200 Subject: [PATCH 07/12] Add more tests --- tests/test_ignore_trailing_slash.py | 55 ++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 61ef7b496..10830077c 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter, FastAPI, WebSocket +from fastapi import APIRouter, FastAPI, Request, WebSocket +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient app = FastAPI(ignore_trailing_slash=True) @@ -53,6 +54,44 @@ def route_endpoint_with_slash(): return {"msg": "Routing Example 2"} +@router.websocket("/websocket") +async def router_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket") + await websocket.close() + + +@router.websocket("/websocket2/") +async def router_websocket_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket 2") + await websocket.close() + + +@router.websocket_route("/websocket_route") +async def router_websocket_route_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route") + await websocket.close() + + +@router.websocket_route("/websocket_route_2/") +async def router_websocket_route_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route 2") + await websocket.close() + + +@router.route("/starlette_route", ["get"]) +async def starlette_route_endpoint(request: Request): + return JSONResponse({"msg": "Starlette Route"}) + + +@router.route("/starlette_route_2/", ["get"]) +async def starlette_route_endpoint_with_slash(request: Request): + return JSONResponse({"msg": "Starlette Route 2"}) + + app.include_router(router, prefix="/router") client = TestClient(app) @@ -85,3 +124,17 @@ def test_ignoring_trailing_routing(): response = client.get("router/example2", follow_redirects=False) assert response.status_code == 200 assert response.json()["msg"] == "Routing Example 2" + response = client.get("router/starlette_route/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Starlette Route" + response = client.get("router/starlette_route_2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Starlette Route 2" + with client.websocket_connect("router/websocket/") as websocket: + assert websocket.receive_text() == "Websocket" + with client.websocket_connect("router/websocket2") as websocket: + assert websocket.receive_text() == "Websocket 2" + with client.websocket_connect("router/websocket_route/") as websocket: + assert websocket.receive_text() == "Websocket route" + with client.websocket_connect("router/websocket_route_2/") as websocket: + assert websocket.receive_text() == "Websocket route 2" From 192151a71394834b3f8c0550804a4fbceff73b3c Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 13:24:44 +0200 Subject: [PATCH 08/12] Add more tests --- tests/test_ignore_trailing_slash.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py index 10830077c..5038c6f7b 100644 --- a/tests/test_ignore_trailing_slash.py +++ b/tests/test_ignore_trailing_slash.py @@ -92,7 +92,35 @@ async def starlette_route_endpoint_with_slash(request: Request): return JSONResponse({"msg": "Starlette Route 2"}) +router_ignore = APIRouter(ignore_trailing_slash=True) + + +@router_ignore.route("/example", ["get"]) +async def router_ignore_example(request: Request): + return JSONResponse({"msg": "Router Ignore"}) + + +@router_ignore.route("/example2/", ["get"]) +async def router_ignore_example_with_slash(request: Request): + return JSONResponse({"msg": "Router Ignore 2"}) + + +@router_ignore.websocket_route("/websocket") +async def router_ignore_websocket(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Router Ignore Websocket") + await websocket.close() + + +@router_ignore.websocket_route("/websocket2/") +async def router_ignore_websocket_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Router Ignore Websocket 2") + await websocket.close() + + app.include_router(router, prefix="/router") +app.include_router(router_ignore, prefix="/router_ignore") client = TestClient(app) @@ -138,3 +166,16 @@ def test_ignoring_trailing_routing(): assert websocket.receive_text() == "Websocket route" with client.websocket_connect("router/websocket_route_2/") as websocket: assert websocket.receive_text() == "Websocket route 2" + + +def test_add_router_with_ignore_flag(): + response = client.get("/router_ignore/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Router Ignore" + response = client.get("/router_ignore/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Router Ignore 2" + with client.websocket_connect("/router_ignore/websocket/") as websocket: + assert websocket.receive_text() == "Router Ignore Websocket" + with client.websocket_connect("/router_ignore/websocket2") as websocket: + assert websocket.receive_text() == "Router Ignore Websocket 2" From 031c69fa9e56f794aed5da9c952fa7bc669ca59d Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 13:36:26 +0200 Subject: [PATCH 09/12] Use helper method to avoid code duplication --- fastapi/routing.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 33e3b7434..140e6e75a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -879,12 +879,9 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_slash: - path = path.rstrip("/") - def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( - path, + self._normalize_path(path), func, methods=methods, name=name, @@ -928,8 +925,7 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) route_class = route_class_override or self.route_class responses = responses or {} combined_responses = {**self.responses, **responses} @@ -1046,8 +1042,7 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) super().add_websocket_route(path, endpoint, name) def add_api_websocket_route( @@ -1058,8 +1053,7 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -1143,8 +1137,7 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(path) def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) @@ -1304,9 +1297,7 @@ class APIRouter(routing.Router): responses = {} for route in router.routes: if isinstance(route, APIRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( route.response_class, @@ -1366,9 +1357,7 @@ class APIRouter(routing.Router): generate_unique_id_function=current_generate_unique_id, ) elif isinstance(route, routing.Route): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) methods = list(route.methods or []) self.add_route( prefix + path, @@ -1378,9 +1367,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, APIWebSocketRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) current_dependencies = [] if dependencies: current_dependencies.extend(dependencies) @@ -1393,9 +1380,7 @@ class APIRouter(routing.Router): name=route.name, ) elif isinstance(route, routing.WebSocketRoute): - path = route.path - if self.ignore_trailing_slash: - path = path.rstrip("/") + path = self._normalize_path(route.path) self.add_websocket_route(prefix + path, route.endpoint, name=route.name) for handler in router.on_startup: self.add_event_handler("startup", handler) @@ -1406,6 +1391,11 @@ class APIRouter(routing.Router): router.lifespan_context, ) + def _normalize_path(self, path: str) -> str: + if self.ignore_trailing_slash: + return path.rstrip("/") + return path + def get( self, path: Annotated[ From 0acc056d1be6cf3a3c0eae3ba5498f71e1d11e10 Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 8 Sep 2024 13:45:51 +0200 Subject: [PATCH 10/12] Use helper method for routes to avoid code duplication --- fastapi/routing.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 140e6e75a..7d5e5c21a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1296,8 +1296,8 @@ class APIRouter(routing.Router): if responses is None: responses = {} for route in router.routes: + route = self._normalize_route(route) if isinstance(route, APIRoute): - path = self._normalize_path(route.path) combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( route.response_class, @@ -1327,7 +1327,7 @@ class APIRouter(routing.Router): self.generate_unique_id_function, ) self.add_api_route( - prefix + path, + prefix + route.path, route.endpoint, response_model=route.response_model, status_code=route.status_code, @@ -1357,31 +1357,30 @@ class APIRouter(routing.Router): generate_unique_id_function=current_generate_unique_id, ) elif isinstance(route, routing.Route): - path = self._normalize_path(route.path) methods = list(route.methods or []) self.add_route( - prefix + path, + prefix + route.path, route.endpoint, methods=methods, include_in_schema=route.include_in_schema, name=route.name, ) elif isinstance(route, APIWebSocketRoute): - path = self._normalize_path(route.path) current_dependencies = [] if dependencies: current_dependencies.extend(dependencies) if route.dependencies: current_dependencies.extend(route.dependencies) self.add_api_websocket_route( - prefix + path, + prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name, ) elif isinstance(route, routing.WebSocketRoute): - path = self._normalize_path(route.path) - self.add_websocket_route(prefix + path, route.endpoint, name=route.name) + self.add_websocket_route( + prefix + route.path, route.endpoint, name=route.name + ) for handler in router.on_startup: self.add_event_handler("startup", handler) for handler in router.on_shutdown: @@ -1396,6 +1395,11 @@ class APIRouter(routing.Router): return path.rstrip("/") return path + def _normalize_route(self, route: BaseRoute) -> BaseRoute: + if hasattr(route, "path") and isinstance(route.path, str): + route.path = self._normalize_path(route.path) + return route + def get( self, path: Annotated[ From 319fea3781669a3cce17c92d8d3f93c218af77f9 Mon Sep 17 00:00:00 2001 From: Synrom Date: Sun, 15 Sep 2024 13:46:25 +0200 Subject: [PATCH 11/12] :memo: Add more documentation on `ignore_trailing_slash` --- fastapi/applications.py | 5 +++++ fastapi/routing.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/fastapi/applications.py b/fastapi/applications.py index f44f52ed3..58d5b268a 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -821,6 +821,11 @@ class FastAPI(Starlette): By default (`ignore_trailing_slash` is False), the two requests are treated differently. One of them will result in a 307-redirect. + + It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` + and `/auth/` as different routes will be treated as if `/auth` was registered twice. + This means that only the first route registered will be used. + Therefore, ensure your route setup does not conflict unintentionally. """ ), ] = False, diff --git a/fastapi/routing.py b/fastapi/routing.py index 7d5e5c21a..cec5075e8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -842,6 +842,11 @@ class APIRouter(routing.Router): By default (`ignore_trailing_slash` is False), the two requests are treated differently. One of them will result in a 307-redirect. + + It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` + and `/auth/` as different routes will be treated as if `/auth` was registered twice. + This means that only the first route registered will be used. + Therefore, ensure your route setup does not conflict unintentionally. """ ), ] = False, From 982880694c04c55ec6929593ef67b5b5e0eb18fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Sep 2024 11:48:02 +0000 Subject: [PATCH 12/12] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/applications.py | 2 +- fastapi/routing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 58d5b268a..cae292f37 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -823,7 +823,7 @@ class FastAPI(Starlette): One of them will result in a 307-redirect. It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` - and `/auth/` as different routes will be treated as if `/auth` was registered twice. + and `/auth/` as different routes will be treated as if `/auth` was registered twice. This means that only the first route registered will be used. Therefore, ensure your route setup does not conflict unintentionally. """ diff --git a/fastapi/routing.py b/fastapi/routing.py index cec5075e8..31db7ebef 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -844,7 +844,7 @@ class APIRouter(routing.Router): One of them will result in a 307-redirect. It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` - and `/auth/` as different routes will be treated as if `/auth` was registered twice. + and `/auth/` as different routes will be treated as if `/auth` was registered twice. This means that only the first route registered will be used. Therefore, ensure your route setup does not conflict unintentionally. """