Browse Source

Formatting

pull/12145/head
Synrom 7 months ago
parent
commit
25accde3c6
  1. 3
      fastapi/applications.py
  2. 6
      fastapi/routing.py
  3. 20
      tests/test_ignore_trailing_slash.py

3
fastapi/applications.py

@ -978,12 +978,15 @@ class FastAPI(Starlette):
self.middleware_stack: Union[ASGIApp, None] = None
if ignore_trailing_slash:
def ignore_trailing_whitespace_middleware(app):
async def ignore_trailing_whitespace_wrapper(scope, receive, send):
if scope["type"] in {"http", "websocket"}:
scope["path"] = scope["path"].rstrip("/")
await app(scope, receive, send)
return ignore_trailing_whitespace_wrapper
self.add_middleware(ignore_trailing_whitespace_middleware)
self.setup()

6
fastapi/routing.py

@ -881,6 +881,7 @@ class APIRouter(routing.Router):
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_route(
path,
@ -1134,6 +1135,7 @@ class APIRouter(routing.Router):
) -> Callable[[DecoratedCallable], DecoratedCallable]:
if self.ignore_trailing_slash:
path = path.rstrip("/")
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_websocket_route(path, func, name=name)
return func
@ -1375,9 +1377,7 @@ class APIRouter(routing.Router):
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + path, route.endpoint, name=route.name
)
self.add_websocket_route(prefix + path, route.endpoint, name=route.name)
for handler in router.on_startup:
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:

20
tests/test_ignore_trailing_slash.py

@ -1,41 +1,49 @@
from fastapi import FastAPI, WebSocket, APIRouter
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi.testclient import TestClient
app = FastAPI(ignore_trailing_slash=True)
router = APIRouter()
@app.get("/example")
async def example_endpoint():
return {"msg": "Example"}
@app.get("/example2/")
async def example_endpoint():
async def example_endpoint_with_slash():
return {"msg": "Example 2"}
@app.websocket("/websocket")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket")
await websocket.close()
@app.websocket("/websocket2/")
async def websocket_endpoint(websocket: WebSocket):
async def websocket_endpoint_with_slash(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Websocket 2")
await websocket.close()
@router.get("/example")
def route_endpoint():
return {"msg": "Routing Example"}
@router.get("/example2/")
def route_endpoint():
def route_endpoint_with_slash():
return {"msg": "Routing Example 2"}
app.include_router(router, prefix="/router")
client = TestClient(app)
def test_ignoring_trailing_slash():
response = client.get("/example/", follow_redirects=False)
assert response.status_code == 200
@ -44,16 +52,18 @@ def test_ignoring_trailing_slash():
assert response.status_code == 200
assert response.json()["msg"] == "Example 2"
def test_ignoring_trailing_shlash_ws():
with client.websocket_connect("/websocket/") as websocket:
assert websocket.receive_text() == "Websocket"
with client.websocket_connect("/websocket2") as websocket:
assert websocket.receive_text() == "Websocket 2"
def test_ignoring_trailing_routing():
response = client.get("router/example/", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example"
response = client.get("router/example2", follow_redirects=False)
assert response.status_code == 200
assert response.json()["msg"] == "Routing Example 2"
assert response.json()["msg"] == "Routing Example 2"

Loading…
Cancel
Save