diff --git a/fastapi/staticfiles.py b/fastapi/staticfiles.py index 3dd6a2f720..978f83de11 100644 --- a/fastapi/staticfiles.py +++ b/fastapi/staticfiles.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable from typing import Any from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import PlainTextResponse, Response from starlette.staticfiles import StaticFiles as StaticFiles # noqa from starlette.types import Receive, Scope, Send @@ -45,12 +45,23 @@ class AuthStaticFiles(StaticFiles): * `auth`: An async callable that takes a `Request` object and performs authentication. It should raise an `HTTPException` if authentication fails, or return `None` if authentication succeeds. + * `on_error`: An optional callable that takes a `Request` and an + `HTTPException` and returns a `Response`. Use this to customize + error responses (e.g., redirect to login, return HTML instead of + plain text). If not provided, a plain text error response is returned. * `directory`: The directory to serve files from. * `packages`: A list of Python packages to serve files from. * `html`: If `True`, serves `index.html` files for directories. * `check_dir`: If `True`, checks that the directory exists on startup. * `follow_symlink`: If `True`, follows symbolic links. + ## Performance Note + + The `auth` callable runs on **every static file request** (CSS, JS, + images, etc.). Prefer lightweight checks (header presence, JWT signature + verification) over expensive operations (database lookups) to avoid + slowing down page loads. + Ref: https://github.com/fastapi/fastapi/issues/858 """ @@ -63,6 +74,7 @@ class AuthStaticFiles(StaticFiles): check_dir: bool = True, follow_symlink: bool = False, auth: Callable[[Request], Awaitable[Any]], + on_error: Callable[[Request, Any], Awaitable[Response]] | None = None, ) -> None: super().__init__( directory=directory, @@ -72,6 +84,7 @@ class AuthStaticFiles(StaticFiles): follow_symlink=follow_symlink, ) self.auth = auth + self.on_error = on_error async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": @@ -82,11 +95,14 @@ class AuthStaticFiles(StaticFiles): from fastapi.exceptions import HTTPException if isinstance(exc, HTTPException): - response = JSONResponse( - {"detail": exc.detail}, - status_code=exc.status_code, - headers=getattr(exc, "headers", None), - ) + if self.on_error is not None: + response = await self.on_error(request, exc) + else: + response = PlainTextResponse( + str(exc.detail), + status_code=exc.status_code, + headers=getattr(exc, "headers", None), + ) await response(scope, receive, send) return raise diff --git a/tests/test_auth_static_files.py b/tests/test_auth_static_files.py index 8ecf42ba19..10be23affd 100644 --- a/tests/test_auth_static_files.py +++ b/tests/test_auth_static_files.py @@ -1,7 +1,9 @@ import pytest from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import RedirectResponse from fastapi.staticfiles import AuthStaticFiles from fastapi.testclient import TestClient +from starlette.responses import HTMLResponse, Response @pytest.fixture(scope="module") @@ -19,6 +21,11 @@ async def verify_token(request: Request) -> None: raise HTTPException(status_code=401, detail="Not authenticated") +async def _allow_all(request: Request) -> None: + """Auth function that allows all requests.""" + pass + + @pytest.fixture(scope="module") def app(static_dir): app = FastAPI() @@ -46,11 +53,6 @@ def app(static_dir): return app -async def _allow_all(request: Request) -> None: - """Auth function that allows all requests.""" - pass - - @pytest.fixture(scope="module") def client(app): with TestClient(app) as c: @@ -61,7 +63,7 @@ def test_private_file_without_auth(client: TestClient): """Requesting a private file without auth should return 401.""" response = client.get("/private/secret.txt") assert response.status_code == 401 - assert response.json() == {"detail": "Not authenticated"} + assert response.text == "Not authenticated" def test_private_file_with_wrong_token(client: TestClient): @@ -71,7 +73,7 @@ def test_private_file_with_wrong_token(client: TestClient): headers={"Authorization": "Bearer wrong-token"}, ) assert response.status_code == 401 - assert response.json() == {"detail": "Not authenticated"} + assert response.text == "Not authenticated" def test_private_file_with_valid_token(client: TestClient): @@ -121,7 +123,7 @@ def test_auth_headers_forwarded(static_dir): response = client.get("/protected/public.txt") assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Bearer" - assert response.json() == {"detail": "Login required"} + assert response.text == "Login required" def test_cookie_based_auth(static_dir): @@ -149,3 +151,58 @@ def test_cookie_based_auth(static_dir): response = client.get("/dashboard/public.txt") assert response.status_code == 200 assert response.text == "public content" + + +def test_custom_on_error_redirect(static_dir): + """on_error can redirect to a login page.""" + + async def deny_all(request: Request) -> None: + raise HTTPException(status_code=401, detail="Unauthorized") + + async def redirect_to_login(request: Request, exc: HTTPException) -> Response: + return RedirectResponse(url="/login", status_code=302) + + app = FastAPI() + app.mount( + "/protected", + AuthStaticFiles( + directory=str(static_dir), + auth=deny_all, + on_error=redirect_to_login, + ), + name="protected", + ) + + with TestClient(app, follow_redirects=False) as client: + response = client.get("/protected/public.txt") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + +def test_custom_on_error_html(static_dir): + """on_error can return a custom HTML error page.""" + + async def deny_all(request: Request) -> None: + raise HTTPException(status_code=403, detail="Forbidden") + + async def html_error(request: Request, exc: HTTPException) -> Response: + return HTMLResponse( + f"

{exc.status_code} {exc.detail}

", + status_code=exc.status_code, + ) + + app = FastAPI() + app.mount( + "/protected", + AuthStaticFiles( + directory=str(static_dir), + auth=deny_all, + on_error=html_error, + ), + name="protected", + ) + + with TestClient(app) as client: + response = client.get("/protected/public.txt") + assert response.status_code == 403 + assert "

403 Forbidden

" in response.text