pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
6.1 KiB
208 lines
6.1 KiB
import pytest
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import RedirectResponse
|
|
from fastapi.staticfiles import AuthStaticFiles
|
|
from fastapi.testclient import TestClient
|
|
from starlette.responses import HTMLResponse, Response
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def static_dir(tmp_path_factory):
|
|
d = tmp_path_factory.mktemp("static")
|
|
(d / "public.txt").write_text("public content")
|
|
(d / "secret.txt").write_text("secret content")
|
|
return d
|
|
|
|
|
|
async def verify_token(request: Request) -> None:
|
|
"""Simple token-based auth for testing."""
|
|
token = request.headers.get("Authorization")
|
|
if token != "Bearer valid-token":
|
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
|
|
|
|
async def _allow_all(request: Request) -> None:
|
|
"""Auth function that allows all requests."""
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def app(static_dir):
|
|
app = FastAPI()
|
|
|
|
# Public static files (no auth)
|
|
app.mount(
|
|
"/public",
|
|
AuthStaticFiles(
|
|
directory=str(static_dir),
|
|
auth=_allow_all,
|
|
),
|
|
name="public",
|
|
)
|
|
|
|
# Private static files (requires auth)
|
|
app.mount(
|
|
"/private",
|
|
AuthStaticFiles(
|
|
directory=str(static_dir),
|
|
auth=verify_token,
|
|
),
|
|
name="private",
|
|
)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client(app):
|
|
with TestClient(app) as c:
|
|
yield c
|
|
|
|
|
|
def test_private_file_without_auth(client: TestClient):
|
|
"""Requesting a private file without auth should return 401."""
|
|
response = client.get("/private/secret.txt")
|
|
assert response.status_code == 401
|
|
assert response.text == "Not authenticated"
|
|
|
|
|
|
def test_private_file_with_wrong_token(client: TestClient):
|
|
"""Requesting a private file with wrong token should return 401."""
|
|
response = client.get(
|
|
"/private/secret.txt",
|
|
headers={"Authorization": "Bearer wrong-token"},
|
|
)
|
|
assert response.status_code == 401
|
|
assert response.text == "Not authenticated"
|
|
|
|
|
|
def test_private_file_with_valid_token(client: TestClient):
|
|
"""Requesting a private file with valid token should return the file."""
|
|
response = client.get(
|
|
"/private/secret.txt",
|
|
headers={"Authorization": "Bearer valid-token"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.text == "secret content"
|
|
|
|
|
|
def test_private_file_not_found_with_valid_token(client: TestClient):
|
|
"""Requesting a non-existent private file with valid auth should return 404."""
|
|
response = client.get(
|
|
"/private/nonexistent.txt",
|
|
headers={"Authorization": "Bearer valid-token"},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_public_files_accessible(client: TestClient):
|
|
"""Public mount with allow-all auth should serve files without auth."""
|
|
response = client.get("/public/public.txt")
|
|
assert response.status_code == 200
|
|
assert response.text == "public content"
|
|
|
|
|
|
def test_auth_headers_forwarded(static_dir):
|
|
"""Auth errors with custom headers should forward them in the response."""
|
|
|
|
async def auth_with_headers(request: Request) -> None:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Login required",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
app = FastAPI()
|
|
app.mount(
|
|
"/protected",
|
|
AuthStaticFiles(directory=str(static_dir), auth=auth_with_headers),
|
|
name="protected",
|
|
)
|
|
|
|
with TestClient(app) as client:
|
|
response = client.get("/protected/public.txt")
|
|
assert response.status_code == 401
|
|
assert response.headers["WWW-Authenticate"] == "Bearer"
|
|
assert response.text == "Login required"
|
|
|
|
|
|
def test_cookie_based_auth(static_dir):
|
|
"""AuthStaticFiles should work with cookie-based authentication."""
|
|
|
|
async def verify_cookie(request: Request) -> None:
|
|
session = request.cookies.get("session_id")
|
|
if session != "valid-session":
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
app = FastAPI()
|
|
app.mount(
|
|
"/dashboard",
|
|
AuthStaticFiles(directory=str(static_dir), auth=verify_cookie),
|
|
name="dashboard",
|
|
)
|
|
|
|
with TestClient(app) as client:
|
|
# Without cookie
|
|
response = client.get("/dashboard/public.txt")
|
|
assert response.status_code == 403
|
|
|
|
# With valid cookie
|
|
client.cookies.set("session_id", "valid-session")
|
|
response = client.get("/dashboard/public.txt")
|
|
assert response.status_code == 200
|
|
assert response.text == "public content"
|
|
|
|
|
|
def test_custom_on_error_redirect(static_dir):
|
|
"""on_error can redirect to a login page."""
|
|
|
|
async def deny_all(request: Request) -> None:
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
async def redirect_to_login(request: Request, exc: HTTPException) -> Response:
|
|
return RedirectResponse(url="/login", status_code=302)
|
|
|
|
app = FastAPI()
|
|
app.mount(
|
|
"/protected",
|
|
AuthStaticFiles(
|
|
directory=str(static_dir),
|
|
auth=deny_all,
|
|
on_error=redirect_to_login,
|
|
),
|
|
name="protected",
|
|
)
|
|
|
|
with TestClient(app, follow_redirects=False) as client:
|
|
response = client.get("/protected/public.txt")
|
|
assert response.status_code == 302
|
|
assert response.headers["location"] == "/login"
|
|
|
|
|
|
def test_custom_on_error_html(static_dir):
|
|
"""on_error can return a custom HTML error page."""
|
|
|
|
async def deny_all(request: Request) -> None:
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
async def html_error(request: Request, exc: HTTPException) -> Response:
|
|
return HTMLResponse(
|
|
f"<h1>{exc.status_code} {exc.detail}</h1>",
|
|
status_code=exc.status_code,
|
|
)
|
|
|
|
app = FastAPI()
|
|
app.mount(
|
|
"/protected",
|
|
AuthStaticFiles(
|
|
directory=str(static_dir),
|
|
auth=deny_all,
|
|
on_error=html_error,
|
|
),
|
|
name="protected",
|
|
)
|
|
|
|
with TestClient(app) as client:
|
|
response = client.get("/protected/public.txt")
|
|
assert response.status_code == 403
|
|
assert "<h1>403 Forbidden</h1>" in response.text
|
|
|