Browse Source

Fix middleware and add tests

pull/12145/head
Synrom 7 months ago
parent
commit
b00fbd3427
  1. 17
      fastapi/applications.py
  2. 14
      fastapi/routing.py
  3. 56
      tests/test_ignore_trailing_slash.py

17
fastapi/applications.py

@ -810,7 +810,7 @@ class FastAPI(Starlette):
"""
),
] = True,
ignore_trailing_whitespaces: Annotated[
ignore_trailing_slash: Annotated[
bool,
Doc(
"""
@ -950,7 +950,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
ignore_trailing_whitespaces=ignore_trailing_whitespaces,
ignore_trailing_slash=ignore_trailing_slash,
)
self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
@ -970,11 +970,14 @@ class FastAPI(Starlette):
)
self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_whitespaces:
async def middleware_ignore_tailing_whitespace(request: Request, call_next):
request.scope["path"] = request.scope["path"].rstrip("/")
return await call_next(request)
self.add_middleware(BaseHTTPMiddleware, dispatch=middleware_ignore_tailing_whitespace)
if ignore_trailing_slash:
def ignore_trailing_whitespace_middleware(app):
async def ignore_trailing_whitespace_wrapper(scope, receive, send):
scope["path"] = scope["path"].rstrip("/")
await app(scope, receive, send)
return ignore_trailing_whitespace_wrapper
self.add_middleware(ignore_trailing_whitespace_middleware)
self.setup()
def openapi(self) -> Dict[str, Any]:

14
fastapi/routing.py

@ -811,7 +811,7 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
ignore_trailing_whitespaces: Annotated[
ignore_trailing_slash: Annotated[
bool,
Doc(
"""
@ -843,7 +843,7 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.ignore_trailing_whitespaces = ignore_trailing_whitespaces
self.ignore_trailing_slash = ignore_trailing_slash
def route(
self,
@ -852,7 +852,7 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
include_in_schema: bool = True,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
if self.ignore_trailing_slash:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route(
@ -900,7 +900,7 @@ class APIRouter(routing.Router):
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
) -> None:
if self.ignore_trailing_whitespaces:
if self.ignore_trailing_slash:
path = path.rstrip("/")
route_class = route_class_override or self.route_class
responses = responses or {}
@ -1020,7 +1020,7 @@ class APIRouter(routing.Router):
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
if self.ignore_trailing_whitespaces:
if self.ignore_trailing_slash:
path = path.rstrip("/")
current_dependencies = self.dependencies.copy()
if dependencies:
@ -1105,7 +1105,7 @@ class APIRouter(routing.Router):
def websocket_route(
self, path: str, name: Union[str, None] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_whitespaces:
if self.ignore_trailing_slash:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name)
@ -1265,7 +1265,7 @@ class APIRouter(routing.Router):
responses = {}
for route in router.routes:
path = route.path
if self.ignore_trailing_whitespaces:
if self.ignore_trailing_slash:
path = path.rstrip("/")
if isinstance(route, APIRoute):
combined_responses = {**responses, **route.responses}

56
tests/test_ignore_trailing_slash.py

@ -1,29 +1,45 @@
from fastapi import FastAPI
from fastapi import FastAPI, WebSocket, APIRouter
from fastapi.testclient import TestClient
recognizing_app = FastAPI()
ignoring_app = FastAPI(ignore_trailing_whitespaces=True)
app = FastAPI(ignore_trailing_slash=True)
router = APIRouter()
@recognizing_app.get("/example")
@ignoring_app.get("/example")
async def return_data():
return {"msg": "Reached the route!"}
@app.get("/example")
async def example_endpoint():
return {"msg": "Example"}
recognizing_client = TestClient(recognizing_app)
ignoring_client = TestClient(ignoring_app)
@app.websocket("/websocket")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket")
await websocket.close()
def test_recognizing_trailing_slash():
response = recognizing_client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = recognizing_client.get("/example/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"].endswith("/example")
@router.get("/example")
def route_endpoint():
return {"msg": "Routing Example"}
app.include_router(router, prefix="/router")
client = TestClient(app)
def test_ignoring_trailing_slash():
response = ignoring_client.get("/example", follow_redirects=False)
response = client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example"
response = client.get("/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example"
def test_ignoring_trailing_shlash_ws():
with client.websocket_connect("/websocket") as websocket:
assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket/") as websocket:
assert websocket.receive_text() == "Websocket"
def test_ignoring_trailing_routing():
response = client.get("router/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
response = ignoring_client.get("/example/", follow_redirects=False)
assert response.json()["msg"] == "Routing Example"
response = client.get("router/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Reached the route!"
assert response.json()["msg"] == "Routing Example"

Loading…
Cancel
Save