Browse Source

Add `iter_route_contexts()` for advanced use cases that used to use `router.routes` (e.g. Jupyverse) (#15785)

pull/15790/head
Sebastián Ramírez 1 week ago
committed by GitHub
parent
commit
6ac122071d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 26
      fastapi/openapi/utils.py
  2. 55
      fastapi/routing.py
  3. 105
      tests/test_router_include_context.py

26
fastapi/openapi/utils.py

@ -479,26 +479,22 @@ def get_openapi_path(
def _get_api_route_for_openapi( def _get_api_route_for_openapi(
route: BaseRoute, route_context: routing._EffectiveRouteContext | None route_context: routing.RouteContext,
) -> routing._APIRouteLike | None: ) -> routing._APIRouteLike | None:
if route_context is not None and isinstance( if isinstance(route_context.original_route, routing.APIRoute):
route_context.original_route, routing.APIRoute
):
return cast(routing._APIRouteLike, route_context) return cast(routing._APIRouteLike, route_context)
if isinstance(route, routing.APIRoute):
return cast(routing._APIRouteLike, route)
return None return None
def get_fields_from_routes( def get_fields_from_routes(
routes: Sequence[BaseRoute], routes: Sequence[BaseRoute | routing.RouteContext],
) -> list[ModelField]: ) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = [] body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = [] responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = [] request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = [] callback_flat_models: list[ModelField] = []
for route, route_context in routing._iter_routes_with_context(routes): for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route, route_context) api_route = _get_api_route_for_openapi(route_context)
if api_route is None: if api_route is None:
continue continue
if api_route.include_in_schema: if api_route.include_in_schema:
@ -531,8 +527,8 @@ def get_openapi(
openapi_version: str = "3.1.0", openapi_version: str = "3.1.0",
summary: str | None = None, summary: str | None = None,
description: str | None = None, description: str | None = None,
routes: Sequence[BaseRoute], routes: Sequence[BaseRoute | routing.RouteContext],
webhooks: Sequence[BaseRoute] | None = None, webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
tags: list[dict[str, Any]] | None = None, tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None, servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None, terms_of_service: str | None = None,
@ -567,8 +563,8 @@ def get_openapi(
model_name_map=model_name_map, model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas, separate_input_output_schemas=separate_input_output_schemas,
) )
for route, route_context in routing._iter_routes_with_context(routes): for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route, route_context) api_route = _get_api_route_for_openapi(route_context)
if api_route is not None: if api_route is not None:
result = get_openapi_path( result = get_openapi_path(
route=api_route, route=api_route,
@ -587,8 +583,8 @@ def get_openapi(
) )
if path_definitions: if path_definitions:
definitions.update(path_definitions) definitions.update(path_definitions)
for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []): for webhook_context in routing.iter_route_contexts(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook, webhook_context) api_webhook = _get_api_route_for_openapi(webhook_context)
if api_webhook is not None: if api_webhook is not None:
result = get_openapi_path( result = get_openapi_path(
route=api_webhook, route=api_webhook,

55
fastapi/routing.py

@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
return URLPath(path=path, protocol="http") return URLPath(path=path, protocol="http")
@dataclass(frozen=True)
class RouteContext:
route: BaseRoute
_route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
@property
def original_route(self) -> BaseRoute:
if self._route_context is not None:
return self._route_context.original_route
return self.route
@property
def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
if self._route_context is not None:
return self._route_context
return self.route
@property
def path(self) -> str | None:
return getattr(self._effective_route, "path", None)
@property
def path_format(self) -> str | None:
return getattr(self._effective_route, "path_format", None)
@property
def name(self) -> str | None:
return getattr(self._effective_route, "name", None)
@property
def methods(self) -> set[str] | None:
return getattr(self._effective_route, "methods", None)
@property
def endpoint(self) -> Callable[..., Any] | None:
return getattr(self._effective_route, "endpoint", None)
def __getattr__(self, name: str) -> Any:
return getattr(self._effective_route, name)
@dataclass @dataclass
class _IncludedRouter(BaseRoute): class _IncludedRouter(BaseRoute):
original_router: "APIRouter" original_router: "APIRouter"
@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
yield route yield route
def iter_route_contexts(
routes: Sequence[BaseRoute | RouteContext],
) -> Iterator[RouteContext]:
for route in routes:
if isinstance(route, RouteContext):
yield route
continue
for original_route, route_context in _iter_routes_with_context([route]):
if route_context is None:
yield RouteContext(original_route)
else:
yield RouteContext(original_route, route_context)
def _iter_routes_with_context( def _iter_routes_with_context(
routes: Sequence[BaseRoute], routes: Sequence[BaseRoute],
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]: ) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:

105
tests/test_router_include_context.py

@ -1,16 +1,21 @@
from typing import Annotated, cast from typing import Annotated, cast
import pytest import pytest
from fastapi import APIRouter, Body, Depends, FastAPI, Request from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security
from fastapi.exceptions import FastAPIError from fastapi.exceptions import FastAPIError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.routing import ( from fastapi.routing import (
APIRoute, APIRoute,
RouteContext,
_IncludedRouter, _IncludedRouter,
_iter_included_route_candidates, _iter_included_route_candidates,
_restore_fastapi_scope_key, _restore_fastapi_scope_key,
iter_route_contexts,
) )
from fastapi.security import HTTPBearer
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
@ -30,6 +35,104 @@ def unique_id_b(route: APIRoute) -> str:
return f"b_{route.name}" return f"b_{route.name}"
def test_iter_route_contexts_returns_direct_route_context():
router = APIRouter()
@router.get("/items/{item_id}")
def read_item(item_id: str): # pragma: no cover
return {"item_id": item_id}
contexts = list(iter_route_contexts(router.routes))
assert len(contexts) == 1
assert isinstance(contexts[0], RouteContext)
assert contexts[0].original_route is router.routes[0]
assert contexts[0].path == "/items/{item_id}"
assert contexts[0].path_format == "/items/{item_id}"
assert contexts[0].methods == {"GET"}
assert contexts[0].endpoint is read_item
def test_iter_route_contexts_supports_nested_conflict_detection():
existing_router = APIRouter()
nested_router = APIRouter()
@nested_router.get("/{username}")
def read_user(username: str): # pragma: no cover
return {"username": username}
existing_router.include_router(nested_router, prefix="/auth/user")
new_router = APIRouter()
@new_router.get("/auth/user/{username}")
def read_user_again(username: str): # pragma: no cover
return {"username": username}
existing_paths = {
context.path for context in iter_route_contexts(existing_router.routes)
}
new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
assert existing_paths & new_paths == {"/auth/user/{username}"}
def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
router = APIRouter()
bearer_scheme = HTTPBearer()
@router.get("/public", tags=["public"])
def read_public(token: Annotated[str, Security(bearer_scheme)]): # pragma: no cover
return {"public": True}
@router.get("/private", tags=["private"])
def read_private(): # pragma: no cover
return {"private": True}
app = FastAPI()
app.include_router(router, prefix="/api")
public_routes = [
context
for context in iter_route_contexts(app.routes)
if "public" in getattr(context, "tags", [])
]
schema = get_openapi(
title="Public API",
version="1.0.0",
routes=public_routes,
)
assert set(schema["paths"]) == {"/api/public"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
def test_get_openapi_accepts_webhook_route_contexts():
app = FastAPI()
bearer_scheme = HTTPBearer()
class Subscription(BaseModel):
username: str
@app.webhooks.post("new-subscription")
def new_subscription(
body: Subscription, token: Annotated[str, Security(bearer_scheme)]
): # pragma: no cover
return None
webhook_contexts = list(iter_route_contexts(app.webhooks.routes))
schema = get_openapi(
title="Webhook API",
version="1.0.0",
routes=[],
webhooks=webhook_contexts,
)
assert set(schema["webhooks"]) == {"new-subscription"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
assert "Subscription" in schema["components"]["schemas"]
def test_router_include_context_matches_flattened_include_metadata(): def test_router_include_context_matches_flattened_include_metadata():
callback_router = APIRouter() callback_router = APIRouter()

Loading…
Cancel
Save