Browse Source

Address review feedback: plain text errors, configurable on_error, perf docs

- Changed default error response from JSON to plain text (browser-friendly)
- Added optional `on_error` callback for custom error responses (redirect
  to login page, HTML error pages, etc.)
- Added performance note about keeping auth checks lightweight
- Added tests for redirect and HTML custom error responses

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
pull/15295/head
faisalsaificode 2 months ago
parent
commit
edd5be62d6
  1. 28
      fastapi/staticfiles.py
  2. 73
      tests/test_auth_static_files.py

28
fastapi/staticfiles.py

@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import PlainTextResponse, Response
from starlette.staticfiles import StaticFiles as StaticFiles # noqa from starlette.staticfiles import StaticFiles as StaticFiles # noqa
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
@ -45,12 +45,23 @@ class AuthStaticFiles(StaticFiles):
* `auth`: An async callable that takes a `Request` object and performs * `auth`: An async callable that takes a `Request` object and performs
authentication. It should raise an `HTTPException` if authentication authentication. It should raise an `HTTPException` if authentication
fails, or return `None` if authentication succeeds. fails, or return `None` if authentication succeeds.
* `on_error`: An optional callable that takes a `Request` and an
`HTTPException` and returns a `Response`. Use this to customize
error responses (e.g., redirect to login, return HTML instead of
plain text). If not provided, a plain text error response is returned.
* `directory`: The directory to serve files from. * `directory`: The directory to serve files from.
* `packages`: A list of Python packages to serve files from. * `packages`: A list of Python packages to serve files from.
* `html`: If `True`, serves `index.html` files for directories. * `html`: If `True`, serves `index.html` files for directories.
* `check_dir`: If `True`, checks that the directory exists on startup. * `check_dir`: If `True`, checks that the directory exists on startup.
* `follow_symlink`: If `True`, follows symbolic links. * `follow_symlink`: If `True`, follows symbolic links.
## Performance Note
The `auth` callable runs on **every static file request** (CSS, JS,
images, etc.). Prefer lightweight checks (header presence, JWT signature
verification) over expensive operations (database lookups) to avoid
slowing down page loads.
Ref: https://github.com/fastapi/fastapi/issues/858 Ref: https://github.com/fastapi/fastapi/issues/858
""" """
@ -63,6 +74,7 @@ class AuthStaticFiles(StaticFiles):
check_dir: bool = True, check_dir: bool = True,
follow_symlink: bool = False, follow_symlink: bool = False,
auth: Callable[[Request], Awaitable[Any]], auth: Callable[[Request], Awaitable[Any]],
on_error: Callable[[Request, Any], Awaitable[Response]] | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
directory=directory, directory=directory,
@ -72,6 +84,7 @@ class AuthStaticFiles(StaticFiles):
follow_symlink=follow_symlink, follow_symlink=follow_symlink,
) )
self.auth = auth self.auth = auth
self.on_error = on_error
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http": if scope["type"] == "http":
@ -82,11 +95,14 @@ class AuthStaticFiles(StaticFiles):
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
if isinstance(exc, HTTPException): if isinstance(exc, HTTPException):
response = JSONResponse( if self.on_error is not None:
{"detail": exc.detail}, response = await self.on_error(request, exc)
status_code=exc.status_code, else:
headers=getattr(exc, "headers", None), response = PlainTextResponse(
) str(exc.detail),
status_code=exc.status_code,
headers=getattr(exc, "headers", None),
)
await response(scope, receive, send) await response(scope, receive, send)
return return
raise raise

73
tests/test_auth_static_files.py

@ -1,7 +1,9 @@
import pytest import pytest
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import AuthStaticFiles from fastapi.staticfiles import AuthStaticFiles
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.responses import HTMLResponse, Response
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -19,6 +21,11 @@ async def verify_token(request: Request) -> None:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
async def _allow_all(request: Request) -> None:
"""Auth function that allows all requests."""
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def app(static_dir): def app(static_dir):
app = FastAPI() app = FastAPI()
@ -46,11 +53,6 @@ def app(static_dir):
return app return app
async def _allow_all(request: Request) -> None:
"""Auth function that allows all requests."""
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(app): def client(app):
with TestClient(app) as c: with TestClient(app) as c:
@ -61,7 +63,7 @@ def test_private_file_without_auth(client: TestClient):
"""Requesting a private file without auth should return 401.""" """Requesting a private file without auth should return 401."""
response = client.get("/private/secret.txt") response = client.get("/private/secret.txt")
assert response.status_code == 401 assert response.status_code == 401
assert response.json() == {"detail": "Not authenticated"} assert response.text == "Not authenticated"
def test_private_file_with_wrong_token(client: TestClient): def test_private_file_with_wrong_token(client: TestClient):
@ -71,7 +73,7 @@ def test_private_file_with_wrong_token(client: TestClient):
headers={"Authorization": "Bearer wrong-token"}, headers={"Authorization": "Bearer wrong-token"},
) )
assert response.status_code == 401 assert response.status_code == 401
assert response.json() == {"detail": "Not authenticated"} assert response.text == "Not authenticated"
def test_private_file_with_valid_token(client: TestClient): def test_private_file_with_valid_token(client: TestClient):
@ -121,7 +123,7 @@ def test_auth_headers_forwarded(static_dir):
response = client.get("/protected/public.txt") response = client.get("/protected/public.txt")
assert response.status_code == 401 assert response.status_code == 401
assert response.headers["WWW-Authenticate"] == "Bearer" assert response.headers["WWW-Authenticate"] == "Bearer"
assert response.json() == {"detail": "Login required"} assert response.text == "Login required"
def test_cookie_based_auth(static_dir): def test_cookie_based_auth(static_dir):
@ -149,3 +151,58 @@ def test_cookie_based_auth(static_dir):
response = client.get("/dashboard/public.txt") response = client.get("/dashboard/public.txt")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "public content" assert response.text == "public content"
def test_custom_on_error_redirect(static_dir):
"""on_error can redirect to a login page."""
async def deny_all(request: Request) -> None:
raise HTTPException(status_code=401, detail="Unauthorized")
async def redirect_to_login(request: Request, exc: HTTPException) -> Response:
return RedirectResponse(url="/login", status_code=302)
app = FastAPI()
app.mount(
"/protected",
AuthStaticFiles(
directory=str(static_dir),
auth=deny_all,
on_error=redirect_to_login,
),
name="protected",
)
with TestClient(app, follow_redirects=False) as client:
response = client.get("/protected/public.txt")
assert response.status_code == 302
assert response.headers["location"] == "/login"
def test_custom_on_error_html(static_dir):
"""on_error can return a custom HTML error page."""
async def deny_all(request: Request) -> None:
raise HTTPException(status_code=403, detail="Forbidden")
async def html_error(request: Request, exc: HTTPException) -> Response:
return HTMLResponse(
f"<h1>{exc.status_code} {exc.detail}</h1>",
status_code=exc.status_code,
)
app = FastAPI()
app.mount(
"/protected",
AuthStaticFiles(
directory=str(static_dir),
auth=deny_all,
on_error=html_error,
),
name="protected",
)
with TestClient(app) as client:
response = client.get("/protected/public.txt")
assert response.status_code == 403
assert "<h1>403 Forbidden</h1>" in response.text

Loading…
Cancel
Save