Browse Source
Co-authored-by: Sebastián Ramírez <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5307/head
committed by
GitHub
2 changed files with 97 additions and 0 deletions
@ -0,0 +1,95 @@ |
|||||
|
from pathlib import Path |
||||
|
from typing import Optional |
||||
|
|
||||
|
from fastapi import APIRouter, FastAPI, File, UploadFile |
||||
|
from fastapi.exceptions import HTTPException |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
router = APIRouter() |
||||
|
|
||||
|
|
||||
|
class ContentSizeLimitMiddleware: |
||||
|
"""Content size limiting middleware for ASGI applications |
||||
|
Args: |
||||
|
app (ASGI application): ASGI application |
||||
|
max_content_size (optional): the maximum content size allowed in bytes, None for no limit |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, app: APIRouter, max_content_size: Optional[int] = None): |
||||
|
self.app = app |
||||
|
self.max_content_size = max_content_size |
||||
|
|
||||
|
def receive_wrapper(self, receive): |
||||
|
received = 0 |
||||
|
|
||||
|
async def inner(): |
||||
|
nonlocal received |
||||
|
message = await receive() |
||||
|
if message["type"] != "http.request": |
||||
|
return message # pragma: no cover |
||||
|
|
||||
|
body_len = len(message.get("body", b"")) |
||||
|
received += body_len |
||||
|
if received > self.max_content_size: |
||||
|
raise HTTPException( |
||||
|
422, |
||||
|
detail={ |
||||
|
"name": "ContentSizeLimitExceeded", |
||||
|
"code": 999, |
||||
|
"message": "File limit exceeded", |
||||
|
}, |
||||
|
) |
||||
|
return message |
||||
|
|
||||
|
return inner |
||||
|
|
||||
|
async def __call__(self, scope, receive, send): |
||||
|
if scope["type"] != "http" or self.max_content_size is None: |
||||
|
await self.app(scope, receive, send) |
||||
|
return |
||||
|
|
||||
|
wrapper = self.receive_wrapper(receive) |
||||
|
await self.app(scope, wrapper, send) |
||||
|
|
||||
|
|
||||
|
@router.post("/middleware") |
||||
|
def run_middleware(file: UploadFile = File(..., description="Big File")): |
||||
|
return {"message": "OK"} |
||||
|
|
||||
|
|
||||
|
app.include_router(router) |
||||
|
app.add_middleware(ContentSizeLimitMiddleware, max_content_size=2**8) |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_custom_middleware_exception(tmp_path: Path): |
||||
|
default_pydantic_max_size = 2**16 |
||||
|
path = tmp_path / "test.txt" |
||||
|
path.write_bytes(b"x" * (default_pydantic_max_size + 1)) |
||||
|
|
||||
|
with client: |
||||
|
with open(path, "rb") as file: |
||||
|
response = client.post("/middleware", files={"file": file}) |
||||
|
assert response.status_code == 422, response.text |
||||
|
assert response.json() == { |
||||
|
"detail": { |
||||
|
"name": "ContentSizeLimitExceeded", |
||||
|
"code": 999, |
||||
|
"message": "File limit exceeded", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def test_custom_middleware_exception_not_raised(tmp_path: Path): |
||||
|
path = tmp_path / "test.txt" |
||||
|
path.write_bytes(b"<file content>") |
||||
|
|
||||
|
with client: |
||||
|
with open(path, "rb") as file: |
||||
|
response = client.post("/middleware", files={"file": file}) |
||||
|
assert response.status_code == 200, response.text |
||||
|
assert response.json() == {"message": "OK"} |
Loading…
Reference in new issue