diff --git a/fastapi/staticfiles.py b/fastapi/staticfiles.py index 978f83de11..8e09fe9674 100644 --- a/fastapi/staticfiles.py +++ b/fastapi/staticfiles.py @@ -1,11 +1,16 @@ +import inspect from collections.abc import Awaitable, Callable from typing import Any +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.staticfiles import StaticFiles as StaticFiles # noqa from starlette.types import Receive, Scope, Send +AuthCallable = Callable[[Request], Awaitable[Any] | Any] + class AuthStaticFiles(StaticFiles): """ @@ -42,9 +47,10 @@ class AuthStaticFiles(StaticFiles): ## Parameters - * `auth`: An async callable that takes a `Request` object and performs - authentication. It should raise an `HTTPException` if authentication - fails, or return `None` if authentication succeeds. + * `auth`: A sync or async callable that takes a `Request` object and + performs authentication. It should raise an `HTTPException` if + authentication fails, or return `None` if authentication succeeds. + Sync callables are automatically run in a threadpool. * `on_error`: An optional callable that takes a `Request` and an `HTTPException` and returns a `Response`. Use this to customize error responses (e.g., redirect to login, return HTML instead of @@ -73,8 +79,8 @@ class AuthStaticFiles(StaticFiles): html: bool = False, check_dir: bool = True, follow_symlink: bool = False, - auth: Callable[[Request], Awaitable[Any]], - on_error: Callable[[Request, Any], Awaitable[Response]] | None = None, + auth: AuthCallable, + on_error: Callable[[Request, HTTPException], Awaitable[Response]] | None = None, ) -> None: super().__init__( directory=directory, @@ -84,26 +90,26 @@ class AuthStaticFiles(StaticFiles): follow_symlink=follow_symlink, ) self.auth = auth + self._auth_is_async = inspect.iscoroutinefunction(auth) self.on_error = on_error async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": request = Request(scope, receive) try: - await self.auth(request) - except Exception as exc: - from fastapi.exceptions import HTTPException - - if isinstance(exc, HTTPException): - if self.on_error is not None: - response = await self.on_error(request, exc) - else: - response = PlainTextResponse( - str(exc.detail), - status_code=exc.status_code, - headers=getattr(exc, "headers", None), - ) - await response(scope, receive, send) - return - raise + if self._auth_is_async: + await self.auth(request) + else: + await run_in_threadpool(self.auth, request) + except HTTPException as exc: + if self.on_error is not None: + response = await self.on_error(request, exc) + else: + response = PlainTextResponse( + str(exc.detail), + status_code=exc.status_code, + headers=getattr(exc, "headers", None), + ) + await response(scope, receive, send) + return await super().__call__(scope, receive, send) diff --git a/tests/test_auth_static_files.py b/tests/test_auth_static_files.py index 10be23affd..300d399296 100644 --- a/tests/test_auth_static_files.py +++ b/tests/test_auth_static_files.py @@ -179,6 +179,50 @@ def test_custom_on_error_redirect(static_dir): assert response.headers["location"] == "/login" +def test_sync_auth_callable(static_dir): + """A sync auth callable should be supported via run_in_threadpool.""" + + def sync_verify(request: Request) -> None: + token = request.headers.get("X-Token") + if token != "valid": + raise HTTPException(status_code=401, detail="Bad token") + + app = FastAPI() + app.mount( + "/sync", + AuthStaticFiles(directory=str(static_dir), auth=sync_verify), + name="sync", + ) + + with TestClient(app) as client: + response = client.get("/sync/public.txt") + assert response.status_code == 401 + + response = client.get("/sync/public.txt", headers={"X-Token": "valid"}) + assert response.status_code == 200 + assert response.text == "public content" + + +def test_starlette_httpexception_caught(static_dir): + """Starlette's HTTPException (used by FastAPI security modules) should be caught.""" + from starlette.exceptions import HTTPException as StarletteHTTPException + + async def deny_with_starlette(request: Request) -> None: + raise StarletteHTTPException(status_code=401, detail="Starlette error") + + app = FastAPI() + app.mount( + "/starlette", + AuthStaticFiles(directory=str(static_dir), auth=deny_with_starlette), + name="starlette", + ) + + with TestClient(app) as client: + response = client.get("/starlette/public.txt") + assert response.status_code == 401 + assert response.text == "Starlette error" + + def test_custom_on_error_html(static_dir): """on_error can return a custom HTML error page."""