Browse Source

Fix: catch starlette.HTTPException, support sync auth callables

Addresses bot review feedback:
- Catch starlette.exceptions.HTTPException (fastapi.HTTPException is a
  subclass, so both are covered). FastAPI's security modules raise
  starlette.HTTPException, which previously fell through unhandled.
- Support sync auth callables via starlette.run_in_threadpool. Detected
  at construction time with inspect.iscoroutinefunction.
- Add tests for sync callables and starlette.HTTPException handling.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
pull/15295/head
faisalsaificode 2 months ago
parent
commit
8f490285f1
  1. 48
      fastapi/staticfiles.py
  2. 44
      tests/test_auth_static_files.py

48
fastapi/staticfiles.py

@ -1,11 +1,16 @@
import inspect
from collections.abc import Awaitable, Callable
from typing import Any
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.staticfiles import StaticFiles as StaticFiles # noqa
from starlette.types import Receive, Scope, Send
AuthCallable = Callable[[Request], Awaitable[Any] | Any]
class AuthStaticFiles(StaticFiles):
"""
@ -42,9 +47,10 @@ class AuthStaticFiles(StaticFiles):
## Parameters
* `auth`: An async callable that takes a `Request` object and performs
authentication. It should raise an `HTTPException` if authentication
fails, or return `None` if authentication succeeds.
* `auth`: A sync or async callable that takes a `Request` object and
performs authentication. It should raise an `HTTPException` if
authentication fails, or return `None` if authentication succeeds.
Sync callables are automatically run in a threadpool.
* `on_error`: An optional callable that takes a `Request` and an
`HTTPException` and returns a `Response`. Use this to customize
error responses (e.g., redirect to login, return HTML instead of
@ -73,8 +79,8 @@ class AuthStaticFiles(StaticFiles):
html: bool = False,
check_dir: bool = True,
follow_symlink: bool = False,
auth: Callable[[Request], Awaitable[Any]],
on_error: Callable[[Request, Any], Awaitable[Response]] | None = None,
auth: AuthCallable,
on_error: Callable[[Request, HTTPException], Awaitable[Response]] | None = None,
) -> None:
super().__init__(
directory=directory,
@ -84,26 +90,26 @@ class AuthStaticFiles(StaticFiles):
follow_symlink=follow_symlink,
)
self.auth = auth
self._auth_is_async = inspect.iscoroutinefunction(auth)
self.on_error = on_error
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
request = Request(scope, receive)
try:
await self.auth(request)
except Exception as exc:
from fastapi.exceptions import HTTPException
if isinstance(exc, HTTPException):
if self.on_error is not None:
response = await self.on_error(request, exc)
else:
response = PlainTextResponse(
str(exc.detail),
status_code=exc.status_code,
headers=getattr(exc, "headers", None),
)
await response(scope, receive, send)
return
raise
if self._auth_is_async:
await self.auth(request)
else:
await run_in_threadpool(self.auth, request)
except HTTPException as exc:
if self.on_error is not None:
response = await self.on_error(request, exc)
else:
response = PlainTextResponse(
str(exc.detail),
status_code=exc.status_code,
headers=getattr(exc, "headers", None),
)
await response(scope, receive, send)
return
await super().__call__(scope, receive, send)

44
tests/test_auth_static_files.py

@ -179,6 +179,50 @@ def test_custom_on_error_redirect(static_dir):
assert response.headers["location"] == "/login"
def test_sync_auth_callable(static_dir):
"""A sync auth callable should be supported via run_in_threadpool."""
def sync_verify(request: Request) -> None:
token = request.headers.get("X-Token")
if token != "valid":
raise HTTPException(status_code=401, detail="Bad token")
app = FastAPI()
app.mount(
"/sync",
AuthStaticFiles(directory=str(static_dir), auth=sync_verify),
name="sync",
)
with TestClient(app) as client:
response = client.get("/sync/public.txt")
assert response.status_code == 401
response = client.get("/sync/public.txt", headers={"X-Token": "valid"})
assert response.status_code == 200
assert response.text == "public content"
def test_starlette_httpexception_caught(static_dir):
"""Starlette's HTTPException (used by FastAPI security modules) should be caught."""
from starlette.exceptions import HTTPException as StarletteHTTPException
async def deny_with_starlette(request: Request) -> None:
raise StarletteHTTPException(status_code=401, detail="Starlette error")
app = FastAPI()
app.mount(
"/starlette",
AuthStaticFiles(directory=str(static_dir), auth=deny_with_starlette),
name="starlette",
)
with TestClient(app) as client:
response = client.get("/starlette/public.txt")
assert response.status_code == 401
assert response.text == "Starlette error"
def test_custom_on_error_html(static_dir):
"""on_error can return a custom HTML error page."""

Loading…
Cancel
Save