From d4eaafb8043a988701ff965c1c061c586b27f9ff Mon Sep 17 00:00:00 2001 From: Synrom Date: Wed, 4 Sep 2024 17:03:56 +0000 Subject: [PATCH] Add ignore_trailing_slashes flag to applications --- fastapi/applications.py | 14 ++++++++++++++ fastapi/routing.py | 27 +++++++++++++++++++++++---- tests/test_ignore_trailing_slash.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/test_ignore_trailing_slash.py diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..6fc595634 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -810,6 +810,13 @@ class FastAPI(Starlette): """ ), ] = True, + ignore_trailing_whitespaces: Annotated[ + bool, + Doc( + """ + """ + ), + ] = False, **extra: Annotated[ Any, Doc( @@ -943,6 +950,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + ignore_trailing_whitespaces=ignore_trailing_whitespaces, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] @@ -961,6 +969,12 @@ class FastAPI(Starlette): [] if middleware is None else list(middleware) ) self.middleware_stack: Union[ASGIApp, None] = None + + if ignore_trailing_whitespaces: + async def middleware_ignore_tailing_whitespace(request: Request, call_next): + request.scope["path"] = request.scope["path"].rstrip("/") + return await call_next(request) + self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace) self.setup() def openapi(self) -> Dict[str, Any]: diff --git a/fastapi/routing.py b/fastapi/routing.py index 61a112fc4..cb14fd399 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -811,6 +811,13 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + ignore_trailing_whitespaces: Annotated[ + bool, + Doc( + """ + """ + ), + ] = False, ) -> None: super().__init__( routes=routes, @@ -836,6 +843,7 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.ignore_trailing_whitespaces = ignore_trailing_whitespaces def route( self, @@ -844,6 +852,8 @@ class APIRouter(routing.Router): name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( path, @@ -890,6 +900,8 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") route_class = route_class_override or self.route_class responses = responses or {} combined_responses = {**self.responses, **responses} @@ -1008,6 +1020,8 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -1091,6 +1105,8 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) return func @@ -1248,6 +1264,9 @@ class APIRouter(routing.Router): if responses is None: responses = {} for route in router.routes: + path = route.path + if self.ignore_trailing_whitespaces: + path = path.rstrip("/") if isinstance(route, APIRoute): combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( @@ -1278,7 +1297,7 @@ class APIRouter(routing.Router): self.generate_unique_id_function, ) self.add_api_route( - prefix + route.path, + prefix + path, route.endpoint, response_model=route.response_model, status_code=route.status_code, @@ -1310,7 +1329,7 @@ class APIRouter(routing.Router): elif isinstance(route, routing.Route): methods = list(route.methods or []) self.add_route( - prefix + route.path, + prefix + path, route.endpoint, methods=methods, include_in_schema=route.include_in_schema, @@ -1323,14 +1342,14 @@ class APIRouter(routing.Router): if route.dependencies: current_dependencies.extend(route.dependencies) self.add_api_websocket_route( - prefix + route.path, + prefix + path, route.endpoint, dependencies=current_dependencies, name=route.name, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( - prefix + route.path, route.endpoint, name=route.name + prefix + path, route.endpoint, name=route.name ) for handler in router.on_startup: self.add_event_handler("startup", handler) diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py new file mode 100644 index 000000000..551ccd61b --- /dev/null +++ b/tests/test_ignore_trailing_slash.py @@ -0,0 +1,29 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +recognizing_app = FastAPI() +ignoring_app = FastAPI(ignore_trailing_whitespaces=True) + +@recognizing_app.get("/example") +@ignoring_app.get("/example") +async def return_data(): + return {"msg": "Reached the route!"} + +recognizing_client = TestClient(recognizing_app) +ignoring_client = TestClient(ignoring_app) + +def test_recognizing_trailing_slash(): + response = recognizing_client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!" + response = recognizing_client.get("/example/", follow_redirects=False) + assert response.status_code == 307 + assert response.headers["location"].endswith("/example") + +def test_ignoring_trailing_slash(): + response = ignoring_client.get("/example", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!" + response = ignoring_client.get("/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Reached the route!"