Browse Source

Formatting

pull/12145/head
Synrom 7 months ago
parent
commit
25accde3c6
  1. 3
      fastapi/applications.py
  2. 6
      fastapi/routing.py
  3. 20
      tests/test_ignore_trailing_slash.py

3
fastapi/applications.py

@ -978,12 +978,15 @@ class FastAPI(Starlette):
self.middleware_stack: Union[ASGIApp, None] = None self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_slash: if ignore_trailing_slash:
def ignore_trailing_whitespace_middleware(app): def ignore_trailing_whitespace_middleware(app):
async def ignore_trailing_whitespace_wrapper(scope, receive, send): async def ignore_trailing_whitespace_wrapper(scope, receive, send):
if scope["type"] in {"http", "websocket"}: if scope["type"] in {"http", "websocket"}:
scope["path"] = scope["path"].rstrip("/") scope["path"] = scope["path"].rstrip("/")
await app(scope, receive, send) await app(scope, receive, send)
return ignore_trailing_whitespace_wrapper return ignore_trailing_whitespace_wrapper
self.add_middleware(ignore_trailing_whitespace_middleware) self.add_middleware(ignore_trailing_whitespace_middleware)
self.setup() self.setup()

6
fastapi/routing.py

@ -881,6 +881,7 @@ class APIRouter(routing.Router):
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash: if self.ignore_trailing_slash:
path = path.rstrip("/") path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route( self.add_route(
path, path,
@ -1134,6 +1135,7 @@ class APIRouter(routing.Router):
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash: if self.ignore_trailing_slash:
path = path.rstrip("/") path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name) self.add_websocket_route(path, func, name=name)
return func return func
@ -1375,9 +1377,7 @@ class APIRouter(routing.Router):
name=route.name, name=route.name,
) )
elif isinstance(route, routing.WebSocketRoute): elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route( self.add_websocket_route(prefix + path, route.endpoint, name=route.name)
prefix + path, route.endpoint, name=route.name
)
for handler in router.on_startup: for handler in router.on_startup:
self.add_event_handler("startup", handler) self.add_event_handler("startup", handler)
for handler in router.on_shutdown: for handler in router.on_shutdown:

20
tests/test_ignore_trailing_slash.py

@ -1,41 +1,49 @@
from fastapi import FastAPI, WebSocket, APIRouter from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
app = FastAPI(ignore_trailing_slash=True) app = FastAPI(ignore_trailing_slash=True)
router = APIRouter() router = APIRouter()
@app.get("/example") @app.get("/example")
async def example_endpoint(): async def example_endpoint():
return {"msg": "Example"} return {"msg": "Example"}
@app.get("/example2/") @app.get("/example2/")
async def example_endpoint(): async def example_endpoint_with_slash():
return {"msg": "Example 2"} return {"msg": "Example 2"}
@app.websocket("/websocket") @app.websocket("/websocket")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
await websocket.send_text("Websocket") await websocket.send_text("Websocket")
await websocket.close() await websocket.close()
@app.websocket("/websocket2/") @app.websocket("/websocket2/")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint_with_slash(websocket: WebSocket):
await websocket.accept() await websocket.accept()
await websocket.send_text("Websocket 2") await websocket.send_text("Websocket 2")
await websocket.close() await websocket.close()
@router.get("/example") @router.get("/example")
def route_endpoint(): def route_endpoint():
return {"msg": "Routing Example"} return {"msg": "Routing Example"}
@router.get("/example2/") @router.get("/example2/")
def route_endpoint(): def route_endpoint_with_slash():
return {"msg": "Routing Example 2"} return {"msg": "Routing Example 2"}
app.include_router(router, prefix="/router") app.include_router(router, prefix="/router")
client = TestClient(app) client = TestClient(app)
def test_ignoring_trailing_slash(): def test_ignoring_trailing_slash():
response = client.get("/example/", follow_redirects=False) response = client.get("/example/", follow_redirects=False)
assert response.status_code == 200 assert response.status_code == 200
@ -44,16 +52,18 @@ def test_ignoring_trailing_slash():
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["msg"] == "Example 2" assert response.json()["msg"] == "Example 2"
def test_ignoring_trailing_shlash_ws(): def test_ignoring_trailing_shlash_ws():
with client.websocket_connect("/websocket/") as websocket: with client.websocket_connect("/websocket/") as websocket:
assert websocket.receive_text() == "Websocket" assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket2") as websocket: with client.websocket_connect("/websocket2") as websocket:
assert websocket.receive_text() == "Websocket 2" assert websocket.receive_text() == "Websocket 2"
def test_ignoring_trailing_routing(): def test_ignoring_trailing_routing():
response = client.get("router/example/", follow_redirects=False) response = client.get("router/example/", follow_redirects=False)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["msg"] == "Routing Example" assert response.json()["msg"] == "Routing Example"
response = client.get("router/example2", follow_redirects=False) response = client.get("router/example2", follow_redirects=False)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["msg"] == "Routing Example 2" assert response.json()["msg"] == "Routing Example 2"

Loading…
Cancel
Save