Browse Source

defer-route-init for better preformance with many nested routers

pull/10589/head
Jan Vollmer 7 months ago
parent
commit
9f5831bea8
No known key found for this signature in database GPG Key ID: 19473D3A5AB433DA
  1. 1
      fastapi/applications.py
  2. 155
      fastapi/routing.py

1
fastapi/applications.py

@ -943,6 +943,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
defer_init=False,
)
self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]

155
fastapi/routing.py

@ -5,6 +5,7 @@ import inspect
import json
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from functools import cached_property
from typing import (
Any,
AsyncIterator,
@ -459,6 +460,7 @@ class APIRoute(routing.Route):
generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id),
defer_init: bool = True,
) -> None:
self.path = path
self.endpoint = endpoint
@ -503,68 +505,16 @@ class APIRoute(routing.Route):
if isinstance(status_code, IntEnum):
status_code = int(status_code)
self.status_code = status_code
if self.response_model:
assert is_body_allowed_for_status_code(
status_code
), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0].strip()
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
else:
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.app = request_response(self.get_route_handler())
if not defer_init:
self.init_attributes()
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
@ -589,6 +539,85 @@ class APIRoute(routing.Route):
child_scope["route"] = self
return match, child_scope
@cached_property
def response_field(self) -> Optional[ModelField]:
if not self.response_model:
return None
response_name = "Response_" + self.unique_id
return create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
@cached_property
def secure_cloned_response_field(self) -> Optional[ModelField]:
return create_cloned_field(self.response_field) if self.response_field else None
@cached_property
def response_fields(self) -> Dict[Union[int, str], ModelField]:
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
return response_fields
@cached_property
def dependant(self) -> Dependant:
dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
return dependant
@cached_property
def _flat_dependant(self) -> Dependant:
return get_flat_dependant(self.dependant)
@cached_property
def _embed_body_fields(self) -> bool:
return _should_embed_body_fields(self._flat_dependant.body_params)
@cached_property
def body_field(self) -> Optional[ModelField]:
return get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
@cached_property
def app(self) -> ASGIApp: # type: ignore
return request_response(self.get_route_handler())
def init_attributes(self) -> None:
self.app = self.app
self.dependant = self.dependant
self.response_field = self.response_field
self.response_fields = self.response_fields
self.secure_cloned_response_field = self.secure_cloned_response_field
self.body_field = self.body_field
self._flat_dependant = self._flat_dependant
self._embed_body_fields = self._embed_body_fields
class APIRouter(routing.Router):
"""
@ -831,6 +860,13 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
defer_init: Annotated[
bool,
Doc("""
By default every route will defer its initialization upon usage.
This flag disables the behavior for the routes defined in this router, causing the routes to initialize immediately.
"""),
] = True,
) -> None:
super().__init__(
routes=routes,
@ -856,6 +892,9 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.defer_init = defer_init
if not self.defer_init:
self.init_routes()
def route(
self,
@ -955,6 +994,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id,
defer_init=self.defer_init,
)
self.routes.append(route)
@ -1117,6 +1157,11 @@ class APIRouter(routing.Router):
return decorator
def init_routes(self) -> None:
for route in self.routes:
if isinstance(route, APIRoute):
route.init_attributes()
def include_router(
self,
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],

Loading…
Cancel
Save