Browse Source

Add `iter_route_contexts()` for advanced use cases that used to use `router.routes` (e.g. Jupyverse) (#15785)

pull/15694/merge
Sebastián Ramírez 22 hours ago
committed by GitHub
parent
commit
6ac122071d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 26
      fastapi/openapi/utils.py
  2. 55
      fastapi/routing.py
  3. 105
      tests/test_router_include_context.py

26
fastapi/openapi/utils.py

@ -479,26 +479,22 @@ def get_openapi_path(
def _get_api_route_for_openapi(
route: BaseRoute, route_context: routing._EffectiveRouteContext | None
route_context: routing.RouteContext,
) -> routing._APIRouteLike | None:
if route_context is not None and isinstance(
route_context.original_route, routing.APIRoute
):
if isinstance(route_context.original_route, routing.APIRoute):
return cast(routing._APIRouteLike, route_context)
if isinstance(route, routing.APIRoute):
return cast(routing._APIRouteLike, route)
return None
def get_fields_from_routes(
routes: Sequence[BaseRoute],
routes: Sequence[BaseRoute | routing.RouteContext],
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is None:
continue
if api_route.include_in_schema:
@ -531,8 +527,8 @@ def get_openapi(
openapi_version: str = "3.1.0",
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
webhooks: Sequence[BaseRoute] | None = None,
routes: Sequence[BaseRoute | routing.RouteContext],
webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
@ -567,8 +563,8 @@ def get_openapi(
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is not None:
result = get_openapi_path(
route=api_route,
@ -587,8 +583,8 @@ def get_openapi(
)
if path_definitions:
definitions.update(path_definitions)
for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
for webhook_context in routing.iter_route_contexts(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook_context)
if api_webhook is not None:
result = get_openapi_path(
route=api_webhook,

55
fastapi/routing.py

@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
return URLPath(path=path, protocol="http")
@dataclass(frozen=True)
class RouteContext:
route: BaseRoute
_route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
@property
def original_route(self) -> BaseRoute:
if self._route_context is not None:
return self._route_context.original_route
return self.route
@property
def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
if self._route_context is not None:
return self._route_context
return self.route
@property
def path(self) -> str | None:
return getattr(self._effective_route, "path", None)
@property
def path_format(self) -> str | None:
return getattr(self._effective_route, "path_format", None)
@property
def name(self) -> str | None:
return getattr(self._effective_route, "name", None)
@property
def methods(self) -> set[str] | None:
return getattr(self._effective_route, "methods", None)
@property
def endpoint(self) -> Callable[..., Any] | None:
return getattr(self._effective_route, "endpoint", None)
def __getattr__(self, name: str) -> Any:
return getattr(self._effective_route, name)
@dataclass
class _IncludedRouter(BaseRoute):
original_router: "APIRouter"
@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
yield route
def iter_route_contexts(
routes: Sequence[BaseRoute | RouteContext],
) -> Iterator[RouteContext]:
for route in routes:
if isinstance(route, RouteContext):
yield route
continue
for original_route, route_context in _iter_routes_with_context([route]):
if route_context is None:
yield RouteContext(original_route)
else:
yield RouteContext(original_route, route_context)
def _iter_routes_with_context(
routes: Sequence[BaseRoute],
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:

105
tests/test_router_include_context.py

@ -1,16 +1,21 @@
from typing import Annotated, cast
import pytest
from fastapi import APIRouter, Body, Depends, FastAPI, Request
from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security
from fastapi.exceptions import FastAPIError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.routing import (
APIRoute,
RouteContext,
_IncludedRouter,
_iter_included_route_candidates,
_restore_fastapi_scope_key,
iter_route_contexts,
)
from fastapi.security import HTTPBearer
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
@ -30,6 +35,104 @@ def unique_id_b(route: APIRoute) -> str:
return f"b_{route.name}"
def test_iter_route_contexts_returns_direct_route_context():
router = APIRouter()
@router.get("/items/{item_id}")
def read_item(item_id: str): # pragma: no cover
return {"item_id": item_id}
contexts = list(iter_route_contexts(router.routes))
assert len(contexts) == 1
assert isinstance(contexts[0], RouteContext)
assert contexts[0].original_route is router.routes[0]
assert contexts[0].path == "/items/{item_id}"
assert contexts[0].path_format == "/items/{item_id}"
assert contexts[0].methods == {"GET"}
assert contexts[0].endpoint is read_item
def test_iter_route_contexts_supports_nested_conflict_detection():
existing_router = APIRouter()
nested_router = APIRouter()
@nested_router.get("/{username}")
def read_user(username: str): # pragma: no cover
return {"username": username}
existing_router.include_router(nested_router, prefix="/auth/user")
new_router = APIRouter()
@new_router.get("/auth/user/{username}")
def read_user_again(username: str): # pragma: no cover
return {"username": username}
existing_paths = {
context.path for context in iter_route_contexts(existing_router.routes)
}
new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
assert existing_paths & new_paths == {"/auth/user/{username}"}
def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
router = APIRouter()
bearer_scheme = HTTPBearer()
@router.get("/public", tags=["public"])
def read_public(token: Annotated[str, Security(bearer_scheme)]): # pragma: no cover
return {"public": True}
@router.get("/private", tags=["private"])
def read_private(): # pragma: no cover
return {"private": True}
app = FastAPI()
app.include_router(router, prefix="/api")
public_routes = [
context
for context in iter_route_contexts(app.routes)
if "public" in getattr(context, "tags", [])
]
schema = get_openapi(
title="Public API",
version="1.0.0",
routes=public_routes,
)
assert set(schema["paths"]) == {"/api/public"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
def test_get_openapi_accepts_webhook_route_contexts():
app = FastAPI()
bearer_scheme = HTTPBearer()
class Subscription(BaseModel):
username: str
@app.webhooks.post("new-subscription")
def new_subscription(
body: Subscription, token: Annotated[str, Security(bearer_scheme)]
): # pragma: no cover
return None
webhook_contexts = list(iter_route_contexts(app.webhooks.routes))
schema = get_openapi(
title="Webhook API",
version="1.0.0",
routes=[],
webhooks=webhook_contexts,
)
assert set(schema["webhooks"]) == {"new-subscription"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
assert "Subscription" in schema["components"]["schemas"]
def test_router_include_context_matches_flattened_include_metadata():
callback_router = APIRouter()

Loading…
Cancel
Save