From 319be508ce7db9ee5f52c3b9baa68c6cc1037c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 1 Jul 2026 18:12:33 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Support=20dependencies=20in=20`app.?= =?UTF-8?q?frontend()`,=20e.g.=20for=20automatic=20cookie=20authentication?= =?UTF-8?q?=20for=20the=20frontend=20(#15908)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/en/docs/tutorial/frontend.md | 6 + fastapi/routing.py | 279 ++++++++++++++++++------ tests/test_frontend.py | 345 +++++++++++++++++++++++++++++- 3 files changed, 563 insertions(+), 67 deletions(-) diff --git a/docs/en/docs/tutorial/frontend.md b/docs/en/docs/tutorial/frontend.md index 4cbc21fa13..433cea2751 100644 --- a/docs/en/docs/tutorial/frontend.md +++ b/docs/en/docs/tutorial/frontend.md @@ -126,6 +126,12 @@ In this example, frontend paths are served under `/app`. Any regular *path operations* in the app will still take precedence, including in other routers. +## Dependencies and Middleware { #dependencies-and-middleware } + +Frontend responses run inside the normal **FastAPI** application, so HTTP middleware applies to them. + +Dependencies from the app, from an `APIRouter`, and from `include_router()` also apply to frontend responses. This can be useful for protecting a frontend with cookie authentication or similar. + ## Static Build Output Only { #static-build-output-only } `app.frontend()` serves files already generated by your frontend build. diff --git a/fastapi/routing.py b/fastapi/routing.py index e41ef6a599..c442b122ba 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -797,17 +797,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant( - path=self.path_format, call=self.endpoint, scope="function" - ) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), - ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params + ( + self.dependant, + self._flat_dependant, + self._embed_body_fields, + ) = _build_dependant_with_parameterless_dependencies( + path=self.path_format, + call=self.endpoint, + dependencies=self.dependencies, ) self.app = websocket_session( get_websocket_app( @@ -827,6 +824,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): _FASTAPI_SCOPE_KEY = "fastapi" _FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY = "effective_route_context" _FASTAPI_FRONTEND_PATH_KEY = "frontend_path" +_FASTAPI_FRONTEND_SPECIFICITY_KEY = "frontend_specificity" _FASTAPI_INCLUDED_ROUTER_KEY = "included_router" _effective_route_context_var: ContextVar[Any | None] = ContextVar( "fastapi_effective_route_context", default=None @@ -834,6 +832,27 @@ _effective_route_context_var: ContextVar[Any | None] = ContextVar( _SCOPE_MISSING = object() +def _frontend_dependency_endpoint() -> None: + pass # pragma: no cover + + +def _build_dependant_with_parameterless_dependencies( + *, + path: str, + call: Callable[..., Any], + dependencies: Sequence[params.Depends], +) -> tuple[Dependant, Dependant, bool]: + dependant = get_dependant(path=path, call=call, scope="function") + for depends in dependencies[::-1]: + dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=path), + ) + flat_dependant = get_flat_dependant(dependant) + embed_body_fields = _should_embed_body_fields(flat_dependant.body_params) + return dependant, flat_dependant, embed_body_fields + + class _RouteWithPath(Protocol): path: str @@ -861,6 +880,15 @@ def _get_scope_included_router(scope: Scope) -> Any | None: return scope.get(_FASTAPI_SCOPE_KEY, {}).get(_FASTAPI_INCLUDED_ROUTER_KEY) +def _frontend_scope_specificity(scope: Scope) -> int | None: + specificity = scope.get(_FASTAPI_SCOPE_KEY, {}).get( + _FASTAPI_FRONTEND_SPECIFICITY_KEY + ) + if isinstance(specificity, int): + return specificity + return None + + def _restore_fastapi_scope_key(scope: Scope, key: str, previous: Any) -> None: fastapi_scope = scope.get(_FASTAPI_SCOPE_KEY) if not isinstance(fastapi_scope, dict): @@ -1053,17 +1081,14 @@ def _populate_api_route_state( route.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - route.dependant = get_dependant( - path=route.path_format, call=route.endpoint, scope="function" - ) - for depends in route.dependencies[::-1]: - route.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=route.path_format), - ) - route._flat_dependant = get_flat_dependant(route.dependant) - route._embed_body_fields = _should_embed_body_fields( - route._flat_dependant.body_params + ( + route.dependant, + route._flat_dependant, + route._embed_body_fields, + ) = _build_dependant_with_parameterless_dependencies( + path=route.path_format, + call=route.endpoint, + dependencies=route.dependencies, ) route.body_field = get_body_field( flat_dependant=route._flat_dependant, @@ -1334,6 +1359,7 @@ class _RouterIncludeContext: class _EffectiveRouteContext: original_route: BaseRoute starlette_route: BaseRoute | None = None + frontend_prefix: str = "" path: str = "" endpoint: Callable[..., Any] | None = None stream_item_type: Any | None = None @@ -1436,7 +1462,34 @@ class _EffectiveRouteContext: ) return context + @classmethod + def from_frontend_route_group( + cls, + *, + original_route: "_FrontendRouteGroup", + include_context: _RouterIncludeContext, + ) -> "_EffectiveRouteContext": + dependencies = [*include_context.dependencies, *original_route.dependencies] + context = cls( + original_route=original_route, + frontend_prefix=include_context.prefix, + dependencies=dependencies, + dependency_overrides_provider=include_context.dependency_overrides_provider, + ) + ( + context.dependant, + context._flat_dependant, + context._embed_body_fields, + ) = _build_dependant_with_parameterless_dependencies( + path="", + call=_frontend_dependency_endpoint, + dependencies=dependencies, + ) + return context + def matches(self, scope: Scope) -> tuple[Match, Scope]: + if isinstance(self.original_route, _FrontendRouteGroup): + return self.original_route.matches_with_prefix(scope, self.frontend_prefix) if not isinstance(self.original_route, APIRoute): assert self.starlette_route is not None return self.starlette_route.matches(scope) @@ -1579,9 +1632,9 @@ class _IncludedRouter(BaseRoute): include_context=self.include_context, ) if isinstance(route, _FrontendRouteGroup): - return _EffectiveRouteContext( + return _EffectiveRouteContext.from_frontend_route_group( original_route=route, - starlette_route=route.with_prefix(self.include_context.prefix), + include_context=self.include_context, ) if isinstance(route, routing.Route): starlette_route: BaseRoute = routing.Route( @@ -1970,28 +2023,31 @@ class _FrontendRoute(BaseRoute): directory=directory, fallback=fallback, check_dir=check_dir ) - def with_path(self, path: str) -> "_FrontendRoute": - route = copy.copy(self) - route.path = _normalize_frontend_path(path) - return route - def matches(self, scope: Scope) -> tuple[Match, Scope]: + return self.matches_with_path(scope, self.path) + + def matches_with_path(self, scope: Scope, path: str) -> tuple[Match, Scope]: if scope["type"] != "http": return Match.NONE, {} - frontend_path = self._get_frontend_path(get_route_path(scope)) + frontend_path = self._get_frontend_path(path, get_route_path(scope)) if frontend_path is None: return Match.NONE, {} - child_scope = {_FASTAPI_SCOPE_KEY: {_FASTAPI_FRONTEND_PATH_KEY: frontend_path}} + child_scope = { + _FASTAPI_SCOPE_KEY: { + _FASTAPI_FRONTEND_PATH_KEY: frontend_path, + _FASTAPI_FRONTEND_SPECIFICITY_KEY: _frontend_path_specificity(path), + } + } if scope["method"] not in self.methods: return Match.PARTIAL, child_scope return Match.FULL, child_scope - def _get_frontend_path(self, route_path: str) -> str | None: - if self.path == "/": + def _get_frontend_path(self, path: str, route_path: str) -> str | None: + if path == "/": return route_path.lstrip("/") - if route_path == self.path: + if route_path == path: return "" - prefix = self.path + "/" + prefix = path + "/" if route_path.startswith(prefix): return route_path[len(prefix) :] return None @@ -2004,8 +2060,24 @@ class _FrontendRoute(BaseRoute): class _FrontendRouteGroup(BaseRoute): - def __init__(self) -> None: + def __init__( + self, + *, + dependencies: Sequence[params.Depends] | None = None, + dependency_overrides_provider: Any | None = None, + ) -> None: self.routes: list[_FrontendRoute] = [] + self.dependencies = list(dependencies or []) + self.dependency_overrides_provider = dependency_overrides_provider + ( + self.dependant, + self._flat_dependant, + self._embed_body_fields, + ) = _build_dependant_with_parameterless_dependencies( + path="", + call=_frontend_dependency_endpoint, + dependencies=self.dependencies, + ) def add_frontend_route( self, @@ -2024,51 +2096,116 @@ class _FrontendRouteGroup(BaseRoute): ) ) - def with_prefix(self, prefix: str) -> "_FrontendRouteGroup": - route_group = copy.copy(self) - route_group.routes = [ - route.with_path(_join_frontend_paths(prefix, route.path)) - for route in self.routes - ] - return route_group - def matches(self, scope: Scope) -> tuple[Match, Scope]: - match, child_scope, _ = self._match(scope) + match, child_scope, _ = self._match(scope, prefix="") return match, child_scope - def _match(self, scope: Scope) -> tuple[Match, Scope, _FrontendRoute | None]: - full: tuple[Scope, _FrontendRoute] | None = None - partial: tuple[Scope, _FrontendRoute] | None = None + def matches_with_prefix(self, scope: Scope, prefix: str) -> tuple[Match, Scope]: + match, child_scope, _ = self._match(scope, prefix=prefix) + return match, child_scope + + def _match( + self, scope: Scope, *, prefix: str + ) -> tuple[Match, Scope, _FrontendRoute | None]: + full: tuple[Scope, _FrontendRoute, int] | None = None + partial: tuple[Scope, _FrontendRoute, int] | None = None for route in self.routes: - match, child_scope = route.matches(scope) + path = _join_frontend_paths(prefix, route.path) + match, child_scope = route.matches_with_path(scope, path) + specificity = _frontend_path_specificity(path) if match == Match.FULL: - if full is None or _frontend_path_specificity( - route.path - ) > _frontend_path_specificity(full[1].path): - full = (child_scope, route) + if full is None or specificity > full[2]: + full = (child_scope, route, specificity) elif match == Match.PARTIAL: - if partial is None or _frontend_path_specificity( - route.path - ) > _frontend_path_specificity(partial[1].path): - partial = (child_scope, route) + if partial is None or specificity > partial[2]: + partial = (child_scope, route, specificity) if full is not None: - child_scope, route = full + child_scope, route, _ = full return Match.FULL, child_scope, route if partial is not None: - child_scope, route = partial + child_scope, route, _ = partial return Match.PARTIAL, child_scope, route return Match.NONE, {}, None async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: - match, child_scope, route = self._match(scope) + effective_context = _get_scope_effective_route_context(scope) + if ( + isinstance(effective_context, _EffectiveRouteContext) + and effective_context.original_route is self + ): + prefix = effective_context.frontend_prefix + dependant = effective_context.dependant + dependency_overrides_provider = ( + effective_context.dependency_overrides_provider + ) + embed_body_fields = effective_context._embed_body_fields + else: + prefix = "" + dependant = self.dependant + dependency_overrides_provider = self.dependency_overrides_provider + embed_body_fields = self._embed_body_fields + match, child_scope, route = self._match(scope, prefix=prefix) if match == Match.NONE or route is None: raise HTTPException(status_code=404) _update_scope(scope, child_scope) + if match == Match.FULL and dependant and dependant.dependencies: + async with self._solve_dependencies( + scope, + receive, + send, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, + embed_body_fields=embed_body_fields, + ): + await route.handle(scope, receive, send) + return await route.handle(scope, receive, send) def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: raise NoMatchFound(name, path_params) + # TODO: probably move this out of the Route / Route Group, same in APIRoute + # this should probably be top level FastAPI logic, not part of APIRoute and + # duplicated here + @asynccontextmanager + async def _solve_dependencies( + self, + scope: Scope, + receive: Receive, + send: Send, + *, + dependant: Dependant, + dependency_overrides_provider: Any | None, + embed_body_fields: bool, + ) -> AsyncIterator[None]: + request = Request(scope, receive, send) + previous_inner_astack = scope.get("fastapi_inner_astack", _SCOPE_MISSING) + previous_function_astack = scope.get("fastapi_function_astack", _SCOPE_MISSING) + try: + async with AsyncExitStack() as request_stack: + scope["fastapi_inner_astack"] = request_stack + async with AsyncExitStack() as function_stack: + scope["fastapi_function_astack"] = function_stack + solved_result = await solve_dependencies( + request=request, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, + async_exit_stack=request_stack, + embed_body_fields=embed_body_fields, + ) + if solved_result.errors: + raise RequestValidationError(solved_result.errors) + yield + finally: + if previous_inner_astack is _SCOPE_MISSING: + scope.pop("fastapi_inner_astack", None) + else: + scope["fastapi_inner_astack"] = previous_inner_astack + if previous_function_astack is _SCOPE_MISSING: + scope.pop("fastapi_function_astack", None) + else: + scope["fastapi_function_astack"] = previous_function_astack + class APIRouter(routing.Router): """ @@ -2515,7 +2652,10 @@ class APIRouter(routing.Router): """ normalized_path = _normalize_frontend_path(path) if self._frontend_routes is None: - self._frontend_routes = _FrontendRouteGroup() + self._frontend_routes = _FrontendRouteGroup( + dependencies=self.dependencies, + dependency_overrides_provider=self.dependency_overrides_provider, + ) self._low_priority_routes.append(self._frontend_routes) self._frontend_routes.add_frontend_route( _join_frontend_paths(self.prefix, normalized_path), @@ -2650,10 +2790,14 @@ class APIRouter(routing.Router): match, child_scope = candidate.matches(scope) route = candidate if match == Match.FULL: - if full is None: + if full is None or self._frontend_match_is_more_specific( + child_scope, full[0] + ): full = (child_scope, route, route_context) elif match == Match.PARTIAL: - if partial is None: + if partial is None or self._frontend_match_is_more_specific( + child_scope, partial[0] + ): partial = (child_scope, route, route_context) if full is not None: child_scope, route, route_context = full @@ -2663,6 +2807,15 @@ class APIRouter(routing.Router): return Match.PARTIAL, child_scope, route, route_context return Match.NONE, {}, None, None + def _frontend_match_is_more_specific( + self, child_scope: Scope, previous_child_scope: Scope + ) -> bool: + specificity = _frontend_scope_specificity(child_scope) + previous_specificity = _frontend_scope_specificity(previous_child_scope) + if specificity is None or previous_specificity is None: + return False + return specificity > previous_specificity + def route( self, path: str, diff --git a/tests/test_frontend.py b/tests/test_frontend.py index 81cffc2285..b8dd62fb8a 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -1,12 +1,13 @@ import errno import os import runpy +from contextlib import AsyncExitStack from pathlib import Path from typing import Literal import anyio import pytest -from fastapi import APIRouter, FastAPI, HTTPException, Request, WebSocket +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, WebSocket from fastapi.testclient import TestClient from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.responses import PlainTextResponse, Response @@ -18,11 +19,18 @@ def write_file(path: Path, content: str) -> None: path.write_text(content) +def record_dependency(calls: list[str], name: str): + def dependency() -> None: + calls.append(name) + + return dependency + + def test_frontend_exact_prefix_path_serves_index(tmp_path: Path): dist = tmp_path / "dist" write_file(dist / "index.html", "app") app = FastAPI() - app.frontend("/app", directory=dist) + app.frontend("/", directory=dist) response = TestClient(app).get("/app") @@ -109,7 +117,7 @@ def test_frontend_route_group_helpers(tmp_path: Path): dist = tmp_path / "dist" write_file(dist / "index.html", "app") app = FastAPI() - app.frontend("/", directory=dist) + app.frontend("/app", directory=dist) route_group = app.router._frontend_routes assert route_group is not None @@ -117,9 +125,22 @@ def test_frontend_route_group_helpers(tmp_path: Path): assert match == Match.NONE assert child_scope == {} + match, child_scope = route_group.matches_with_prefix( + {"type": "http", "path": "/prefix/app", "method": "GET"}, + "/prefix", + ) + assert match == Match.FULL + assert child_scope["fastapi"]["frontend_path"] == "" + + match, child_scope = route_group.routes[0].matches( + {"type": "http", "path": "/app", "method": "GET"} + ) + assert match == Match.FULL + assert child_scope["fastapi"]["frontend_path"] == "" + with pytest.raises(StarletteHTTPException) as exc_info: anyio.run( - route_group.with_prefix("/app").handle, + route_group.handle, {"type": "http", "path": "/missing", "method": "GET"}, None, None, @@ -246,6 +267,322 @@ def test_basic_file_serving(tmp_path: Path): assert "last-modified" in response.headers +def test_app_frontend_dependencies_protect_root_asset_and_fallback(tmp_path: Path): + calls: list[str] = [] + + def require_cookie(request: Request) -> None: + calls.append(request.url.path) + if request.cookies.get("session") != "ok": + raise HTTPException(status_code=401) + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + write_file(dist / "assets" / "app.js", "console.log('ok')") + app = FastAPI(dependencies=[Depends(require_cookie)]) + app.frontend("/", directory=dist, fallback="index.html") + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 401 + + response = client.get("/", headers={"cookie": "session=ok"}) + assert response.status_code == 200 + assert response.text == "app" + + response = client.get("/assets/app.js", headers={"cookie": "session=ok"}) + assert response.status_code == 200 + assert response.text == "console.log('ok')" + + response = client.get( + "/dashboard", + headers={"accept": "text/html", "cookie": "session=ok"}, + ) + assert response.status_code == 200 + assert response.text == "app" + assert calls == ["/", "/", "/assets/app.js", "/dashboard"] + + +def test_apirouter_frontend_dependencies_protect_prefixed_frontend(tmp_path: Path): + def require_cookie(request: Request) -> None: + if request.cookies.get("session") != "ok": + raise HTTPException(status_code=401) + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + write_file(dist / "assets" / "app.js", "console.log('ok')") + router = APIRouter(dependencies=[Depends(require_cookie)]) + router.frontend("/", directory=dist, fallback="index.html") + app = FastAPI() + app.include_router(router, prefix="/app") + client = TestClient(app) + + response = client.get("/app/") + assert response.status_code == 401 + + response = client.get("/app/", headers={"cookie": "session=ok"}) + assert response.status_code == 200 + assert response.text == "app" + + response = client.get("/app/assets/app.js", headers={"cookie": "session=ok"}) + assert response.status_code == 200 + assert response.text == "console.log('ok')" + + response = client.get( + "/app/dashboard", + headers={"accept": "text/html", "cookie": "session=ok"}, + ) + assert response.status_code == 200 + assert response.text == "app" + + +def test_included_frontend_does_not_block_url_path_for(tmp_path: Path): + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + frontend_router = APIRouter() + frontend_router.frontend("/", directory=dist) + api_router = APIRouter() + + @api_router.get("/api", name="read_api") + def read_api(): + return {"ok": True} + + app = FastAPI() + app.include_router(frontend_router, prefix="/app") + app.include_router(api_router) + included_frontend = next( + route + for route in app.router.routes + if hasattr(route, "effective_low_priority_routes") + ) + + with pytest.raises(NoMatchFound): + included_frontend.url_path_for("missing") + assert app.url_path_for("read_api") == "/api" + response = TestClient(app).get("/api") + assert response.status_code == 200 + assert response.json() == {"ok": True} + + +def test_include_router_frontend_dependencies_apply_in_nested_order(tmp_path: Path): + calls: list[str] = [] + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + child = APIRouter(dependencies=[Depends(record_dependency(calls, "child"))]) + child.frontend("/ui", directory=dist) + parent = APIRouter(dependencies=[Depends(record_dependency(calls, "parent"))]) + parent.include_router( + child, + prefix="/child", + dependencies=[Depends(record_dependency(calls, "parent-include"))], + ) + app = FastAPI(dependencies=[Depends(record_dependency(calls, "app"))]) + app.include_router( + parent, + prefix="/parent", + dependencies=[Depends(record_dependency(calls, "app-include"))], + ) + + response = TestClient(app).get("/parent/child/ui/") + + assert response.status_code == 200 + assert response.text == "app" + assert calls == ["app", "app-include", "parent", "parent-include", "child"] + + +def test_frontend_dependency_overrides_apply(tmp_path: Path): + calls: list[str] = [] + + def require_cookie() -> None: + raise HTTPException(status_code=401) # pragma: no cover + + def allow_cookie() -> None: + calls.append("override") + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + app = FastAPI(dependencies=[Depends(require_cookie)]) + app.dependency_overrides[require_cookie] = allow_cookie + app.frontend("/", directory=dist) + + response = TestClient(app).get("/") + + assert response.status_code == 200 + assert response.text == "app" + assert calls == ["override"] + + +def test_frontend_dependencies_do_not_run_when_api_route_wins(tmp_path: Path): + calls: list[str] = [] + + def frontend_dependency() -> None: + calls.append("frontend") # pragma: no cover + + dist = tmp_path / "dist" + write_file(dist / "api", "frontend") + router = APIRouter(dependencies=[Depends(frontend_dependency)]) + router.frontend("/", directory=dist) + app = FastAPI() + + @app.get("/api") + def read_api(): + return {"source": "api"} + + app.include_router(router) + + response = TestClient(app).get("/api") + + assert response.status_code == 200 + assert response.json() == {"source": "api"} + assert calls == [] + + +def test_only_selected_frontend_mount_dependencies_run(tmp_path: Path): + calls: list[str] = [] + + site = tmp_path / "site" + admin = tmp_path / "admin" + write_file(site / "index.html", "site") + write_file(admin / "index.html", "admin") + site_router = APIRouter() + site_router.frontend("/", directory=site) + admin_router = APIRouter() + admin_router.frontend("/", directory=admin) + app = FastAPI() + app.include_router( + site_router, dependencies=[Depends(record_dependency(calls, "site"))] + ) + app.include_router( + admin_router, + prefix="/admin", + dependencies=[Depends(record_dependency(calls, "admin"))], + ) + + response = TestClient(app).get("/admin/") + + assert response.status_code == 200 + assert response.text == "admin" + assert calls == ["admin"] + + +def test_app_middleware_still_runs_for_frontend_dependencies(tmp_path: Path): + calls: list[str] = [] + + def frontend_dependency() -> None: + calls.append("dependency") + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + app = FastAPI(dependencies=[Depends(frontend_dependency)]) + + @app.middleware("http") + async def record_middleware(request: Request, call_next): + calls.append("middleware-before") + response = await call_next(request) + calls.append("middleware-after") + return response + + app.frontend("/", directory=dist) + + response = TestClient(app).get("/") + + assert response.status_code == 200 + assert response.text == "app" + assert calls == ["middleware-before", "dependency", "middleware-after"] + + +def test_frontend_dependency_validation_errors_return_422(tmp_path: Path): + def require_token(token: str) -> None: + pass # pragma: no cover + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + app = FastAPI(dependencies=[Depends(require_token)]) + app.frontend("/", directory=dist) + + response = TestClient(app).get("/") + + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "token"], + "msg": "Field required", + "input": None, + } + ] + } + + +@pytest.mark.anyio +async def test_frontend_dependency_restores_existing_dependency_stacks( + tmp_path: Path, +): + def frontend_dependency() -> None: + pass + + dist = tmp_path / "dist" + write_file(dist / "index.html", "app") + app = FastAPI(dependencies=[Depends(frontend_dependency)]) + app.frontend("/", directory=dist) + assert app.router._frontend_routes is not None + inner_stack = AsyncExitStack() + function_stack = AsyncExitStack() + scope = { + "type": "http", + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": "/", + "root_path": "", + "query_string": b"", + "headers": [], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "fastapi_inner_astack": inner_stack, + "fastapi_function_astack": function_stack, + } + messages = [] + + async def receive(): + return { # pragma: no cover + "type": "http.request", + "body": b"", + "more_body": False, + } + + async def send(message): + messages.append(message) + + async with inner_stack, function_stack: + await app.router._frontend_routes.handle(scope, receive, send) + + assert scope["fastapi_inner_astack"] is inner_stack + assert scope["fastapi_function_astack"] is function_stack + assert messages[0]["type"] == "http.response.start" + assert messages[0]["status"] == 200 + + +def test_non_frontend_low_priority_route_keeps_order_before_frontend( + tmp_path: Path, +): + async def low_priority_endpoint(request: Request): + return PlainTextResponse("low") + + dist = tmp_path / "dist" + write_file(dist / "index.html", "frontend") + app = FastAPI() + app.router._low_priority_routes.append(Route("/admin", low_priority_endpoint)) + app.router._mark_routes_changed() + app.frontend("/", directory=dist) + + response = TestClient(app).get("/admin") + + assert response.status_code == 200 + assert response.text == "low" + + def test_existing_api_route_wins_over_frontend(tmp_path: Path): dist = tmp_path / "dist" write_file(dist / "api" / "users", "frontend")