Browse Source

Add docs, more tests and condition in middleware

pull/12145/head
Synrom 8 months ago
parent
commit
6823cb4a20
  1. 10
      fastapi/applications.py
  2. 7
      fastapi/routing.py
  3. 30
      tests/test_ignore_trailing_slash.py

10
fastapi/applications.py

@ -814,6 +814,13 @@ class FastAPI(Starlette):
bool, bool,
Doc( Doc(
""" """
To ignore (or not) trailing slashes at the end of URIs.
For example, by setting `ignore_trailing_slash` to True,
requests to `/auth` and `/auth/` will have the same behaviour.
By default (`ignore_trailing_slash` is False), the two requests are treated differently.
One of them will result in a 307-redirect.
""" """
), ),
] = False, ] = False,
@ -973,7 +980,8 @@ class FastAPI(Starlette):
if ignore_trailing_slash: if ignore_trailing_slash:
def ignore_trailing_whitespace_middleware(app): def ignore_trailing_whitespace_middleware(app):
async def ignore_trailing_whitespace_wrapper(scope, receive, send): async def ignore_trailing_whitespace_wrapper(scope, receive, send):
scope["path"] = scope["path"].rstrip("/") if scope["type"] in {"http", "websocket"}:
scope["path"] = scope["path"].rstrip("/")
await app(scope, receive, send) await app(scope, receive, send)
return ignore_trailing_whitespace_wrapper return ignore_trailing_whitespace_wrapper
self.add_middleware(ignore_trailing_whitespace_middleware) self.add_middleware(ignore_trailing_whitespace_middleware)

7
fastapi/routing.py

@ -815,6 +815,13 @@ class APIRouter(routing.Router):
bool, bool,
Doc( Doc(
""" """
To ignore (or not) trailing slashes at the end of URIs.
For example, by setting `ignore_trailing_slash` to True,
requests to `/auth` and `/auth/` will have the same behaviour.
By default (`ignore_trailing_slash` is False), the two requests are treated differently.
One of them will result in a 307-redirect.
""" """
), ),
] = False, ] = False,

30
tests/test_ignore_trailing_slash.py

@ -8,38 +8,52 @@ router = APIRouter()
async def example_endpoint(): async def example_endpoint():
return {"msg": "Example"} return {"msg": "Example"}
@app.get("/example2/")
async def example_endpoint():
return {"msg": "Example 2"}
@app.websocket("/websocket") @app.websocket("/websocket")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
await websocket.send_text("Websocket") await websocket.send_text("Websocket")
await websocket.close() await websocket.close()
@app.websocket("/websocket2/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket 2")
await websocket.close()
@router.get("/example") @router.get("/example")
def route_endpoint(): def route_endpoint():
return {"msg": "Routing Example"} return {"msg": "Routing Example"}
@router.get("/example2/")
def route_endpoint():
return {"msg": "Routing Example 2"}
app.include_router(router, prefix="/router") app.include_router(router, prefix="/router")
client = TestClient(app) client = TestClient(app)
def test_ignoring_trailing_slash(): def test_ignoring_trailing_slash():
response = client.get("/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example"
response = client.get("/example/", follow_redirects=False) response = client.get("/example/", follow_redirects=False)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["msg"] == "Example" assert response.json()["msg"] == "Example"
response = client.get("/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Example 2"
def test_ignoring_trailing_shlash_ws(): def test_ignoring_trailing_shlash_ws():
with client.websocket_connect("/websocket") as websocket:
assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket/") as websocket: with client.websocket_connect("/websocket/") as websocket:
assert websocket.receive_text() == "Websocket" assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket2") as websocket:
assert websocket.receive_text() == "Websocket 2"
def test_ignoring_trailing_routing(): def test_ignoring_trailing_routing():
response = client.get("router/example", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example"
response = client.get("router/example/", follow_redirects=False) response = client.get("router/example/", follow_redirects=False)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["msg"] == "Routing Example" assert response.json()["msg"] == "Routing Example"
response = client.get("router/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example 2"
Loading…
Cancel
Save