From 9f5831bea80af3812485456d9be1891032b9e367 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Fri, 6 Sep 2024 01:25:28 +0200 Subject: [PATCH] defer-route-init for better preformance with many nested routers --- fastapi/applications.py | 1 + fastapi/routing.py | 155 ++++++++++++++++++++++++++-------------- 2 files changed, 101 insertions(+), 55 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..2d9f1059c 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -943,6 +943,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + defer_init=False, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] diff --git a/fastapi/routing.py b/fastapi/routing.py index 86e303602..e5062979c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -5,6 +5,7 @@ import inspect import json from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum +from functools import cached_property from typing import ( Any, AsyncIterator, @@ -459,6 +460,7 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + defer_init: bool = True, ) -> None: self.path = path self.endpoint = endpoint @@ -503,68 +505,16 @@ class APIRoute(routing.Route): if isinstance(status_code, IntEnum): status_code = int(status_code) self.status_code = status_code - if self.response_model: - assert is_body_allowed_for_status_code( - status_code - ), f"Status code {status_code} must not have a response body" - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - # Create a clone of the field, so that a Pydantic submodel is not returned - # as is just because it's an instance of a subclass of a more limited class - # e.g. UserInDB (containing hashed_password) could be a subclass of User - # that doesn't have the hashed_password. But because it's a subclass, it - # would pass the validation and be returned as is. - # By being a new field, no inheritance will be passed as is. A new model - # will always be created. - # TODO: remove when deprecating Pydantic v1 - self.secure_cloned_response_field: Optional[ModelField] = ( - create_cloned_field(self.response_field) - ) - else: - self.response_field = None # type: ignore - self.secure_cloned_response_field = None self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" self.description = self.description.split("\f")[0].strip() - response_fields = {} - for additional_status_code, response in self.responses.items(): - assert isinstance(response, dict), "An additional response must be a dict" - model = response.get("model") - if model: - assert is_body_allowed_for_status_code( - additional_status_code - ), f"Status code {additional_status_code} must not have a response body" - response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field(name=response_name, type_=model) - response_fields[additional_status_code] = response_field - if response_fields: - self.response_fields: Dict[Union[int, str], ModelField] = response_fields - else: - self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), - ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( - flat_dependant=self._flat_dependant, - name=self.unique_id, - embed_body_fields=self._embed_body_fields, - ) - self.app = request_response(self.get_route_handler()) + + if not defer_init: + self.init_attributes() def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( @@ -589,6 +539,85 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @cached_property + def response_field(self) -> Optional[ModelField]: + if not self.response_model: + return None + response_name = "Response_" + self.unique_id + return create_model_field( + name=response_name, + type_=self.response_model, + mode="serialization", + ) + + # Create a clone of the field, so that a Pydantic submodel is not returned + # as is just because it's an instance of a subclass of a more limited class + # e.g. UserInDB (containing hashed_password) could be a subclass of User + # that doesn't have the hashed_password. But because it's a subclass, it + # would pass the validation and be returned as is. + # By being a new field, no inheritance will be passed as is. A new model + # will always be created. + # TODO: remove when deprecating Pydantic v1 + @cached_property + def secure_cloned_response_field(self) -> Optional[ModelField]: + return create_cloned_field(self.response_field) if self.response_field else None + + @cached_property + def response_fields(self) -> Dict[Union[int, str], ModelField]: + response_fields = {} + for additional_status_code, response in self.responses.items(): + assert isinstance(response, dict), "An additional response must be a dict" + model = response.get("model") + if model: + assert is_body_allowed_for_status_code( + additional_status_code + ), f"Status code {additional_status_code} must not have a response body" + + response_name = f"Response_{additional_status_code}_{self.unique_id}" + response_field = create_model_field(name=response_name, type_=model) + response_fields[additional_status_code] = response_field + return response_fields + + @cached_property + def dependant(self) -> Dependant: + dependant = get_dependant(path=self.path_format, call=self.endpoint) + for depends in self.dependencies[::-1]: + dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=self.path_format), + ) + return dependant + + @cached_property + def _flat_dependant(self) -> Dependant: + return get_flat_dependant(self.dependant) + + @cached_property + def _embed_body_fields(self) -> bool: + return _should_embed_body_fields(self._flat_dependant.body_params) + + @cached_property + def body_field(self) -> Optional[ModelField]: + return get_body_field( + flat_dependant=self._flat_dependant, + name=self.unique_id, + embed_body_fields=self._embed_body_fields, + ) + + @cached_property + def app(self) -> ASGIApp: # type: ignore + return request_response(self.get_route_handler()) + + def init_attributes(self) -> None: + self.app = self.app + self.dependant = self.dependant + self.response_field = self.response_field + self.response_fields = self.response_fields + self.secure_cloned_response_field = self.secure_cloned_response_field + self.body_field = self.body_field + self._flat_dependant = self._flat_dependant + self._embed_body_fields = self._embed_body_fields + class APIRouter(routing.Router): """ @@ -831,6 +860,13 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + defer_init: Annotated[ + bool, + Doc(""" + By default every route will defer its initialization upon usage. + This flag disables the behavior for the routes defined in this router, causing the routes to initialize immediately. + """), + ] = True, ) -> None: super().__init__( routes=routes, @@ -856,6 +892,9 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.defer_init = defer_init + if not self.defer_init: + self.init_routes() def route( self, @@ -955,6 +994,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + defer_init=self.defer_init, ) self.routes.append(route) @@ -1117,6 +1157,11 @@ class APIRouter(routing.Router): return decorator + def init_routes(self) -> None: + for route in self.routes: + if isinstance(route, APIRoute): + route.init_attributes() + def include_router( self, router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],