Browse Source

Merge 982880694c into 76b324d95b

pull/12145/merge
Synrom 2 days ago
committed by GitHub
parent
commit
067bbfd57c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 36
      fastapi/applications.py
  2. 46
      fastapi/routing.py
  3. 181
      tests/test_ignore_trailing_slash.py

36
fastapi/applications.py

@ -810,6 +810,25 @@ class FastAPI(Starlette):
"""
),
] = True,
ignore_trailing_slash: Annotated[
bool,
Doc(
"""
To ignore (or not) trailing slashes at the end of URIs.
For example, by setting `ignore_trailing_slash` to True,
requests to `/auth` and `/auth/` will have the same behaviour.
By default (`ignore_trailing_slash` is False), the two requests are treated differently.
One of them will result in a 307-redirect.
It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth`
and `/auth/` as different routes will be treated as if `/auth` was registered twice.
This means that only the first route registered will be used.
Therefore, ensure your route setup does not conflict unintentionally.
"""
),
] = False,
**extra: Annotated[
Any,
Doc(
@ -943,6 +962,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
ignore_trailing_slash=ignore_trailing_slash,
)
self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
@ -961,6 +981,22 @@ class FastAPI(Starlette):
[] if middleware is None else list(middleware)
)
self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_slash:
class _IgnoreTrailingWhitespaceMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] in {"http", "websocket"}:
scope["path"] = scope["path"].rstrip("/")
await self.app(scope, receive, send)
self.add_middleware(_IgnoreTrailingWhitespaceMiddleware)
self.setup()
def openapi(self) -> Dict[str, Any]:

46
fastapi/routing.py

@ -833,6 +833,25 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
ignore_trailing_slash: Annotated[
bool,
Doc(
"""
To ignore (or not) trailing slashes at the end of URIs.
For example, by setting `ignore_trailing_slash` to True,
requests to `/auth` and `/auth/` will have the same behaviour.
By default (`ignore_trailing_slash` is False), the two requests are treated differently.
One of them will result in a 307-redirect.
It's important to understand that when `ignore_trailing_slash=True`, registering both `/auth`
and `/auth/` as different routes will be treated as if `/auth` was registered twice.
This means that only the first route registered will be used.
Therefore, ensure your route setup does not conflict unintentionally.
"""
),
] = False,
) -> None:
super().__init__(
routes=routes,
@ -858,6 +877,7 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.ignore_trailing_slash = ignore_trailing_slash
def route(
self,
@ -868,7 +888,7 @@ class APIRouter(routing.Router):
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route(
path,
self._normalize_path(path),
func,
methods=methods,
name=name,
@ -912,6 +932,7 @@ class APIRouter(routing.Router):
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
) -> None:
path = self._normalize_path(path)
route_class = route_class_override or self.route_class
responses = responses or {}
combined_responses = {**self.responses, **responses}
@ -1022,6 +1043,15 @@ class APIRouter(routing.Router):
return decorator
def add_websocket_route(
self,
path: str,
endpoint: Callable[..., Any],
name: Optional[str] = None,
) -> None:
path = self._normalize_path(path)
super().add_websocket_route(path, endpoint, name)
def add_api_websocket_route(
self,
path: str,
@ -1030,6 +1060,7 @@ class APIRouter(routing.Router):
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
path = self._normalize_path(path)
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)
@ -1113,6 +1144,8 @@ class APIRouter(routing.Router):
def websocket_route(
self, path: str, name: Union[str, None] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
path = self._normalize_path(path)
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name)
return func
@ -1270,6 +1303,7 @@ class APIRouter(routing.Router):
if responses is None:
responses = {}
for route in router.routes:
route = self._normalize_route(route)
if isinstance(route, APIRoute):
combined_responses = {**responses, **route.responses}
use_response_class = get_value_or_default(
@ -1363,6 +1397,16 @@ class APIRouter(routing.Router):
router.lifespan_context,
)
def _normalize_path(self, path: str) -> str:
if self.ignore_trailing_slash:
return path.rstrip("/")
return path
def _normalize_route(self, route: BaseRoute) -> BaseRoute:
if hasattr(route, "path") and isinstance(route.path, str):
route.path = self._normalize_path(route.path)
return route
def get(
self,
path: Annotated[

181
tests/test_ignore_trailing_slash.py

@ -0,0 +1,181 @@
from fastapi import APIRouter, FastAPI, Request, WebSocket
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
app = FastAPI(ignore_trailing_slash=True)
router = APIRouter()
@app.get("/example")
async def example_endpoint():
return {"msg": "Example"}
@app.get("/example2/")
async def example_endpoint_with_slash():
return {"msg": "Example 2"}
@app.websocket("/websocket")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket")
await websocket.close()
@app.websocket("/websocket2/")
async def websocket_endpoint_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket 2")
await websocket.close()
@app.websocket_route("/websocket_route")
async def websocket_route_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket route")
await websocket.close()
@app.websocket_route("/websocket_route_2/")
async def websocket_route_endpoint_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket route 2")
await websocket.close()
@router.get("/example")
def route_endpoint():
return {"msg": "Routing Example"}
@router.get("/example2/")
def route_endpoint_with_slash():
return {"msg": "Routing Example 2"}
@router.websocket("/websocket")
async def router_websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket")
await websocket.close()
@router.websocket("/websocket2/")
async def router_websocket_endpoint_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket 2")
await websocket.close()
@router.websocket_route("/websocket_route")
async def router_websocket_route_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket route")
await websocket.close()
@router.websocket_route("/websocket_route_2/")
async def router_websocket_route_endpoint_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket route 2")
await websocket.close()
@router.route("/starlette_route", ["get"])
async def starlette_route_endpoint(request: Request):
return JSONResponse({"msg": "Starlette Route"})
@router.route("/starlette_route_2/", ["get"])
async def starlette_route_endpoint_with_slash(request: Request):
return JSONResponse({"msg": "Starlette Route 2"})
router_ignore = APIRouter(ignore_trailing_slash=True)
@router_ignore.route("/example", ["get"])
async def router_ignore_example(request: Request):
return JSONResponse({"msg": "Router Ignore"})
@router_ignore.route("/example2/", ["get"])
async def router_ignore_example_with_slash(request: Request):
return JSONResponse({"msg": "Router Ignore 2"})
@router_ignore.websocket_route("/websocket")
async def router_ignore_websocket(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Router Ignore Websocket")
await websocket.close()
@router_ignore.websocket_route("/websocket2/")
async def router_ignore_websocket_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Router Ignore Websocket 2")
await websocket.close()
app.include_router(router, prefix="/router")
app.include_router(router_ignore, prefix="/router_ignore")
client = TestClient(app)
def test_ignoring_trailing_slash():
response = client.get("/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example"
response = client.get("/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example 2"
def test_ignoring_trailing_shlash_ws():
with client.websocket_connect("/websocket/") as websocket:
assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket2") as websocket:
assert websocket.receive_text() == "Websocket 2"
with client.websocket_connect("/websocket_route/") as websocket:
assert websocket.receive_text() == "Websocket route"
with client.websocket_connect("/websocket_route_2/") as websocket:
assert websocket.receive_text() == "Websocket route 2"
def test_ignoring_trailing_routing():
response = client.get("router/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example"
response = client.get("router/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example 2"
response = client.get("router/starlette_route/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Starlette Route"
response = client.get("router/starlette_route_2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Starlette Route 2"
with client.websocket_connect("router/websocket/") as websocket:
assert websocket.receive_text() == "Websocket"
with client.websocket_connect("router/websocket2") as websocket:
assert websocket.receive_text() == "Websocket 2"
with client.websocket_connect("router/websocket_route/") as websocket:
assert websocket.receive_text() == "Websocket route"
with client.websocket_connect("router/websocket_route_2/") as websocket:
assert websocket.receive_text() == "Websocket route 2"
def test_add_router_with_ignore_flag():
response = client.get("/router_ignore/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Router Ignore"
response = client.get("/router_ignore/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Router Ignore 2"
with client.websocket_connect("/router_ignore/websocket/") as websocket:
assert websocket.receive_text() == "Router Ignore Websocket"
with client.websocket_connect("/router_ignore/websocket2") as websocket:
assert websocket.receive_text() == "Router Ignore Websocket 2"
Loading…
Cancel
Save