diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..cae292f37 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -810,6 +810,25 @@ class FastAPI(Starlette): """ ), ] = True, + ignore_trailing_slash: Annotated[ + bool, + Doc( + """ + To ignore (or not) trailing slashes at the end of URIs. + + For example, by setting `ignore_trailing_slash` to True, + requests to `/auth` and `/auth/` will have the same behaviour. + + By default (`ignore_trailing_slash` is False), the two requests are treated differently. + One of them will result in a 307-redirect. + + It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` + and `/auth/` as different routes will be treated as if `/auth` was registered twice. + This means that only the first route registered will be used. + Therefore, ensure your route setup does not conflict unintentionally. + """ + ), + ] = False, **extra: Annotated[ Any, Doc( @@ -943,6 +962,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + ignore_trailing_slash=ignore_trailing_slash, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] @@ -961,6 +981,22 @@ class FastAPI(Starlette): [] if middleware is None else list(middleware) ) self.middleware_stack: Union[ASGIApp, None] = None + + if ignore_trailing_slash: + + class _IgnoreTrailingWhitespaceMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + if scope["type"] in {"http", "websocket"}: + scope["path"] = scope["path"].rstrip("/") + await self.app(scope, receive, send) + + self.add_middleware(_IgnoreTrailingWhitespaceMiddleware) + self.setup() def openapi(self) -> Dict[str, Any]: diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..71377437a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -833,6 +833,25 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + ignore_trailing_slash: Annotated[ + bool, + Doc( + """ + To ignore (or not) trailing slashes at the end of URIs. + + For example, by setting `ignore_trailing_slash` to True, + requests to `/auth` and `/auth/` will have the same behaviour. + + By default (`ignore_trailing_slash` is False), the two requests are treated differently. + One of them will result in a 307-redirect. + + It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth` + and `/auth/` as different routes will be treated as if `/auth` was registered twice. + This means that only the first route registered will be used. + Therefore, ensure your route setup does not conflict unintentionally. + """ + ), + ] = False, ) -> None: super().__init__( routes=routes, @@ -858,6 +877,7 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.ignore_trailing_slash = ignore_trailing_slash def route( self, @@ -868,7 +888,7 @@ class APIRouter(routing.Router): ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_route( - path, + self._normalize_path(path), func, methods=methods, name=name, @@ -912,6 +932,7 @@ class APIRouter(routing.Router): Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), ) -> None: + path = self._normalize_path(path) route_class = route_class_override or self.route_class responses = responses or {} combined_responses = {**self.responses, **responses} @@ -1022,6 +1043,15 @@ class APIRouter(routing.Router): return decorator + def add_websocket_route( + self, + path: str, + endpoint: Callable[..., Any], + name: Optional[str] = None, + ) -> None: + path = self._normalize_path(path) + super().add_websocket_route(path, endpoint, name) + def add_api_websocket_route( self, path: str, @@ -1030,6 +1060,7 @@ class APIRouter(routing.Router): *, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: + path = self._normalize_path(path) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -1113,6 +1144,8 @@ class APIRouter(routing.Router): def websocket_route( self, path: str, name: Union[str, None] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: + path = self._normalize_path(path) + def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_websocket_route(path, func, name=name) return func @@ -1270,6 +1303,7 @@ class APIRouter(routing.Router): if responses is None: responses = {} for route in router.routes: + route = self._normalize_route(route) if isinstance(route, APIRoute): combined_responses = {**responses, **route.responses} use_response_class = get_value_or_default( @@ -1363,6 +1397,16 @@ class APIRouter(routing.Router): router.lifespan_context, ) + def _normalize_path(self, path: str) -> str: + if self.ignore_trailing_slash: + return path.rstrip("/") + return path + + def _normalize_route(self, route: BaseRoute) -> BaseRoute: + if hasattr(route, "path") and isinstance(route.path, str): + route.path = self._normalize_path(route.path) + return route + def get( self, path: Annotated[ diff --git a/tests/test_ignore_trailing_slash.py b/tests/test_ignore_trailing_slash.py new file mode 100644 index 000000000..5038c6f7b --- /dev/null +++ b/tests/test_ignore_trailing_slash.py @@ -0,0 +1,181 @@ +from fastapi import APIRouter, FastAPI, Request, WebSocket +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + +app = FastAPI(ignore_trailing_slash=True) +router = APIRouter() + + +@app.get("/example") +async def example_endpoint(): + return {"msg": "Example"} + + +@app.get("/example2/") +async def example_endpoint_with_slash(): + return {"msg": "Example 2"} + + +@app.websocket("/websocket") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket") + await websocket.close() + + +@app.websocket("/websocket2/") +async def websocket_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket 2") + await websocket.close() + + +@app.websocket_route("/websocket_route") +async def websocket_route_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route") + await websocket.close() + + +@app.websocket_route("/websocket_route_2/") +async def websocket_route_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route 2") + await websocket.close() + + +@router.get("/example") +def route_endpoint(): + return {"msg": "Routing Example"} + + +@router.get("/example2/") +def route_endpoint_with_slash(): + return {"msg": "Routing Example 2"} + + +@router.websocket("/websocket") +async def router_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket") + await websocket.close() + + +@router.websocket("/websocket2/") +async def router_websocket_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket 2") + await websocket.close() + + +@router.websocket_route("/websocket_route") +async def router_websocket_route_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route") + await websocket.close() + + +@router.websocket_route("/websocket_route_2/") +async def router_websocket_route_endpoint_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Websocket route 2") + await websocket.close() + + +@router.route("/starlette_route", ["get"]) +async def starlette_route_endpoint(request: Request): + return JSONResponse({"msg": "Starlette Route"}) + + +@router.route("/starlette_route_2/", ["get"]) +async def starlette_route_endpoint_with_slash(request: Request): + return JSONResponse({"msg": "Starlette Route 2"}) + + +router_ignore = APIRouter(ignore_trailing_slash=True) + + +@router_ignore.route("/example", ["get"]) +async def router_ignore_example(request: Request): + return JSONResponse({"msg": "Router Ignore"}) + + +@router_ignore.route("/example2/", ["get"]) +async def router_ignore_example_with_slash(request: Request): + return JSONResponse({"msg": "Router Ignore 2"}) + + +@router_ignore.websocket_route("/websocket") +async def router_ignore_websocket(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Router Ignore Websocket") + await websocket.close() + + +@router_ignore.websocket_route("/websocket2/") +async def router_ignore_websocket_with_slash(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Router Ignore Websocket 2") + await websocket.close() + + +app.include_router(router, prefix="/router") +app.include_router(router_ignore, prefix="/router_ignore") + +client = TestClient(app) + + +def test_ignoring_trailing_slash(): + response = client.get("/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example" + response = client.get("/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Example 2" + + +def test_ignoring_trailing_shlash_ws(): + with client.websocket_connect("/websocket/") as websocket: + assert websocket.receive_text() == "Websocket" + with client.websocket_connect("/websocket2") as websocket: + assert websocket.receive_text() == "Websocket 2" + with client.websocket_connect("/websocket_route/") as websocket: + assert websocket.receive_text() == "Websocket route" + with client.websocket_connect("/websocket_route_2/") as websocket: + assert websocket.receive_text() == "Websocket route 2" + + +def test_ignoring_trailing_routing(): + response = client.get("router/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Routing Example" + response = client.get("router/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Routing Example 2" + response = client.get("router/starlette_route/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Starlette Route" + response = client.get("router/starlette_route_2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Starlette Route 2" + with client.websocket_connect("router/websocket/") as websocket: + assert websocket.receive_text() == "Websocket" + with client.websocket_connect("router/websocket2") as websocket: + assert websocket.receive_text() == "Websocket 2" + with client.websocket_connect("router/websocket_route/") as websocket: + assert websocket.receive_text() == "Websocket route" + with client.websocket_connect("router/websocket_route_2/") as websocket: + assert websocket.receive_text() == "Websocket route 2" + + +def test_add_router_with_ignore_flag(): + response = client.get("/router_ignore/example/", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Router Ignore" + response = client.get("/router_ignore/example2", follow_redirects=False) + assert response.status_code == 200 + assert response.json()["msg"] == "Router Ignore 2" + with client.websocket_connect("/router_ignore/websocket/") as websocket: + assert websocket.receive_text() == "Router Ignore Websocket" + with client.websocket_connect("/router_ignore/websocket2") as websocket: + assert websocket.receive_text() == "Router Ignore Websocket 2"