diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..2d9f1059c 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -943,6 +943,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + defer_init=False, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..485f35835 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -5,6 +5,7 @@ import inspect import json from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum +from functools import cached_property from typing import ( Any, AsyncIterator, @@ -459,6 +460,7 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + defer_init: bool = True, ) -> None: self.path = path self.endpoint = endpoint @@ -507,66 +509,16 @@ class APIRoute(routing.Route): assert is_body_allowed_for_status_code(status_code), ( f"Status code {status_code} must not have a response body" ) - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - # Create a clone of the field, so that a Pydantic submodel is not returned - # as is just because it's an instance of a subclass of a more limited class - # e.g. UserInDB (containing hashed_password) could be a subclass of User - # that doesn't have the hashed_password. But because it's a subclass, it - # would pass the validation and be returned as is. - # By being a new field, no inheritance will be passed as is. A new model - # will always be created. - # TODO: remove when deprecating Pydantic v1 - self.secure_cloned_response_field: Optional[ModelField] = ( - create_cloned_field(self.response_field) - ) - else: - self.response_field = None # type: ignore - self.secure_cloned_response_field = None self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" self.description = self.description.split("\f")[0].strip() - response_fields = {} - for additional_status_code, response in self.responses.items(): - assert isinstance(response, dict), "An additional response must be a dict" - model = response.get("model") - if model: - assert is_body_allowed_for_status_code(additional_status_code), ( - f"Status code {additional_status_code} must not have a response body" - ) - response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field( - name=response_name, type_=model, mode="serialization" - ) - response_fields[additional_status_code] = response_field - if response_fields: - self.response_fields: Dict[Union[int, str], ModelField] = response_fields - else: - self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), - ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( - flat_dependant=self._flat_dependant, - name=self.unique_id, - embed_body_fields=self._embed_body_fields, - ) - self.app = request_response(self.get_route_handler()) + + if not defer_init: + self.init_attributes() def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( @@ -591,6 +543,87 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @cached_property + def response_field(self) -> Optional[ModelField]: + if not self.response_model: + return None + response_name = "Response_" + self.unique_id + return create_model_field( + name=response_name, + type_=self.response_model, + mode="serialization", + ) + + # Create a clone of the field, so that a Pydantic submodel is not returned + # as is just because it's an instance of a subclass of a more limited class + # e.g. UserInDB (containing hashed_password) could be a subclass of User + # that doesn't have the hashed_password. But because it's a subclass, it + # would pass the validation and be returned as is. + # By being a new field, no inheritance will be passed as is. A new model + # will always be created. + # TODO: remove when deprecating Pydantic v1 + @cached_property + def secure_cloned_response_field(self) -> Optional[ModelField]: + return create_cloned_field(self.response_field) if self.response_field else None + + @cached_property + def response_fields(self) -> Dict[Union[int, str], ModelField]: + response_fields = {} + for additional_status_code, response in self.responses.items(): + assert isinstance(response, dict), "An additional response must be a dict" + model = response.get("model") + if model: + assert is_body_allowed_for_status_code(additional_status_code), ( + f"Status code {additional_status_code} must not have a response body" + ) + + response_name = f"Response_{additional_status_code}_{self.unique_id}" + response_field = create_model_field( + name=response_name, type_=model, mode="serialization" + ) + response_fields[additional_status_code] = response_field + return response_fields + + @cached_property + def dependant(self) -> Dependant: + dependant = get_dependant(path=self.path_format, call=self.endpoint) + for depends in self.dependencies[::-1]: + dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=self.path_format), + ) + return dependant + + @cached_property + def _flat_dependant(self) -> Dependant: + return get_flat_dependant(self.dependant) + + @cached_property + def _embed_body_fields(self) -> bool: + return _should_embed_body_fields(self._flat_dependant.body_params) + + @cached_property + def body_field(self) -> Optional[ModelField]: + return get_body_field( + flat_dependant=self._flat_dependant, + name=self.unique_id, + embed_body_fields=self._embed_body_fields, + ) + + @cached_property + def app(self) -> ASGIApp: # type: ignore + return request_response(self.get_route_handler()) + + def init_attributes(self) -> None: + self.app = self.app + self.dependant = self.dependant + self.response_field = self.response_field + self.response_fields = self.response_fields + self.secure_cloned_response_field = self.secure_cloned_response_field + self.body_field = self.body_field + self._flat_dependant = self._flat_dependant + self._embed_body_fields = self._embed_body_fields + class APIRouter(routing.Router): """ @@ -833,6 +866,16 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + defer_init: Annotated[ + bool, + Doc( + """ + By default, every route will defer its initialization until the first call. + This flag can be used to deactivate this behavior for the routes defined in this router, + causing the routes to initialize immediately when they are defined. + """ + ), + ] = True, ) -> None: super().__init__( routes=routes, @@ -858,6 +901,9 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.defer_init = defer_init + if not self.defer_init: + self.init_routes() def route( self, @@ -957,6 +1003,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + defer_init=self.defer_init, ) self.routes.append(route) @@ -1119,6 +1166,11 @@ class APIRouter(routing.Router): return decorator + def init_routes(self) -> None: + for route in self.routes: + if isinstance(route, APIRoute): + route.init_attributes() + def include_router( self, router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], diff --git a/tests/test_route_deferred_init.py b/tests/test_route_deferred_init.py new file mode 100644 index 000000000..355db896a --- /dev/null +++ b/tests/test_route_deferred_init.py @@ -0,0 +1,117 @@ +from itertools import chain +from typing import List, Optional + +from fastapi import APIRouter, Depends, FastAPI +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.routing import BaseRoute + +deferred_keys = [ + "app", + "response_fields", + "body_field", + "response_field", + "secure_cloned_response_field", + "dependant", + "_flat_dependant", + "_embed_body_fields", +] + + +def check_if_initialized(route: APIRoute, should_not: bool = False): + for key in deferred_keys: + if should_not: + assert key not in route.__dict__ + else: + assert key in route.__dict__ + + +def create_test_router( + routes: Optional[List[BaseRoute]] = None, defer_init: bool = True +): + router = APIRouter(routes=routes or [], defer_init=defer_init) + + class UserIdBody(BaseModel): + user_id: int + + @router.get("/user_id", dependencies=[Depends(lambda: True)]) + async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody: + return {"user_id": user_id} + + return router + + +def test_route_defers(): + app = FastAPI() + router = create_test_router(routes=app.router.routes) + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route, should_not=True) + + app.router = router + client = TestClient(app) + response = client.get("/user_id") + assert response.status_code == 200 + response = client.get("/openapi.json") + assert response.status_code == 200 + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + +def test_route_manual_init(): + router = create_test_router() + for route in router.routes: + check_if_initialized(route, should_not=True) + route.init_attributes() + check_if_initialized(route) + + router = create_test_router() + router.init_routes() + for route in router.routes: + check_if_initialized(route) + + +def test_router_defer_init_flag(): + route = APIRoute("/test", lambda: {"test": True}, defer_init=False) + check_if_initialized(route) + + deferring_router = create_test_router() + router = create_test_router(routes=deferring_router.routes, defer_init=False) + + for route in router.routes: + check_if_initialized(route) + + +def test_root_router_always_initialized(): + app = FastAPI() + + @app.get("/test") + async def test_get(): + return {"test": 1} + + router = create_test_router() + app.include_router(router) + for route in app.router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + + +def test_include_router_no_init(): + router1 = create_test_router() + + router2 = create_test_router() + router2.include_router(router1) + + for route in chain(router1.routes, router2.routes): + check_if_initialized(route, should_not=True)