Browse Source

Add ignore_trailing_slashes flag to applications

pull/12145/head
Synrom 7 months ago
parent
commit
d4eaafb804
  1. 14
      fastapi/applications.py
  2. 27
      fastapi/routing.py
  3. 29
      tests/test_ignore_trailing_slash.py

14
fastapi/applications.py

@ -810,6 +810,13 @@ class FastAPI(Starlette):
""" """
), ),
] = True, ] = True,
ignore_trailing_whitespaces: Annotated[
bool,
Doc(
"""
"""
),
] = False,
**extra: Annotated[ **extra: Annotated[
Any, Any,
Doc( Doc(
@ -943,6 +950,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
responses=responses, responses=responses,
generate_unique_id_function=generate_unique_id_function, generate_unique_id_function=generate_unique_id_function,
ignore_trailing_whitespaces=ignore_trailing_whitespaces,
) )
self.exception_handlers: Dict[ self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
@ -961,6 +969,12 @@ class FastAPI(Starlette):
[] if middleware is None else list(middleware) [] if middleware is None else list(middleware)
) )
self.middleware_stack: Union[ASGIApp, None] = None self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_whitespaces:
async def middleware_ignore_tailing_whitespace(request: Request, call_next):
request.scope["path"] = request.scope["path"].rstrip("/")
return await call_next(request)
self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace)
self.setup() self.setup()
def openapi(self) -> Dict[str, Any]: def openapi(self) -> Dict[str, Any]:

27
fastapi/routing.py

@ -811,6 +811,13 @@ class APIRouter(routing.Router):
""" """
), ),
] = Default(generate_unique_id), ] = Default(generate_unique_id),
ignore_trailing_whitespaces: Annotated[
bool,
Doc(
"""
"""
),
] = False,
) -> None: ) -> None:
super().__init__( super().__init__(
routes=routes, routes=routes,
@ -836,6 +843,7 @@ class APIRouter(routing.Router):
self.route_class = route_class self.route_class = route_class
self.default_response_class = default_response_class self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function self.generate_unique_id_function = generate_unique_id_function
self.ignore_trailing_whitespaces = ignore_trailing_whitespaces
def route( def route(
self, self,
@ -844,6 +852,8 @@ class APIRouter(routing.Router):
name: Optional[str] = None, name: Optional[str] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route( self.add_route(
path, path,
@ -890,6 +900,8 @@ class APIRouter(routing.Router):
Callable[[APIRoute], str], DefaultPlaceholder Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id), ] = Default(generate_unique_id),
) -> None: ) -> None:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
route_class = route_class_override or self.route_class route_class = route_class_override or self.route_class
responses = responses or {} responses = responses or {}
combined_responses = {**self.responses, **responses} combined_responses = {**self.responses, **responses}
@ -1008,6 +1020,8 @@ class APIRouter(routing.Router):
*, *,
dependencies: Optional[Sequence[params.Depends]] = None, dependencies: Optional[Sequence[params.Depends]] = None,
) -> None: ) -> None:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
current_dependencies = self.dependencies.copy() current_dependencies = self.dependencies.copy()
if dependencies: if dependencies:
current_dependencies.extend(dependencies) current_dependencies.extend(dependencies)
@ -1091,6 +1105,8 @@ class APIRouter(routing.Router):
def websocket_route( def websocket_route(
self, path: str, name: Union[str, None] = None self, path: str, name: Union[str, None] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name) self.add_websocket_route(path, func, name=name)
return func return func
@ -1248,6 +1264,9 @@ class APIRouter(routing.Router):
if responses is None: if responses is None:
responses = {} responses = {}
for route in router.routes: for route in router.routes:
path = route.path
if self.ignore_trailing_whitespaces:
path = path.rstrip("/")
if isinstance(route, APIRoute): if isinstance(route, APIRoute):
combined_responses = {**responses, **route.responses} combined_responses = {**responses, **route.responses}
use_response_class = get_value_or_default( use_response_class = get_value_or_default(
@ -1278,7 +1297,7 @@ class APIRouter(routing.Router):
self.generate_unique_id_function, self.generate_unique_id_function,
) )
self.add_api_route( self.add_api_route(
prefix + route.path, prefix + path,
route.endpoint, route.endpoint,
response_model=route.response_model, response_model=route.response_model,
status_code=route.status_code, status_code=route.status_code,
@ -1310,7 +1329,7 @@ class APIRouter(routing.Router):
elif isinstance(route, routing.Route): elif isinstance(route, routing.Route):
methods = list(route.methods or []) methods = list(route.methods or [])
self.add_route( self.add_route(
prefix + route.path, prefix + path,
route.endpoint, route.endpoint,
methods=methods, methods=methods,
include_in_schema=route.include_in_schema, include_in_schema=route.include_in_schema,
@ -1323,14 +1342,14 @@ class APIRouter(routing.Router):
if route.dependencies: if route.dependencies:
current_dependencies.extend(route.dependencies) current_dependencies.extend(route.dependencies)
self.add_api_websocket_route( self.add_api_websocket_route(
prefix + route.path, prefix + path,
route.endpoint, route.endpoint,
dependencies=current_dependencies, dependencies=current_dependencies,
name=route.name, name=route.name,
) )
elif isinstance(route, routing.WebSocketRoute): elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route( self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name prefix + path, route.endpoint, name=route.name
) )
for handler in router.on_startup: for handler in router.on_startup:
self.add_event_handler("startup", handler) self.add_event_handler("startup", handler)

29
tests/test_ignore_trailing_slash.py

@ -0,0 +1,29 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
recognizing_app = FastAPI()
ignoring_app = FastAPI(ignore_trailing_whitespaces=True)
@recognizing_app.get("/example")
@ignoring_app.get("/example")
async def return_data():
return {"msg": "Reached the route!"}
recognizing_client = TestClient(recognizing_app)
ignoring_client = TestClient(ignoring_app)
def test_recognizing_trailing_slash():
response = recognizing_client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = recognizing_client.get("/example/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"].endswith("/example")
def test_ignoring_trailing_slash():
response = ignoring_client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = ignoring_client.get("/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
Loading…
Cancel
Save