Browse Source

Add ignore_trailing_slashes flag to applications

pull/12145/head
Synrom 7 months ago
parent
commit
d4eaafb804
  1. 14
      fastapi/applications.py
  2. 27
      fastapi/routing.py
  3. 29
      tests/test_ignore_trailing_slash.py

14
fastapi/applications.py

@ -810,6 +810,13 @@ class FastAPI(Starlette):
"""
),
] = True,
ignore_trailing_whitespaces: Annotated[
bool,
Doc(
"""
"""
),
] = False,
**extra: Annotated[
Any,
Doc(
@ -943,6 +950,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
ignore_trailing_whitespaces=ignore_trailing_whitespaces,
)
self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
@ -961,6 +969,12 @@ class FastAPI(Starlette):
[] if middleware is None else list(middleware)
)
self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_whitespaces:
async def middleware_ignore_tailing_whitespace(request: Request, call_next):
request.scope["path"] = request.scope["path"].rstrip("/")
return await call_next(request)
self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace)
self.setup()
def openapi(self) -> Dict[str, Any]:

27
fastapi/routing.py

@ -811,6 +811,13 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
ignore_trailing_whitespaces: Annotated[
bool,
Doc(
"""
"""
),
] = False,
) -> None:
super().__init__(
routes=routes,
@ -836,6 +843,7 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.ignore_trailing_whitespaces = ignore_trailing_whitespaces
def route(
self,
@ -844,6 +852,8 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route(
path,
@ -890,6 +900,8 @@ class APIRouter(routing.Router):
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
) -> None:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
route_class = route_class_override or self.route_class
responses = responses or {}
combined_responses = {**self.responses, **responses}
@ -1008,6 +1020,8 @@ class APIRouter(routing.Router):
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)
@ -1091,6 +1105,8 @@ class APIRouter(routing.Router):
def websocket_route(
self, path: str, name: Union[str, None] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name)
return func
@ -1248,6 +1264,9 @@ class APIRouter(routing.Router):
if responses is None:
responses = {}
for route in router.routes:
path = route.path
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
if isinstance(route, APIRoute):
combined_responses = {**responses, **route.responses}
use_response_class = get_value_or_default(
@ -1278,7 +1297,7 @@ class APIRouter(routing.Router):
self.generate_unique_id_function,
)
self.add_api_route(
prefix + route.path,
prefix + path,
route.endpoint,
response_model=route.response_model,
status_code=route.status_code,
@ -1310,7 +1329,7 @@ class APIRouter(routing.Router):
elif isinstance(route, routing.Route):
methods = list(route.methods or [])
self.add_route(
prefix + route.path,
prefix + path,
route.endpoint,
methods=methods,
include_in_schema=route.include_in_schema,
@ -1323,14 +1342,14 @@ class APIRouter(routing.Router):
if route.dependencies:
current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
prefix + route.path,
prefix + path,
route.endpoint,
dependencies=current_dependencies,
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
prefix + path, route.endpoint, name=route.name
)
for handler in router.on_startup:
self.add_event_handler("startup", handler)

29
tests/test_ignore_trailing_slash.py

@ -0,0 +1,29 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
recognizing_app = FastAPI()
ignoring_app = FastAPI(ignore_trailing_whitespaces=True)
@recognizing_app.get("/example")
@ignoring_app.get("/example")
async def return_data():
return {"msg": "Reached the route!"}
recognizing_client = TestClient(recognizing_app)
ignoring_client = TestClient(ignoring_app)
def test_recognizing_trailing_slash():
response = recognizing_client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = recognizing_client.get("/example/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"].endswith("/example")
def test_ignoring_trailing_slash():
response = ignoring_client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = ignoring_client.get("/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
Loading…
Cancel
Save