diff --git a/docs/en/docs/advanced/openapi-callbacks.md b/docs/en/docs/advanced/openapi-callbacks.md index 40cf479567..8197091108 100644 --- a/docs/en/docs/advanced/openapi-callbacks.md +++ b/docs/en/docs/advanced/openapi-callbacks.md @@ -167,13 +167,13 @@ Notice how the callback URL used contains the URL received as a query parameter At this point you have the *callback path operation(s)* needed (the one(s) that the *external developer* should implement in the *external API*) in the callback router you created above. -Now use the parameter `callbacks` in *your API's path operation decorator* to pass the attribute `.routes` (that's actually just a `list` of routes/*path operations*) from that callback router: +Now use the parameter `callbacks` in *your API's path operation decorator* to pass the attribute `.routes` from that callback router: {* ../../docs_src/openapi_callbacks/tutorial001_py310.py hl[33] *} /// tip -Notice that you are not passing the router itself (`invoices_callback_router`) to `callback=`, but the attribute `.routes`, as in `invoices_callback_router.routes`. +Notice that you are not passing the router itself (`invoices_callback_router`) to `callbacks=`, but its `.routes`, as in `invoices_callback_router.routes`. FastAPI will use those routes to generate the callback OpenAPI documentation. /// diff --git a/docs/en/docs/advanced/path-operation-advanced-configuration.md b/docs/en/docs/advanced/path-operation-advanced-configuration.md index 800bf305dc..9dca4d7123 100644 --- a/docs/en/docs/advanced/path-operation-advanced-configuration.md +++ b/docs/en/docs/advanced/path-operation-advanced-configuration.md @@ -16,17 +16,11 @@ You would have to make sure that it is unique for each operation. ### Using the *path operation function* name as the operationId { #using-the-path-operation-function-name-as-the-operationid } -If you want to use your APIs' function names as `operationId`s, you can iterate over all of them and override each *path operation's* `operation_id` using their `APIRoute.name`. +If you want to use your APIs' function names as `operationId`s, you can pass a custom `generate_unique_id_function` to `FastAPI`. -You should do it after adding all your *path operations*. +The function receives each `APIRoute` and returns the `operationId` to use for that path operation. -{* ../../docs_src/path_operation_advanced_configuration/tutorial002_py310.py hl[2, 12:21, 24] *} - -/// tip - -If you manually call `app.openapi()`, you should update the `operationId`s before that. - -/// +{* ../../docs_src/path_operation_advanced_configuration/tutorial002_py310.py hl[2,5:6,9] *} /// warning diff --git a/docs/en/docs/how-to/extending-openapi.md b/docs/en/docs/how-to/extending-openapi.md index 65f5844383..8368eea506 100644 --- a/docs/en/docs/how-to/extending-openapi.md +++ b/docs/en/docs/how-to/extending-openapi.md @@ -25,7 +25,15 @@ And that function `get_openapi()` receives as parameters: * `openapi_version`: The version of the OpenAPI specification used. By default, the latest: `3.1.0`. * `summary`: A short summary of the API. * `description`: The description of your API, this can include markdown and will be shown in the docs. -* `routes`: A list of routes, these are each of the registered *path operations*. They are taken from `app.routes`. +* `routes`: The routes from the application, taken from `app.routes`. FastAPI uses them to collect the registered *path operations*, including those from included routers. + +/// tip | Technical Details + +`app.routes` is a lower-level route tree. It can include route candidates that FastAPI uses internally for included routers, not only final `APIRoute` objects. + +You can still pass `app.routes` to `get_openapi()`. FastAPI will traverse that route tree to collect the effective path operations. + +/// /// note diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index 5303b55dfd..1018045075 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -7,6 +7,55 @@ hide: ## Latest Changes +♻️ Refactor internals to preserve `APIRouter` and `APIRoute` instances + +Unblocks ✨ SO MANY THINGS ✨ + +Before this, `router.include_router(other_router)` would take each path operation from `other_router` and "clone" it, or recreate it from scratch. + +This would mean that in the end there was only one top level router, part of the app. + +The way it is structured here is that there are a few additional classes to handle intermediate metadata for router and route inclusion. That way the information of "router X includes Y and Y includes Z" is stored somewhere, without affecting (recreating / clonning) the final route. + +#### Non Objectives + +Dependencies for 404: previously I intended to support dependencies that would be executed even for 404, but that would conflict with the fact that a router could _not_ find a match, but the next router _did_ find a match. Executing dependencies in the router that did not find a match would not make sense, they could consume the request, body, etc. This original idea was discarded. + +#### Breaking Changes + +Now `router.routes` is no longer a plain list of `APIRoute` objects, it can contain these intermediate objects that can contain additional routers, forming a tree. + +Any logic that depended on iterating on the `router.routes` directly would be affected, that logic cannot expect to be able to extract data from a plain list of routes, as it's no longer a plain list but a tree. + +Additionally, any logic that iterated on `router.routes` to modify them would now also see these new objects, and would not see all the routes in the app. + +`router.routes` should be considered an internal implementation detail, only passed around to the FastAPI functions that need it. + +#### Features + +* Adding routes (path operations) after a router is included now works, they are reflected as they are not copied. +* Including `subrouter` in `mainrouter` can be done before adding routes (path operations) to `subrouter`, because now the the entire object is stored instead of copying the routes. +* As routes are not copied, in some cases that might save some memory. + +#### Alpha Features + +This is not documented yet, so it's not officially supported yet and could change in the future. + +But, as `APIRoute` and `APIRouter` instances are now preserved, they could be customized. + +`APIRouter` has two new methods, `.matches()` and `.handle()`, counterpart to the existing ones in `APIRoute`. With this a router could customize how it matches and handles requests. For example, it could match only requests that include some specific header, for example for handling versions in headers. + +Still, for now, consider this very experimental and potentially changing and breaking in the future. + +#### Future Features Enabled + +* Custom `APIRoute` subclasses (undocumented, but alraedy works as desccribed above) +* Custom `APIRouter` subclasses (undocumented, but already works as described above) +* Dependencies per router +* Exception handlers per router +* Middleware per router +* Other features planned + ### Docs * 📝 Update FastAPI Cloud deployment instructions. PR [#15724](https://github.com/fastapi/fastapi/pull/15724) by [@alejsdev](https://github.com/alejsdev). diff --git a/docs/en/docs/tutorial/bigger-applications.md b/docs/en/docs/tutorial/bigger-applications.md index 8950d59b42..b1a6afb6a1 100644 --- a/docs/en/docs/tutorial/bigger-applications.md +++ b/docs/en/docs/tutorial/bigger-applications.md @@ -396,9 +396,9 @@ It will include all the routes from that router as part of it. /// note | Technical Details -It will actually internally create a *path operation* for each *path operation* that was declared in the `APIRouter`. +FastAPI keeps the original `APIRouter` and its `APIRoute`s active when the router is included in the main application. -So, behind the scenes, it will actually work as if everything was the same single app. +That means custom `APIRouter` and `APIRoute` subclasses can still participate after the router is included. /// @@ -406,7 +406,7 @@ So, behind the scenes, it will actually work as if everything was the same singl You don't have to worry about performance when including routers. -This will take microseconds and will only happen at startup. +This is designed to be lightweight and to avoid adding overhead to each request. So it won't affect performance. ⚡ @@ -461,7 +461,7 @@ The `APIRouter`s are not "mounted", they are not isolated from the rest of the a This is because we want to include their *path operations* in the OpenAPI schema and the user interfaces. -As we cannot just isolate them and "mount" them independently of the rest, the *path operations* are "cloned" (re-created), not included directly. +FastAPI keeps the original routers and path operations active, and combines the router prefixes, dependencies, tags, responses, and other metadata when handling requests and generating OpenAPI. /// @@ -532,4 +532,16 @@ The same way you can include an `APIRouter` in a `FastAPI` application, you can router.include_router(other_router) ``` -Make sure you do it before including `router` in the `FastAPI` app, so that the *path operations* from `other_router` are also included. +You can do this before or after including `router` in the `FastAPI` app. FastAPI will still include the *path operations* from `other_router` in routing and OpenAPI. + +The same applies to *path operations* added later to the routers. They will be visible through the earlier inclusion too. + +/// warning | Technical Details + +Avoid directly mutating `router.routes` after including a router. FastAPI treats router inclusion as live, so the original router and its routes remain part of routing and OpenAPI generation. + +Use documented APIs such as path operation decorators and `.include_router()` to add routes and routers. + +Treat `router.routes` as a lower-level route tree that can contain route definitions and included routers, and avoid relying on it as a flat list of final path operations. + +/// diff --git a/docs_src/path_operation_advanced_configuration/tutorial002_py310.py b/docs_src/path_operation_advanced_configuration/tutorial002_py310.py index 3aaae9b371..5c2257ed68 100644 --- a/docs_src/path_operation_advanced_configuration/tutorial002_py310.py +++ b/docs_src/path_operation_advanced_configuration/tutorial002_py310.py @@ -1,24 +1,14 @@ from fastapi import FastAPI from fastapi.routing import APIRoute -app = FastAPI() - -@app.get("/items/") -async def read_items(): - return [{"item_id": "Foo"}] +def custom_generate_unique_id(route: APIRoute) -> str: + return route.name -def use_route_names_as_operation_ids(app: FastAPI) -> None: - """ - Simplify operation IDs so that generated API clients have simpler function - names. +app = FastAPI(generate_unique_id_function=custom_generate_unique_id) - Should be called only after all routes have been added. - """ - for route in app.routes: - if isinstance(route, APIRoute): - route.operation_id = route.name # in this case, 'read_items' - -use_route_names_as_operation_ids(app) +@app.get("/items/") +async def read_items(): + return [{"item_id": "Foo"}] diff --git a/fastapi/applications.py b/fastapi/applications.py index faac6853fa..c7c551e4e6 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -921,6 +921,7 @@ class FastAPI(Starlette): ), ] = "3.1.0" self.openapi_schema: dict[str, Any] | None = None + self._openapi_routes_version: int | None = None if self.openapi_url: assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'" assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'" @@ -1079,7 +1080,8 @@ class FastAPI(Starlette): Read more in the [FastAPI docs for OpenAPI](https://fastapi.tiangolo.com/how-to/extending-openapi/). """ - if not self.openapi_schema: + routes_version = self.router._get_routes_version() + if not self.openapi_schema or self._openapi_routes_version != routes_version: self.openapi_schema = get_openapi( title=self.title, version=self.version, @@ -1096,6 +1098,7 @@ class FastAPI(Starlette): separate_input_output_schemas=self.separate_input_output_schemas, external_docs=self.openapi_external_docs, ) + self._openapi_routes_version = routes_version return self.openapi_schema def setup(self) -> None: diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 1c7a17c4ca..ab4543d346 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -213,7 +213,7 @@ def get_openapi_operation_request_body( def generate_operation_id( - *, route: routing.APIRoute, method: str + *, route: routing._APIRouteLike, method: str ) -> str: # pragma: nocover warnings.warn( message="fastapi.openapi.utils.generate_operation_id() was deprecated, " @@ -227,14 +227,14 @@ def generate_operation_id( return generate_operation_id_for_path(name=route.name, path=path, method=method) -def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: +def generate_operation_summary(*, route: routing._APIRouteLike, method: str) -> str: if route.summary: return route.summary return route.name.replace("_", " ").title() def get_openapi_operation_metadata( - *, route: routing.APIRoute, method: str, operation_ids: set[str] + *, route: routing._APIRouteLike, method: str, operation_ids: set[str] ) -> dict[str, Any]: operation: dict[str, Any] = {} if route.tags: @@ -259,7 +259,7 @@ def get_openapi_operation_metadata( def get_openapi_path( *, - route: routing.APIRoute, + route: routing._APIRouteLike, operation_ids: set[str], model_name_map: ModelNameMap, field_mapping: dict[ @@ -329,7 +329,7 @@ def get_openapi_path( cb_security_schemes, cb_definitions, ) = get_openapi_path( - route=callback, + route=cast(routing._APIRouteLike, callback), operation_ids=operation_ids, model_name_map=model_name_map, field_mapping=field_mapping, @@ -478,6 +478,18 @@ def get_openapi_path( return path, security_schemes, definitions +def _get_api_route_for_openapi( + route: BaseRoute, route_context: routing._EffectiveRouteContext | None +) -> routing._APIRouteLike | None: + if route_context is not None and isinstance( + route_context.original_route, routing.APIRoute + ): + return cast(routing._APIRouteLike, route_context) + if isinstance(route, routing.APIRoute): + return cast(routing._APIRouteLike, route) + return None + + def get_fields_from_routes( routes: Sequence[BaseRoute], ) -> list[ModelField]: @@ -485,24 +497,25 @@ def get_fields_from_routes( responses_from_routes: list[ModelField] = [] request_fields_from_routes: list[ModelField] = [] callback_flat_models: list[ModelField] = [] - for route in routes: - if not isinstance(route, routing.APIRoute): + for route, route_context in routing._iter_routes_with_context(routes): + api_route = _get_api_route_for_openapi(route, route_context) + if api_route is None: continue - if route.include_in_schema: - if route.body_field: - assert isinstance(route.body_field, ModelField), ( + if api_route.include_in_schema: + if api_route.body_field: + assert isinstance(api_route.body_field, ModelField), ( "A request body must be a Pydantic Field" ) - body_fields_from_routes.append(route.body_field) - if route.response_field: - responses_from_routes.append(route.response_field) - if route.response_fields: - responses_from_routes.extend(route.response_fields.values()) - if route.stream_item_field: - responses_from_routes.append(route.stream_item_field) - if route.callbacks: - callback_flat_models.extend(get_fields_from_routes(route.callbacks)) - params = get_flat_params(route.dependant) + body_fields_from_routes.append(api_route.body_field) + if api_route.response_field: + responses_from_routes.append(api_route.response_field) + if api_route.response_fields: + responses_from_routes.extend(api_route.response_fields.values()) + if api_route.stream_item_field: + responses_from_routes.append(api_route.stream_item_field) + if api_route.callbacks: + callback_flat_models.extend(get_fields_from_routes(api_route.callbacks)) + params = get_flat_params(api_route.dependant) request_fields_from_routes.extend(params) flat_models = callback_flat_models + list( @@ -546,7 +559,7 @@ def get_openapi( paths: dict[str, dict[str, Any]] = {} webhook_paths: dict[str, dict[str, Any]] = {} operation_ids: set[str] = set() - all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) + all_fields = get_fields_from_routes(list(routes) + list(webhooks or [])) flat_models = get_flat_models_from_fields(all_fields, known_models=set()) model_name_map = get_model_name_map(flat_models) field_mapping, definitions = get_definitions( @@ -554,10 +567,11 @@ def get_openapi( model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) - for route in routes or []: - if isinstance(route, routing.APIRoute): + for route, route_context in routing._iter_routes_with_context(routes): + api_route = _get_api_route_for_openapi(route, route_context) + if api_route is not None: result = get_openapi_path( - route=route, + route=api_route, operation_ids=operation_ids, model_name_map=model_name_map, field_mapping=field_mapping, @@ -566,17 +580,18 @@ def get_openapi( if result: path, security_schemes, path_definitions = result if path: - paths.setdefault(route.path_format, {}).update(path) + paths.setdefault(api_route.path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes ) if path_definitions: definitions.update(path_definitions) - for webhook in webhooks or []: - if isinstance(webhook, routing.APIRoute): + for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []): + api_webhook = _get_api_route_for_openapi(webhook, webhook_context) + if api_webhook is not None: result = get_openapi_path( - route=webhook, + route=api_webhook, operation_ids=operation_ids, model_name_map=model_name_map, field_mapping=field_mapping, @@ -585,7 +600,7 @@ def get_openapi( if result: path, security_schemes, path_definitions = result if path: - webhook_paths.setdefault(webhook.path_format, {}).update(path) + webhook_paths.setdefault(api_webhook.path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes diff --git a/fastapi/routing.py b/fastapi/routing.py index 21a1385a27..fb47843099 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,4 +1,5 @@ import contextlib +import copy import email.message import functools import inspect @@ -21,10 +22,13 @@ from contextlib import ( AsyncExitStack, asynccontextmanager, ) +from contextvars import ContextVar +from dataclasses import dataclass, field from enum import Enum, IntEnum from typing import ( Annotated, Any, + Protocol, TypeVar, cast, ) @@ -74,12 +78,17 @@ from fastapi.utils import ( ) from starlette import routing from starlette._exception_handler import wrap_app_handling_exceptions -from starlette._utils import is_async_callable +from starlette._utils import get_route_path, is_async_callable from starlette.concurrency import iterate_in_threadpool, run_in_threadpool -from starlette.datastructures import FormData +from starlette.datastructures import FormData, URLPath from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.responses import ( + JSONResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.routing import ( BaseRoute, Match, @@ -808,6 +817,250 @@ class APIWebSocketRoute(routing.WebSocketRoute): return match, child_scope +_FASTAPI_SCOPE_KEY = "fastapi" +_FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY = "effective_route_context" +_FASTAPI_INCLUDED_ROUTER_KEY = "included_router" +_effective_route_context_var: ContextVar[Any | None] = ContextVar( + "fastapi_effective_route_context", default=None +) +_SCOPE_MISSING = object() + + +def _get_fastapi_scope(scope: Scope) -> dict[str, Any]: + fastapi_scope = scope.setdefault(_FASTAPI_SCOPE_KEY, {}) + assert isinstance(fastapi_scope, dict) + return fastapi_scope + + +def _get_scope_effective_route_context(scope: Scope) -> Any | None: + return scope.get(_FASTAPI_SCOPE_KEY, {}).get(_FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY) + + +def _get_scope_included_router(scope: Scope) -> Any | None: + return scope.get(_FASTAPI_SCOPE_KEY, {}).get(_FASTAPI_INCLUDED_ROUTER_KEY) + + +def _restore_fastapi_scope_key(scope: Scope, key: str, previous: Any) -> None: + fastapi_scope = scope.get(_FASTAPI_SCOPE_KEY) + if not isinstance(fastapi_scope, dict): + return + if previous is _SCOPE_MISSING: + fastapi_scope.pop(key, None) + else: + fastapi_scope[key] = previous + + +class _APIRouteLike(Protocol): + path: str + endpoint: Callable[..., Any] + stream_item_type: Any | None + response_model: Any + summary: str | None + response_description: str + deprecated: bool | None + operation_id: str | None + response_model_include: IncEx | None + response_model_exclude: IncEx | None + response_model_by_alias: bool + response_model_exclude_unset: bool + response_model_exclude_defaults: bool + response_model_exclude_none: bool + include_in_schema: bool + response_class: type[Response] | DefaultPlaceholder + dependency_overrides_provider: Any | None + callbacks: list[BaseRoute] | None + openapi_extra: dict[str, Any] | None + generate_unique_id_function: Callable[[Any], str] | DefaultPlaceholder + strict_content_type: bool | DefaultPlaceholder + tags: list[str | Enum] + responses: dict[int | str, dict[str, Any]] + name: str + path_regex: Any + path_format: str + param_convertors: dict[str, Any] + methods: set[str] + unique_id: str + status_code: int | None + response_field: ModelField | None + stream_item_field: ModelField | None + dependencies: list[params.Depends] + description: str + response_fields: dict[int | str, ModelField] + dependant: Dependant + _flat_dependant: Dependant + _embed_body_fields: bool + body_field: ModelField | None + is_sse_stream: bool + is_json_stream: bool + + +def _populate_api_route_state( + route: _APIRouteLike, + path: str, + endpoint: Callable[..., Any], + *, + response_model: Any = Default(None), + status_code: int | None = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + summary: str | None = None, + description: str | None = None, + response_description: str = "Successful Response", + responses: dict[int | str, dict[str, Any]] | None = None, + deprecated: bool | None = None, + name: str | None = None, + methods: set[str] | list[str] | None = None, + operation_id: str | None = None, + response_model_include: IncEx | None = None, + response_model_exclude: IncEx | None = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + response_class: type[Response] | DefaultPlaceholder = Default(JSONResponse), + dependency_overrides_provider: Any | None = None, + callbacks: list[BaseRoute] | None = None, + openapi_extra: dict[str, Any] | None = None, + generate_unique_id_function: Callable[[Any], str] | DefaultPlaceholder = Default( + generate_unique_id + ), + strict_content_type: bool | DefaultPlaceholder = Default(True), +) -> None: + route.path = path + route.endpoint = endpoint + route.stream_item_type = None + if isinstance(response_model, DefaultPlaceholder): + return_annotation = get_typed_return_annotation(endpoint) + if lenient_issubclass(return_annotation, Response): + response_model = None + else: + stream_item = get_stream_item_type(return_annotation) + if stream_item is not None: + # Extract item type for JSONL or SSE streaming when + # response_class is DefaultPlaceholder (JSONL) or + # EventSourceResponse (SSE). + # ServerSentEvent is excluded: it's a transport + # wrapper, not a data model, so it shouldn't feed + # into validation or OpenAPI schema generation. + if ( + isinstance(response_class, DefaultPlaceholder) + or lenient_issubclass(response_class, EventSourceResponse) + ) and not lenient_issubclass(stream_item, ServerSentEvent): + route.stream_item_type = stream_item + response_model = None + else: + response_model = return_annotation + route.response_model = response_model + route.summary = summary + route.response_description = response_description + route.deprecated = deprecated + route.operation_id = operation_id + route.response_model_include = response_model_include + route.response_model_exclude = response_model_exclude + route.response_model_by_alias = response_model_by_alias + route.response_model_exclude_unset = response_model_exclude_unset + route.response_model_exclude_defaults = response_model_exclude_defaults + route.response_model_exclude_none = response_model_exclude_none + route.include_in_schema = include_in_schema + route.response_class = response_class + route.dependency_overrides_provider = dependency_overrides_provider + route.callbacks = callbacks + route.openapi_extra = openapi_extra + route.generate_unique_id_function = generate_unique_id_function + route.strict_content_type = strict_content_type + route.tags = tags or [] + route.responses = responses or {} + route.name = get_name(endpoint) if name is None else name + route.path_regex, route.path_format, route.param_convertors = compile_path(path) + if methods is None: + methods = ["GET"] + route.methods = {method.upper() for method in methods} + if isinstance(generate_unique_id_function, DefaultPlaceholder): + current_generate_unique_id: Callable[[Any], str] = ( + generate_unique_id_function.value + ) + else: + current_generate_unique_id = generate_unique_id_function + route.unique_id = route.operation_id or current_generate_unique_id(route) + # normalize enums e.g. http.HTTPStatus + if isinstance(status_code, IntEnum): + status_code = int(status_code) + route.status_code = status_code + if route.response_model: + assert is_body_allowed_for_status_code(status_code), ( + f"Status code {status_code} must not have a response body" + ) + response_name = "Response_" + route.unique_id + route.response_field = create_model_field( + name=response_name, + type_=route.response_model, + mode="serialization", + ) + else: + route.response_field = None + if route.stream_item_type: + stream_item_name = "StreamItem_" + route.unique_id + route.stream_item_field = create_model_field( + name=stream_item_name, + type_=route.stream_item_type, + mode="serialization", + ) + else: + route.stream_item_field = None + route.dependencies = list(dependencies or []) + route.description = description or inspect.cleandoc(route.endpoint.__doc__ or "") + # if a "form feed" character (page break) is found in the description text, + # truncate description text to the content preceding the first "form feed" + route.description = route.description.split("\f")[0].strip() + response_fields = {} + for additional_status_code, response in route.responses.items(): + assert isinstance(response, dict), "An additional response must be a dict" + model = response.get("model") + if model: + assert is_body_allowed_for_status_code(additional_status_code), ( + f"Status code {additional_status_code} must not have a response body" + ) + response_name = f"Response_{additional_status_code}_{route.unique_id}" + response_field = create_model_field( + name=response_name, type_=model, mode="serialization" + ) + response_fields[additional_status_code] = response_field + if response_fields: + route.response_fields = response_fields + else: + route.response_fields = {} + + assert callable(endpoint), "An endpoint must be a callable" + route.dependant = get_dependant( + path=route.path_format, call=route.endpoint, scope="function" + ) + for depends in route.dependencies[::-1]: + route.dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=route.path_format), + ) + route._flat_dependant = get_flat_dependant(route.dependant) + route._embed_body_fields = _should_embed_body_fields( + route._flat_dependant.body_params + ) + route.body_field = get_body_field( + flat_dependant=route._flat_dependant, + name=route.unique_id, + embed_body_fields=route._embed_body_fields, + ) + # Detect generator endpoints that should stream as JSONL or SSE + is_generator = ( + route.dependant.is_async_gen_callable or route.dependant.is_gen_callable + ) + route.is_sse_stream = is_generator and lenient_issubclass( + response_class, EventSourceResponse + ) + route.is_json_stream = is_generator and isinstance( + response_class, DefaultPlaceholder + ) + + class APIRoute(routing.Route): def __init__( self, @@ -841,166 +1094,541 @@ class APIRoute(routing.Route): | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), ) -> None: - self.path = path - self.endpoint = endpoint - self.stream_item_type: Any | None = None - if isinstance(response_model, DefaultPlaceholder): - return_annotation = get_typed_return_annotation(endpoint) - if lenient_issubclass(return_annotation, Response): - response_model = None - else: - stream_item = get_stream_item_type(return_annotation) - if stream_item is not None: - # Extract item type for JSONL or SSE streaming when - # response_class is DefaultPlaceholder (JSONL) or - # EventSourceResponse (SSE). - # ServerSentEvent is excluded: it's a transport - # wrapper, not a data model, so it shouldn't feed - # into validation or OpenAPI schema generation. - if ( - isinstance(response_class, DefaultPlaceholder) - or lenient_issubclass(response_class, EventSourceResponse) - ) and not lenient_issubclass(stream_item, ServerSentEvent): - self.stream_item_type = stream_item - response_model = None - else: - response_model = return_annotation - self.response_model = response_model - self.summary = summary - self.response_description = response_description - self.deprecated = deprecated - self.operation_id = operation_id - self.response_model_include = response_model_include - self.response_model_exclude = response_model_exclude - self.response_model_by_alias = response_model_by_alias - self.response_model_exclude_unset = response_model_exclude_unset - self.response_model_exclude_defaults = response_model_exclude_defaults - self.response_model_exclude_none = response_model_exclude_none - self.include_in_schema = include_in_schema - self.response_class = response_class - self.dependency_overrides_provider = dependency_overrides_provider - self.callbacks = callbacks - self.openapi_extra = openapi_extra - self.generate_unique_id_function = generate_unique_id_function - self.strict_content_type = strict_content_type - self.tags = tags or [] - self.responses = responses or {} - self.name = get_name(endpoint) if name is None else name - self.path_regex, self.path_format, self.param_convertors = compile_path(path) - if methods is None: - methods = ["GET"] - self.methods: set[str] = {method.upper() for method in methods} - if isinstance(generate_unique_id_function, DefaultPlaceholder): - current_generate_unique_id: Callable[[APIRoute], str] = ( - generate_unique_id_function.value - ) - else: - current_generate_unique_id = generate_unique_id_function - self.unique_id = self.operation_id or current_generate_unique_id(self) - # normalize enums e.g. http.HTTPStatus - if isinstance(status_code, IntEnum): - status_code = int(status_code) - self.status_code = status_code - if self.response_model: - assert is_body_allowed_for_status_code(status_code), ( - f"Status code {status_code} must not have a response body" - ) - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - else: - self.response_field = None # type: ignore[assignment] - if self.stream_item_type: - stream_item_name = "StreamItem_" + self.unique_id - self.stream_item_field: ModelField | None = create_model_field( - name=stream_item_name, - type_=self.stream_item_type, - mode="serialization", - ) - else: - self.stream_item_field = None - self.dependencies = list(dependencies or []) - self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") - # if a "form feed" character (page break) is found in the description text, - # truncate description text to the content preceding the first "form feed" - self.description = self.description.split("\f")[0].strip() - response_fields = {} - for additional_status_code, response in self.responses.items(): - assert isinstance(response, dict), "An additional response must be a dict" - model = response.get("model") - if model: - assert is_body_allowed_for_status_code(additional_status_code), ( - f"Status code {additional_status_code} must not have a response body" - ) - response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field( - name=response_name, type_=model, mode="serialization" - ) - response_fields[additional_status_code] = response_field - if response_fields: - self.response_fields: dict[int | str, ModelField] = response_fields - else: - self.response_fields = {} - - assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant( - path=self.path_format, call=self.endpoint, scope="function" - ) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), - ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( - flat_dependant=self._flat_dependant, - name=self.unique_id, - embed_body_fields=self._embed_body_fields, - ) - # Detect generator endpoints that should stream as JSONL or SSE - is_generator = ( - self.dependant.is_async_gen_callable or self.dependant.is_gen_callable - ) - self.is_sse_stream = is_generator and lenient_issubclass( - response_class, EventSourceResponse - ) - self.is_json_stream = is_generator and isinstance( - response_class, DefaultPlaceholder + _populate_api_route_state( + cast(_APIRouteLike, self), + path, + endpoint, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary, + description=description, + response_description=response_description, + responses=responses, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + response_class=response_class, + dependency_overrides_provider=dependency_overrides_provider, + callbacks=callbacks, + openapi_extra=openapi_extra, + generate_unique_id_function=generate_unique_id_function, + strict_content_type=strict_content_type, ) self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: + route = cast(_APIRouteLike, self) + # TODO: Replace or deprecate this no-scope hook so included-route + # effective context can be passed explicitly instead of via ContextVar. + effective_context = _effective_route_context_var.get() + if effective_context is not None and effective_context.original_route is self: + route = cast(_APIRouteLike, effective_context) return get_request_handler( - dependant=self.dependant, - body_field=self.body_field, - status_code=self.status_code, - response_class=self.response_class, - response_field=self.response_field, - response_model_include=self.response_model_include, - response_model_exclude=self.response_model_exclude, - response_model_by_alias=self.response_model_by_alias, - response_model_exclude_unset=self.response_model_exclude_unset, - response_model_exclude_defaults=self.response_model_exclude_defaults, - response_model_exclude_none=self.response_model_exclude_none, - dependency_overrides_provider=self.dependency_overrides_provider, - embed_body_fields=self._embed_body_fields, - strict_content_type=self.strict_content_type, - stream_item_field=self.stream_item_field, - is_json_stream=self.is_json_stream, + dependant=route.dependant, + body_field=route.body_field, + status_code=route.status_code, + response_class=route.response_class, + response_field=route.response_field, + response_model_include=route.response_model_include, + response_model_exclude=route.response_model_exclude, + response_model_by_alias=route.response_model_by_alias, + response_model_exclude_unset=route.response_model_exclude_unset, + response_model_exclude_defaults=route.response_model_exclude_defaults, + response_model_exclude_none=route.response_model_exclude_none, + dependency_overrides_provider=route.dependency_overrides_provider, + embed_body_fields=route._embed_body_fields, + strict_content_type=route.strict_content_type, + stream_item_field=route.stream_item_field, + is_json_stream=route.is_json_stream, ) def matches(self, scope: Scope) -> tuple[Match, Scope]: - match, child_scope = super().matches(scope) + effective_context = _get_scope_effective_route_context(scope) + if effective_context is not None and effective_context.original_route is self: + match, child_scope = effective_context.matches(scope) + else: + match, child_scope = super().matches(scope) if match != Match.NONE: child_scope["route"] = self return match, child_scope + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + effective_context = _get_scope_effective_route_context(scope) + if effective_context is not None and effective_context.original_route is self: + methods = effective_context.methods + if methods and scope["method"] not in methods: + headers = {"Allow": ", ".join(methods)} + if "app" in scope: + raise HTTPException(status_code=405, headers=headers) + response = PlainTextResponse( + "Method Not Allowed", status_code=405, headers=headers + ) + await response(scope, receive, send) + return + token = _effective_route_context_var.set(effective_context) + try: + app = request_response(self.get_route_handler()) + finally: + _effective_route_context_var.reset(token) + await app(scope, receive, send) + return + await super().handle(scope, receive, send) + + +@dataclass +class _RouterIncludeContext: + included_router: "APIRouter" + prefix: str = "" + tags: list[str | Enum] = field(default_factory=list) + dependencies: list[params.Depends] = field(default_factory=list) + default_response_class: type[Response] | DefaultPlaceholder = field( + default_factory=lambda: Default(JSONResponse) + ) + responses: dict[int | str, dict[str, Any]] = field(default_factory=dict) + callbacks: list[BaseRoute] = field(default_factory=list) + deprecated: bool | None = None + include_in_schema: bool = True + generate_unique_id_function: Callable[[APIRoute], str] | DefaultPlaceholder = field( + default_factory=lambda: Default(generate_unique_id) + ) + strict_content_type: bool | DefaultPlaceholder = field( + default_factory=lambda: Default(True) + ) + dependency_overrides_provider: Any | None = None + + @classmethod + def for_include( + cls, + *, + parent_router: "APIRouter", + included_router: "APIRouter", + prefix: str = "", + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + default_response_class: type[Response] | DefaultPlaceholder = Default( + JSONResponse + ), + responses: dict[int | str, dict[str, Any]] | None = None, + callbacks: list[BaseRoute] | None = None, + deprecated: bool | None = None, + include_in_schema: bool = True, + generate_unique_id_function: Callable[[APIRoute], str] + | DefaultPlaceholder = Default(generate_unique_id), + ) -> "_RouterIncludeContext": + return cls( + included_router=included_router, + prefix=parent_router.prefix + prefix, + tags=[*parent_router.tags, *(tags or [])], + dependencies=[*parent_router.dependencies, *(dependencies or [])], + default_response_class=get_value_or_default( + default_response_class, parent_router.default_response_class + ), + responses={**parent_router.responses, **(responses or {})}, + callbacks=[*parent_router.callbacks, *(callbacks or [])], + deprecated=deprecated or parent_router.deprecated, + include_in_schema=parent_router.include_in_schema and include_in_schema, + generate_unique_id_function=get_value_or_default( + generate_unique_id_function, parent_router.generate_unique_id_function + ), + strict_content_type=parent_router.strict_content_type, + dependency_overrides_provider=parent_router.dependency_overrides_provider, + ) + + def combine( + self, child_context: "_RouterIncludeContext" + ) -> "_RouterIncludeContext": + return _RouterIncludeContext( + included_router=child_context.included_router, + prefix=self.prefix + child_context.prefix, + tags=[*self.tags, *child_context.tags], + dependencies=[*self.dependencies, *child_context.dependencies], + default_response_class=get_value_or_default( + child_context.default_response_class, self.default_response_class + ), + responses={**self.responses, **child_context.responses}, + callbacks=[*self.callbacks, *child_context.callbacks], + deprecated=self.deprecated or child_context.deprecated, + include_in_schema=self.include_in_schema + and child_context.include_in_schema, + generate_unique_id_function=get_value_or_default( + child_context.generate_unique_id_function, + self.generate_unique_id_function, + ), + strict_content_type=get_value_or_default( + child_context.strict_content_type, self.strict_content_type + ), + dependency_overrides_provider=self.dependency_overrides_provider, + ) + + def path_for( + self, route: APIRoute | routing.Route | routing.WebSocketRoute | routing.Mount + ) -> str: + return self.prefix + route.path + + +@dataclass +class _EffectiveRouteContext: + original_route: BaseRoute + starlette_route: BaseRoute | None = None + path: str = "" + endpoint: Callable[..., Any] | None = None + stream_item_type: Any | None = None + response_model: Any = None + summary: str | None = None + response_description: str = "Successful Response" + deprecated: bool | None = None + operation_id: str | None = None + response_model_include: IncEx | None = None + response_model_exclude: IncEx | None = None + response_model_by_alias: bool = True + response_model_exclude_unset: bool = False + response_model_exclude_defaults: bool = False + response_model_exclude_none: bool = False + include_in_schema: bool = True + response_class: type[Response] | DefaultPlaceholder = field( + default_factory=lambda: Default(JSONResponse) + ) + dependency_overrides_provider: Any | None = None + callbacks: list[BaseRoute] | None = None + openapi_extra: dict[str, Any] | None = None + generate_unique_id_function: Callable[[Any], str] | DefaultPlaceholder = field( + default_factory=lambda: Default(generate_unique_id) + ) + strict_content_type: bool | DefaultPlaceholder = field( + default_factory=lambda: Default(True) + ) + tags: list[str | Enum] = field(default_factory=list) + responses: dict[int | str, dict[str, Any]] = field(default_factory=dict) + name: str = "" + path_regex: Any = None + path_format: str = "" + param_convertors: dict[str, Any] = field(default_factory=dict) + methods: set[str] = field(default_factory=set) + unique_id: str = "" + status_code: int | None = None + response_field: ModelField | None = None + stream_item_field: ModelField | None = None + dependencies: list[params.Depends] = field(default_factory=list) + description: str = "" + response_fields: dict[int | str, ModelField] = field(default_factory=dict) + dependant: Dependant | None = None + _flat_dependant: Dependant | None = None + _embed_body_fields: bool = False + body_field: ModelField | None = None + is_sse_stream: bool = False + is_json_stream: bool = False + + @classmethod + def from_api_route( + cls, + *, + original_route: APIRoute, + include_context: _RouterIncludeContext, + ) -> "_EffectiveRouteContext": + route = cast(_APIRouteLike, original_route) + context = cls(original_route=original_route) + _populate_api_route_state( + cast(_APIRouteLike, context), + include_context.path_for(original_route), + route.endpoint, + response_model=route.response_model, + status_code=route.status_code, + tags=[*include_context.tags, *route.tags], + dependencies=[*include_context.dependencies, *route.dependencies], + summary=route.summary, + description=route.description, + response_description=route.response_description, + responses={**include_context.responses, **route.responses}, + deprecated=route.deprecated or include_context.deprecated, + methods=route.methods, + operation_id=route.operation_id, + response_model_include=route.response_model_include, + response_model_exclude=route.response_model_exclude, + response_model_by_alias=route.response_model_by_alias, + response_model_exclude_unset=route.response_model_exclude_unset, + response_model_exclude_defaults=route.response_model_exclude_defaults, + response_model_exclude_none=route.response_model_exclude_none, + include_in_schema=route.include_in_schema + and include_context.include_in_schema, + response_class=get_value_or_default( + route.response_class, + include_context.included_router.default_response_class, + include_context.default_response_class, + ), + name=route.name, + dependency_overrides_provider=include_context.dependency_overrides_provider, + callbacks=[*include_context.callbacks, *(route.callbacks or [])], + openapi_extra=route.openapi_extra, + generate_unique_id_function=get_value_or_default( + route.generate_unique_id_function, + include_context.included_router.generate_unique_id_function, + include_context.generate_unique_id_function, + ), + strict_content_type=get_value_or_default( + route.strict_content_type, + include_context.included_router.strict_content_type, + include_context.strict_content_type, + ), + ) + return context + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + if not isinstance(self.original_route, APIRoute): + assert self.starlette_route is not None + return self.starlette_route.matches(scope) + if scope["type"] != "http": + return Match.NONE, {} + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if not match: + return Match.NONE, {} + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"endpoint": self.endpoint, "path_params": path_params} + methods = self.methods + if methods and scope["method"] not in methods: + return Match.PARTIAL, child_scope + return Match.FULL, child_scope + + def url_path_for(self, name: str, /, **path_params: Any) -> Any: + if not isinstance(self.original_route, APIRoute): + assert self.starlette_route is not None + return self.starlette_route.url_path_for(name, **path_params) + seen_params = set(path_params.keys()) + param_convertors = self.param_convertors + expected_params = set(param_convertors.keys()) + if name != self.name or seen_params != expected_params: + raise routing.NoMatchFound(name, path_params) + path, remaining_params = routing.replace_params( + self.path_format, param_convertors, path_params + ) + assert not remaining_params + return URLPath(path=path, protocol="http") + + +@dataclass +class _IncludedRouter(BaseRoute): + original_router: "APIRouter" + include_context: _RouterIncludeContext + _effective_candidates: list["_EffectiveRouteContext | _IncludedRouter"] = field( + default_factory=list + ) + _effective_candidates_version: int | None = None + + def effective_candidates(self) -> list["_EffectiveRouteContext | _IncludedRouter"]: + routes_version = self.original_router._get_routes_version() + if routes_version == self._effective_candidates_version: + return self._effective_candidates + self._effective_candidates = [] + candidates = self.original_router.routes + for route in candidates: + if isinstance(route, _IncludedRouter): + child_context = self.include_context.combine(route.include_context) + child_branch = _IncludedRouter( + original_router=route.original_router, + include_context=child_context, + ) + self._effective_candidates.append(child_branch) + continue + route_context = self._build_effective_context(route) + if route_context is not None: + self._effective_candidates.append(route_context) + self._effective_candidates_version = routes_version + return self._effective_candidates + + def _build_effective_context( + self, route: BaseRoute + ) -> _EffectiveRouteContext | None: + if isinstance(route, APIRoute): + return _EffectiveRouteContext.from_api_route( + original_route=route, + include_context=self.include_context, + ) + if isinstance(route, routing.Route): + starlette_route: BaseRoute = routing.Route( + self.include_context.path_for(route), + endpoint=route.endpoint, + methods=list(route.methods or []), + name=route.name, + include_in_schema=route.include_in_schema, + ) + return _EffectiveRouteContext( + original_route=route, + starlette_route=starlette_route, + ) + if isinstance(route, APIWebSocketRoute): + starlette_route = APIWebSocketRoute( + self.include_context.path_for(route), + endpoint=route.endpoint, + name=route.name, + dependencies=[*self.include_context.dependencies, *route.dependencies], + dependency_overrides_provider=( + self.include_context.dependency_overrides_provider + ), + ) + return _EffectiveRouteContext( + original_route=route, + starlette_route=starlette_route, + ) + if isinstance(route, routing.WebSocketRoute): + starlette_route = routing.WebSocketRoute( + self.include_context.path_for(route), route.endpoint, name=route.name + ) + return _EffectiveRouteContext( + original_route=route, + starlette_route=starlette_route, + ) + if isinstance(route, routing.Mount): + starlette_route = copy.copy(route) + starlette_route.path = self.include_context.path_for(route).rstrip("/") + ( + starlette_route.path_regex, + starlette_route.path_format, + starlette_route.param_convertors, + ) = compile_path(starlette_route.path + "/{path:path}") + return _EffectiveRouteContext( + original_route=route, + starlette_route=starlette_route, + ) + if isinstance(route, routing.Host): + if self.include_context.prefix: + prefixed_app: ASGIApp = routing.Router( + routes=[routing.Mount(self.include_context.prefix, app=route.app)] + ) + else: + prefixed_app = route.app + starlette_route = routing.Host( + route.host, app=prefixed_app, name=route.name + ) + return _EffectiveRouteContext( + original_route=route, + starlette_route=starlette_route, + ) + return None + + def _match( + self, scope: Scope + ) -> tuple[Match, Scope, BaseRoute | None, _EffectiveRouteContext | None]: + partial: tuple[Scope, BaseRoute, _EffectiveRouteContext | None] | None = None + for candidate in self.effective_candidates(): + if isinstance(candidate, _IncludedRouter): + match, child_scope = candidate.matches(scope) + route: BaseRoute = candidate + route_context = None + elif isinstance(candidate.original_route, APIRoute): + route_context = candidate + fastapi_scope = _get_fastapi_scope(scope) + previous_context = fastapi_scope.get( + _FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY, _SCOPE_MISSING + ) + fastapi_scope[_FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY] = route_context + try: + match, child_scope = candidate.original_route.matches(scope) + finally: + _restore_fastapi_scope_key( + scope, _FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY, previous_context + ) + route = candidate.original_route + else: + route_context = candidate + match, child_scope = candidate.matches(scope) + route = candidate.starlette_route or candidate.original_route + if match == Match.FULL: + return match, child_scope, route, route_context + if match == Match.PARTIAL and partial is None: + partial = (child_scope, route, route_context) + if partial is not None: + child_scope, route, route_context = partial + return Match.PARTIAL, child_scope, route, route_context + return Match.NONE, {}, None, None + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + fastapi_scope = _get_fastapi_scope(scope) + previous_router = fastapi_scope.get( + _FASTAPI_INCLUDED_ROUTER_KEY, _SCOPE_MISSING + ) + fastapi_scope[_FASTAPI_INCLUDED_ROUTER_KEY] = self + try: + match, _ = self.original_router.matches(scope) + return match, {} + finally: + _restore_fastapi_scope_key( + scope, _FASTAPI_INCLUDED_ROUTER_KEY, previous_router + ) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + _get_fastapi_scope(scope)[_FASTAPI_INCLUDED_ROUTER_KEY] = self + await self.original_router.handle(scope, receive, send) + + async def _handle_selected( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + match, child_scope, route, effective_context = self._match(scope) + if match == Match.NONE or route is None: + await self.original_router.default(scope, receive, send) + return + scope.update(child_scope) + if isinstance(route, _IncludedRouter): + await route.handle(scope, receive, send) + return + if effective_context is not None: + _get_fastapi_scope(scope)[_FASTAPI_EFFECTIVE_ROUTE_CONTEXT_KEY] = ( + effective_context + ) + original_route = effective_context.original_route + if isinstance(original_route, APIRoute): + scope["route"] = original_route + await original_route.handle(scope, receive, send) + return + await route.handle(scope, receive, send) + + def effective_route_contexts(self) -> Iterator[_EffectiveRouteContext]: + for candidate in self.effective_candidates(): + if isinstance(candidate, _IncludedRouter): + yield from candidate.effective_route_contexts() + else: + yield candidate + + def url_path_for(self, name: str, /, **path_params: Any) -> Any: + for route_context in self.effective_route_contexts(): + try: + return route_context.url_path_for(name, **path_params) + except routing.NoMatchFound: + pass + raise routing.NoMatchFound(name, path_params) + + +def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[BaseRoute]: + for route, route_context in _iter_routes_with_context(routes): + if route_context is not None and route_context.starlette_route is not None: + yield route_context.starlette_route + else: + yield route + + +def _iter_routes_with_context( + routes: Sequence[BaseRoute], +) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]: + for route in routes: + if isinstance(route, _IncludedRouter): + for route_context in route.effective_route_contexts(): + yield route_context.original_route, route_context + else: + yield route, None + class APIRouter(routing.Router): """ @@ -1313,6 +1941,87 @@ class APIRouter(routing.Router): self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function self.strict_content_type = strict_content_type + self._routes_version = 0 + + def _mark_routes_changed(self) -> None: + self._routes_version += 1 + + def _get_routes_version(self, seen: set[int] | None = None) -> int: + if seen is None: + seen = set() + router_id = id(self) + if router_id in seen: + return self._routes_version + seen.add(router_id) + version = self._routes_version + for route in self.routes: + if isinstance(route, _IncludedRouter): + version += route.original_router._get_routes_version(seen) + return version + + def _contains_router( + self, router: "APIRouter", seen: set[int] | None = None + ) -> bool: + if seen is None: + seen = set() + router_id = id(self) + if router_id in seen: + return False + seen.add(router_id) + for route in self.routes: + if not isinstance(route, _IncludedRouter): + continue + if route.original_router is router: + return True + if route.original_router._contains_router(router, seen): + return True + return False + + def add_route( + self, + path: str, + endpoint: Callable[[Request], Awaitable[Response] | Response], + methods: Collection[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: + super().add_route( + path, + endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + self._mark_routes_changed() + + def add_websocket_route( + self, + path: str, + endpoint: Callable[[WebSocket], Awaitable[None]], + name: str | None = None, + ) -> None: + super().add_websocket_route(path, endpoint, name=name) + self._mark_routes_changed() + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + included_router = _get_scope_included_router(scope) + if ( + isinstance(included_router, _IncludedRouter) + and included_router.original_router is self + ): + await included_router._handle_selected(scope, receive, send) + return + await self.app(scope, receive, send) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + included_router = _get_scope_included_router(scope) + if ( + isinstance(included_router, _IncludedRouter) + and included_router.original_router is self + ): + match, child_scope, _, _ = included_router._match(scope) + return match, child_scope + return Match.NONE, {} def route( self, @@ -1415,6 +2124,7 @@ class APIRouter(routing.Router): ), ) self.routes.append(route) + self._mark_routes_changed() def api_route( self, @@ -1498,6 +2208,7 @@ class APIRouter(routing.Router): dependency_overrides_provider=self.dependency_overrides_provider, ) self.routes.append(route) + self._mark_routes_changed() def websocket( self, @@ -1714,111 +2425,40 @@ class APIRouter(routing.Router): "Cannot include the same APIRouter instance into itself. " "Did you mean to include a different router?" ) + assert not router._contains_router(self), ( + "Cannot include an APIRouter instance that already includes this router. " + "Did you mean to include a different router?" + ) if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" assert not prefix.endswith("/"), ( "A path prefix must not end with '/', as the routes will start with '/'" ) else: - for r in router.routes: - path = getattr(r, "path") # noqa: B009 + for r in _iter_included_route_candidates(router.routes): + path = getattr(r, "path", None) name = getattr(r, "name", "unknown") if path is not None and not path: raise FastAPIError( f"Prefix and path cannot be both empty (path operation: {name})" ) - if responses is None: - responses = {} - for route in router.routes: - if isinstance(route, APIRoute): - combined_responses = {**responses, **route.responses} - use_response_class = get_value_or_default( - route.response_class, - router.default_response_class, - default_response_class, - self.default_response_class, - ) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) - current_dependencies: list[params.Depends] = [] - if dependencies: - current_dependencies.extend(dependencies) - if route.dependencies: - current_dependencies.extend(route.dependencies) - current_callbacks = [] - if callbacks: - current_callbacks.extend(callbacks) - if route.callbacks: - current_callbacks.extend(route.callbacks) - current_generate_unique_id = get_value_or_default( - route.generate_unique_id_function, - router.generate_unique_id_function, - generate_unique_id_function, - self.generate_unique_id_function, - ) - self.add_api_route( - prefix + route.path, - route.endpoint, - response_model=route.response_model, - status_code=route.status_code, - tags=current_tags, - dependencies=current_dependencies, - summary=route.summary, - description=route.description, - response_description=route.response_description, - responses=combined_responses, - deprecated=route.deprecated or deprecated or self.deprecated, - methods=route.methods, - operation_id=route.operation_id, - response_model_include=route.response_model_include, - response_model_exclude=route.response_model_exclude, - response_model_by_alias=route.response_model_by_alias, - response_model_exclude_unset=route.response_model_exclude_unset, - response_model_exclude_defaults=route.response_model_exclude_defaults, - response_model_exclude_none=route.response_model_exclude_none, - include_in_schema=route.include_in_schema - and self.include_in_schema - and include_in_schema, - response_class=use_response_class, - name=route.name, - route_class_override=type(route), - callbacks=current_callbacks, - openapi_extra=route.openapi_extra, - generate_unique_id_function=current_generate_unique_id, - strict_content_type=get_value_or_default( - route.strict_content_type, - router.strict_content_type, - self.strict_content_type, - ), - ) - elif isinstance(route, routing.Route): - methods = list(route.methods or []) - self.add_route( - prefix + route.path, - route.endpoint, - methods=methods, - include_in_schema=route.include_in_schema, - name=route.name, - ) - elif isinstance(route, APIWebSocketRoute): - current_dependencies = [] - if dependencies: - current_dependencies.extend(dependencies) - if route.dependencies: - current_dependencies.extend(route.dependencies) - self.add_api_websocket_route( - prefix + route.path, - route.endpoint, - dependencies=current_dependencies, - name=route.name, - ) - elif isinstance(route, routing.WebSocketRoute): - self.add_websocket_route( - prefix + route.path, route.endpoint, name=route.name - ) + include_context = _RouterIncludeContext.for_include( + parent_router=self, + included_router=router, + prefix=prefix, + tags=tags, + dependencies=dependencies, + default_response_class=default_response_class, + responses=responses, + callbacks=callbacks, + deprecated=deprecated, + include_in_schema=include_in_schema, + generate_unique_id_function=generate_unique_id_function, + ) + self.routes.append( + _IncludedRouter(original_router=router, include_context=include_context) + ) + self._mark_routes_changed() for handler in router.on_startup: self.add_event_handler("startup", handler) for handler in router.on_shutdown: diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py index 786c1efc31..de5e1da903 100644 --- a/tests/test_custom_route_class.py +++ b/tests/test_custom_route_class.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, FastAPI from fastapi.routing import APIRoute from fastapi.testclient import TestClient from inline_snapshot import snapshot -from starlette.routing import Route app = FastAPI() @@ -63,13 +62,9 @@ def test_get_path(path, expected_status, expected_response): def test_route_classes(): - routes = {} - for r in app.router.routes: - assert isinstance(r, Route) - routes[r.path] = r - assert getattr(routes["/a/"], "x_type") == "A" # noqa: B009 - assert getattr(routes["/a/b/"], "x_type") == "B" # noqa: B009 - assert getattr(routes["/a/b/c/"], "x_type") == "C" # noqa: B009 + assert isinstance(router_a.routes[0], APIRouteA) + assert isinstance(router_b.routes[0], APIRouteB) + assert isinstance(router_c.routes[0], APIRouteC) def test_openapi_schema(): diff --git a/tests/test_router_include_context.py b/tests/test_router_include_context.py new file mode 100644 index 0000000000..408cdd3f11 --- /dev/null +++ b/tests/test_router_include_context.py @@ -0,0 +1,855 @@ +from typing import Annotated, cast + +import pytest +from fastapi import APIRouter, Body, Depends, FastAPI, Request +from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse +from fastapi.routing import ( + APIRoute, + _IncludedRouter, + _iter_included_route_candidates, + _restore_fastapi_scope_key, +) +from fastapi.testclient import TestClient +from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router + + +def dependency_a(): + return "a" + + +def dependency_b(): + return "b" + + +def dependency_c(): + return "c" + + +def unique_id_b(route: APIRoute) -> str: + return f"b_{route.name}" + + +def test_router_include_context_matches_flattened_include_metadata(): + callback_router = APIRouter() + + @callback_router.post("/callback") + def callback(): # pragma: no cover + return {"ok": True} + + callback_route = callback_router.routes[0] + + parent_router = APIRouter() + included_router = APIRouter( + prefix="/items", + tags=["router"], + dependencies=[Depends(dependency_a)], + responses={401: {"description": "Unauthorized"}}, + callbacks=[callback_route], + default_response_class=HTMLResponse, + strict_content_type=False, + ) + + @included_router.get( + "/{item_id}", + tags=["route"], + dependencies=[Depends(dependency_b)], + responses={404: {"description": "Missing"}}, + callbacks=[callback_route], + generate_unique_id_function=unique_id_b, + ) + def read_item(item_id: str, request: Request): + context = request.scope["fastapi"]["effective_route_context"] + return JSONResponse( + { + "path": context.path, + "tags": context.tags, + "dependency_count": len(context.dependencies), + "response_codes": sorted(context.responses), + "callback_count": len(context.callbacks or []), + "deprecated": context.deprecated, + "include_in_schema": context.include_in_schema, + "response_class": context.response_class.__name__, + "generate_unique_id": context.generate_unique_id_function(context), + "strict_content_type": context.strict_content_type, + "has_dependency_overrides_provider": ( + context.dependency_overrides_provider + is app.router.dependency_overrides_provider + ), + } + ) + + parent_router.include_router( + included_router, + prefix="/api", + tags=["include"], + dependencies=[Depends(dependency_c)], + responses={400: {"description": "Bad request"}}, + callbacks=[callback_route], + deprecated=True, + include_in_schema=False, + ) + + app = FastAPI() + app.include_router(parent_router) + response = TestClient(app).get("/api/items/foo") + + assert response.status_code == 200 + assert response.json() == { + "path": "/api/items/{item_id}", + "tags": ["include", "router", "route"], + "dependency_count": 3, + "response_codes": [400, 401, 404], + "callback_count": 3, + "deprecated": True, + "include_in_schema": False, + "response_class": "HTMLResponse", + "generate_unique_id": "b_read_item", + "strict_content_type": False, + "has_dependency_overrides_provider": True, + } + + +def test_live_route_addition_uses_include_metadata_for_runtime_and_openapi(): + calls: list[str] = [] + + def included_dependency(): + calls.append("dependency") + + router = APIRouter() + app = FastAPI() + app.include_router( + router, + prefix="/api", + tags=["included"], + dependencies=[Depends(included_dependency)], + responses={418: {"description": "Teapot"}}, + ) + + @router.get("/later") + def read_later(): + return {"later": True} + + client = TestClient(app) + response = client.get("/api/later") + + assert response.status_code == 200 + assert response.json() == {"later": True} + assert calls == ["dependency"] + operation = client.get("/openapi.json").json()["paths"]["/api/later"]["get"] + assert operation["tags"] == ["included"] + assert operation["responses"]["418"] == {"description": "Teapot"} + + +def test_openapi_cache_updates_after_live_route_addition(): + router = APIRouter() + app = FastAPI() + app.include_router(router, prefix="/api") + client = TestClient(app) + + first_schema = client.get("/openapi.json").json() + assert "/api/later" not in first_schema["paths"] + + @router.get("/later") + def read_later(): # pragma: no cover + return {"later": True} + + second_schema = client.get("/openapi.json").json() + assert "/api/later" in second_schema["paths"] + + +def test_nested_router_added_after_parent_inclusion_is_live(): + parent_router = APIRouter() + child_router = APIRouter() + app = FastAPI() + app.include_router(parent_router, prefix="/api") + parent_router.include_router(child_router, prefix="/child", tags=["child"]) + + @child_router.get("/items") + def read_items(): + return ["item"] + + client = TestClient(app) + response = client.get("/api/child/items") + + assert response.status_code == 200 + assert response.json() == ["item"] + operation = client.get("/openapi.json").json()["paths"]["/api/child/items"]["get"] + assert operation["tags"] == ["child"] + + +def test_repeated_deep_inclusions_handle_all_concrete_paths(): + shared_router = APIRouter() + + @shared_router.get("/items") + def read_items(): + return [] + + parent_router = APIRouter() + parent_router.include_router(shared_router, prefix="/a") + parent_router.include_router(shared_router, prefix="/b") + + app = FastAPI() + app.include_router(parent_router, prefix="/v1") + app.include_router(parent_router, prefix="/v2") + + client = TestClient(app) + paths = ["/v1/a/items", "/v1/b/items", "/v2/a/items", "/v2/b/items"] + for path in paths: + response = client.get(path) + assert response.status_code == 200 + assert response.json() == [] + assert set(client.get("/openapi.json").json()["paths"]) == set(paths) + + +def test_url_path_for_uses_effective_context_for_live_included_route(): + router = APIRouter() + app = FastAPI() + app.include_router(router, prefix="/api") + + @router.get("/items/{item_id}", name="read_item") + def read_item(item_id: str): # pragma: no cover + return {"item_id": item_id} + + assert app.url_path_for("read_item", item_id="abc") == "/api/items/abc" + + +def test_url_path_for_uses_distinct_repeated_inclusion_contexts(): + router = APIRouter() + + @router.get("/items/{item_id}", name="read_item") + def read_item(item_id: str): # pragma: no cover + return {"item_id": item_id} + + parent_router = APIRouter() + parent_router.include_router(router, prefix="/v1") + parent_router.include_router(router, prefix="/v2") + + assert parent_router.url_path_for("read_item", item_id="abc") == "/v1/items/abc" + assert ( + parent_router.routes[1].url_path_for("read_item", item_id="abc") + == "/v2/items/abc" + ) + + +def test_indirect_router_inclusion_cycles_are_rejected(): + parent_router = APIRouter() + child_router = APIRouter() + + parent_router.include_router(child_router, prefix="/child") + + with pytest.raises(AssertionError, match="already includes this router"): + child_router.include_router(parent_router, prefix="/parent") + + parent_router = APIRouter() + child_router = APIRouter() + grandchild_router = APIRouter() + + parent_router.include_router(child_router, prefix="/child") + child_router.include_router(grandchild_router, prefix="/grandchild") + + with pytest.raises(AssertionError, match="already includes this router"): + grandchild_router.include_router(parent_router, prefix="/parent") + + +def test_original_api_route_subclass_instance_is_called_after_inclusion(): + class TrackingRoute(APIRoute): + calls = 0 + + async def handle(self, scope, receive, send): + self.calls += 1 + await super().handle(scope, receive, send) + + router = APIRouter(route_class=TrackingRoute) + + @router.get("/items") + def read_items(): + return [] + + original_route = router.routes[0] + assert isinstance(original_route, TrackingRoute) + + app = FastAPI() + app.include_router(router, prefix="/api") + + response = TestClient(app).get("/api/items") + + assert response.status_code == 200 + assert original_route.calls == 1 + + +def test_original_api_route_get_route_handler_is_called_after_inclusion(): + class TrackingRoute(APIRoute): + calls = 0 + + def get_route_handler(self): + handler = super().get_route_handler() + + async def custom_handler(request): + self.calls += 1 + return await handler(request) + + return custom_handler + + router = APIRouter(route_class=TrackingRoute) + + @router.get("/items") + def read_items(): + return [] + + original_route = router.routes[0] + assert isinstance(original_route, TrackingRoute) + original_route.calls = 0 + + app = FastAPI() + app.include_router(router, prefix="/api") + + response = TestClient(app).get("/api/items") + + assert response.status_code == 200 + assert original_route.calls == 1 + + +def test_original_api_route_matches_is_called_after_inclusion(): + class HeaderRoute(APIRoute): + calls = 0 + + def matches(self, scope): + self.calls += 1 + headers = dict(scope.get("headers", [])) + if headers.get(b"x-match") != b"yes": + return Match.NONE, {} + return super().matches(scope) + + router = APIRouter(route_class=HeaderRoute) + + @router.get("/items") + def read_items(): + return [] + + original_route = router.routes[0] + assert isinstance(original_route, HeaderRoute) + original_route.calls = 0 + + app = FastAPI() + app.include_router(router, prefix="/api") + client = TestClient(app) + + assert client.get("/api/items").status_code == 404 + assert client.get("/api/items", headers={"x-match": "yes"}).status_code == 200 + assert original_route.calls >= 2 + + +def test_effective_route_context_is_available_in_scope_during_request(): + router = APIRouter() + + @router.get("/items") + def read_items(request: Request): + fastapi_scope = request.scope.get("fastapi") + assert isinstance(fastapi_scope, dict) + return { + "has_context": "effective_route_context" in fastapi_scope, + "path": fastapi_scope["effective_route_context"].path, + } + + app = FastAPI() + app.include_router(router, prefix="/api") + + response = TestClient(app).get("/api/items") + + assert response.status_code == 200 + assert response.json() == {"has_context": True, "path": "/api/items"} + + +def test_original_api_router_matches_is_called_after_inclusion(): + class HeaderRouter(APIRouter): + calls = 0 + + def matches(self, scope): + self.calls += 1 + headers = dict(scope.get("headers", [])) + if headers.get(b"x-router-match") != b"yes": + return Match.NONE, {} + return super().matches(scope) + + router = HeaderRouter() + + @router.get("/items") + def read_items(): + return [] + + app = FastAPI() + app.include_router(router, prefix="/api") + client = TestClient(app) + + assert client.get("/api/items").status_code == 404 + assert ( + client.get("/api/items", headers={"x-router-match": "yes"}).status_code == 200 + ) + assert router.calls >= 2 + + +def test_original_nested_api_router_subclasses_are_called_after_inclusion(): + class TrackingRouter(APIRouter): + calls = 0 + + async def handle(self, scope, receive, send): + self.calls += 1 + await super().handle(scope, receive, send) + + parent_router = TrackingRouter() + child_router = TrackingRouter() + + @child_router.get("/items") + def read_items(): + return [] + + parent_router.include_router(child_router, prefix="/child") + app = FastAPI() + app.include_router(parent_router, prefix="/api") + + response = TestClient(app).get("/api/child/items") + + assert response.status_code == 200 + assert parent_router.calls == 1 + assert child_router.calls == 1 + + +def test_router_and_include_prefix_path_params_reach_endpoint_and_openapi(): + router = APIRouter(prefix="/tenants/{tenant_id}") + + @router.get("/items/{item_id}") + def read_item(version: int, tenant_id: int, item_id: int): + return {"version": version, "tenant_id": tenant_id, "item_id": item_id} + + app = FastAPI() + app.include_router(router, prefix="/api/{version}") + + client = TestClient(app) + response = client.get("/api/1/tenants/2/items/3") + + assert response.status_code == 200 + assert response.json() == {"version": 1, "tenant_id": 2, "item_id": 3} + + operation = client.get("/openapi.json").json()["paths"][ + "/api/{version}/tenants/{tenant_id}/items/{item_id}" + ]["get"] + assert {parameter["name"] for parameter in operation["parameters"]} == { + "version", + "tenant_id", + "item_id", + } + + +def test_effective_body_fields_from_app_router_include_and_route_match_openapi(): + def app_body_dependency(app_body: Annotated[str, Body()]): + return app_body + + def router_body_dependency(router_body: Annotated[int, Body()]): + return router_body + + def include_body_dependency(include_body: Annotated[bool, Body()]): + return include_body + + app = FastAPI(dependencies=[Depends(app_body_dependency)]) + router = APIRouter(dependencies=[Depends(router_body_dependency)]) + + @router.post("/items") + def create_item(route_body: Annotated[float, Body()]): + return {"route_body": route_body} + + app.include_router( + router, + prefix="/api", + dependencies=[Depends(include_body_dependency)], + ) + + client = TestClient(app) + response = client.post( + "/api/items", + json={ + "app_body": "app", + "router_body": 1, + "include_body": True, + "route_body": 2.5, + }, + ) + + assert response.status_code == 200 + assert response.json() == {"route_body": 2.5} + + schema = client.get("/openapi.json").json() + request_body_schema = schema["paths"]["/api/items"]["post"]["requestBody"][ + "content" + ]["application/json"]["schema"] + body_ref = request_body_schema["$ref"].removeprefix("#/components/schemas/") + body_schema = schema["components"]["schemas"][body_ref] + assert set(body_schema["required"]) == { + "app_body", + "router_body", + "include_body", + "route_body", + } + assert set(body_schema["properties"]) == { + "app_body", + "router_body", + "include_body", + "route_body", + } + + +def test_later_full_match_wins_over_earlier_included_partial_match(): + get_router = APIRouter() + post_router = APIRouter() + + @get_router.get("/items") + def read_items(): # pragma: no cover + return {"method": "get"} + + @post_router.post("/items") + def create_item(): + return {"method": "post"} + + app = FastAPI() + app.include_router(get_router, prefix="/api") + app.include_router(post_router, prefix="/api") + + response = TestClient(app).post("/api/items") + + assert response.status_code == 200 + assert response.json() == {"method": "post"} + + +def test_included_partial_match_returns_405_when_no_later_full_match_exists(): + router = APIRouter() + + @router.get("/items") + def read_items(): # pragma: no cover + return [] + + app = FastAPI() + app.include_router(router, prefix="/api") + + response = TestClient(app).post("/api/items") + + assert response.status_code == 405 + assert response.headers["allow"] == "GET" + + +def test_included_slash_redirect_does_not_block_later_exact_match(): + redirect_router = APIRouter() + exact_router = APIRouter() + + @redirect_router.get("/items/") + def read_items_with_slash(): # pragma: no cover + return {"path": "slash"} + + @exact_router.get("/items") + def read_items_without_slash(): + return {"path": "exact"} + + app = FastAPI() + app.include_router(redirect_router, prefix="/api") + app.include_router(exact_router, prefix="/api") + + response = TestClient(app).get("/api/items", follow_redirects=False) + + assert response.status_code == 200 + assert response.json() == {"path": "exact"} + + +def test_failed_included_match_does_not_leak_effective_context_to_later_route(): + class RejectingRoute(APIRoute): + def matches(self, scope): + return Match.NONE, {} + + rejecting_router = APIRouter(route_class=RejectingRoute) + fallback_router = APIRouter() + + @rejecting_router.get("/items") + def rejected_item(): # pragma: no cover + return {"source": "rejected"} + + @fallback_router.get("/items") + def fallback_item(request: Request): + fastapi_scope = request.scope.get("fastapi", {}) + context = fastapi_scope.get("effective_route_context") + return { + "source": "fallback", + "context_path": getattr(context, "path", None), + } + + app = FastAPI() + app.include_router(rejecting_router, prefix="/api") + app.include_router(fallback_router, prefix="/api") + + response = TestClient(app).get("/api/items") + + assert response.status_code == 200 + assert response.json() == {"source": "fallback", "context_path": "/api/items"} + + +def test_included_starlette_mount_keeps_prefix_runtime_and_url_path_for(): + def mounted_endpoint(request): + return PlainTextResponse("mounted") + + router = APIRouter( + routes=[ + Mount( + "/mounted", + routes=[Route("/items/{item_id}", mounted_endpoint, name="read_item")], + name="mounted", + ) + ] + ) + app = FastAPI() + app.include_router(router, prefix="/api") + + client = TestClient(app) + response = client.get("/api/mounted/items/abc") + + assert response.status_code == 200 + assert response.text == "mounted" + assert ( + app.url_path_for("mounted:read_item", item_id="abc") == "/api/mounted/items/abc" + ) + + +def test_included_starlette_host_keeps_prefix_runtime_and_url_path_for(): + def hosted_endpoint(request): + return PlainTextResponse("hosted") + + hosted_app = Router( + routes=[Route("/items/{item_id}", hosted_endpoint, name="read_item")] + ) + router = APIRouter( + routes=[Host("{subdomain}.example.com", hosted_app, name="hosted")] + ) + app = FastAPI() + app.include_router(router, prefix="/api") + + client = TestClient(app, base_url="http://api.example.com") + response = client.get("/api/items/abc") + + assert response.status_code == 200 + assert response.text == "hosted" + url = app.url_path_for("hosted:read_item", subdomain="api", item_id="abc") + assert str(url) == "/api/items/abc" + assert url.host == "api.example.com" + + +def test_restore_fastapi_scope_key_ignores_non_dict_fastapi_scope(): + scope = {"fastapi": "not-a-dict"} + + _restore_fastapi_scope_key(scope, "effective_route_context", object()) + + assert scope == {"fastapi": "not-a-dict"} + + +@pytest.mark.anyio +async def test_included_api_route_without_app_scope_returns_405_response(): + router = APIRouter() + + @router.get("/items") + def read_items(): # pragma: no cover + return {"items": []} + + app = FastAPI() + app.include_router(router, prefix="/api") + included_router = cast(_IncludedRouter, app.router.routes[-1]) + effective_context = next(included_router.effective_route_contexts()) + route = effective_context.original_route + messages = [] + + async def receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/api/items", + "raw_path": b"/api/items", + "root_path": "", + "scheme": "http", + "query_string": b"", + "headers": [], + "fastapi": {"effective_route_context": effective_context}, + } + + await route.handle(scope, receive, send) + + assert messages[0]["type"] == "http.response.start" + assert messages[0]["status"] == 405 + assert dict(messages[0]["headers"])[b"allow"] == b"GET" + + +def test_effective_api_route_context_does_not_match_websocket_scope(): + router = APIRouter() + + @router.get("/items") + def read_items(): # pragma: no cover + return {"items": []} + + app = FastAPI() + app.include_router(router, prefix="/api") + included_router = cast(_IncludedRouter, app.router.routes[-1]) + effective_context = next(included_router.effective_route_contexts()) + + match, child_scope = effective_context.matches( + { + "type": "websocket", + "path": "/api/items", + "root_path": "", + } + ) + + assert match == Match.NONE + assert child_scope == {} + + +def test_effective_api_route_context_url_path_for_no_match(): + router = APIRouter() + + @router.get("/items/{item_id}") + def read_item(item_id: str): # pragma: no cover + return {"item_id": item_id} + + app = FastAPI() + app.include_router(router, prefix="/api") + included_router = cast(_IncludedRouter, app.router.routes[-1]) + effective_context = next(included_router.effective_route_contexts()) + + with pytest.raises(NoMatchFound): + effective_context.url_path_for("missing", item_id="abc") + + with pytest.raises(NoMatchFound): + included_router.url_path_for("missing", item_id="abc") + + +def test_included_starlette_host_without_prefix_keeps_original_app(): + def hosted_endpoint(request): + return PlainTextResponse("hosted") + + hosted_app = Router( + routes=[Route("/items/{item_id}", hosted_endpoint, name="read_item")] + ) + router = APIRouter( + routes=[Host("{subdomain}.example.com", hosted_app, name="hosted")] + ) + app = FastAPI() + app.include_router(router) + + client = TestClient(app, base_url="http://api.example.com") + response = client.get("/items/abc") + + assert response.status_code == 200 + assert response.text == "hosted" + + +class UnknownRoute(BaseRoute): + def matches(self, scope): # pragma: no cover + return Match.NONE, {} + + async def handle(self, scope, receive, send): # pragma: no cover + raise AssertionError("UnknownRoute should not be handled") + + def url_path_for(self, name, /, **path_params): # pragma: no cover + raise NoMatchFound(name, path_params) + + +@pytest.mark.anyio +async def test_included_unknown_route_is_ignored_and_can_return_default_404(): + router = APIRouter(routes=[UnknownRoute()]) + app = FastAPI() + app.include_router(router, prefix="/api") + included_router = cast(_IncludedRouter, app.router.routes[-1]) + + assert included_router.effective_candidates() == [] + + messages = [] + + async def receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + messages.append(message) + + scope = { + "type": "http", + "method": "GET", + "path": "/api/missing", + "raw_path": b"/api/missing", + "root_path": "", + "scheme": "http", + "query_string": b"", + "headers": [], + "fastapi": {}, + } + + await included_router._handle_selected(scope, receive, send) + + assert messages[0]["type"] == "http.response.start" + assert messages[0]["status"] == 404 + + +def test_no_prefix_include_validation_sees_effective_starlette_route_candidates(): + def endpoint(request): # pragma: no cover + return PlainTextResponse("ok") + + child_router = APIRouter(routes=[Route("/items", endpoint, name="read_items")]) + parent_router = APIRouter() + parent_router.include_router(child_router, prefix="/child") + + candidates = list(_iter_included_route_candidates(parent_router.routes)) + + assert cast(Route, candidates[0]).path == "/child/items" + + +def test_apirouter_matches_fallback_without_include_context(): + router = APIRouter() + + def read_items(request): # pragma: no cover + return PlainTextResponse("items") + + router.add_route("/items", read_items) + + assert router.matches({"type": "http", "path": "/items", "root_path": ""}) == ( + Match.NONE, + {}, + ) + + +@pytest.mark.anyio +async def test_apirouter_handle_fallback_without_include_context(): + router = APIRouter() + + def read_items(request): + return PlainTextResponse("items") + + router.add_route("/items", read_items) + messages = [] + + async def receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + messages.append(message) + + scope = { + "type": "http", + "method": "GET", + "path": "/items", + "raw_path": b"/items", + "root_path": "", + "scheme": "http", + "query_string": b"", + "headers": [], + } + + await router.handle(scope, receive, send) + + assert messages[0]["type"] == "http.response.start" + assert messages[0]["status"] == 200 + assert messages[1]["body"] == b"items"