From 9f5831bea80af3812485456d9be1891032b9e367 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Fri, 6 Sep 2024 01:25:28 +0200 Subject: [PATCH 1/7] defer-route-init for better preformance with many nested routers --- fastapi/applications.py | 1 + fastapi/routing.py | 155 ++++++++++++++++++++++++++-------------- 2 files changed, 101 insertions(+), 55 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..2d9f1059c 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -943,6 +943,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + defer_init=False, ) self.exception_handlers: Dict[ Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] diff --git a/fastapi/routing.py b/fastapi/routing.py index 86e303602..e5062979c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -5,6 +5,7 @@ import inspect import json from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum +from functools import cached_property from typing import ( Any, AsyncIterator, @@ -459,6 +460,7 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + defer_init: bool = True, ) -> None: self.path = path self.endpoint = endpoint @@ -503,68 +505,16 @@ class APIRoute(routing.Route): if isinstance(status_code, IntEnum): status_code = int(status_code) self.status_code = status_code - if self.response_model: - assert is_body_allowed_for_status_code( - status_code - ), f"Status code {status_code} must not have a response body" - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - # Create a clone of the field, so that a Pydantic submodel is not returned - # as is just because it's an instance of a subclass of a more limited class - # e.g. UserInDB (containing hashed_password) could be a subclass of User - # that doesn't have the hashed_password. But because it's a subclass, it - # would pass the validation and be returned as is. - # By being a new field, no inheritance will be passed as is. A new model - # will always be created. - # TODO: remove when deprecating Pydantic v1 - self.secure_cloned_response_field: Optional[ModelField] = ( - create_cloned_field(self.response_field) - ) - else: - self.response_field = None # type: ignore - self.secure_cloned_response_field = None self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" self.description = self.description.split("\f")[0].strip() - response_fields = {} - for additional_status_code, response in self.responses.items(): - assert isinstance(response, dict), "An additional response must be a dict" - model = response.get("model") - if model: - assert is_body_allowed_for_status_code( - additional_status_code - ), f"Status code {additional_status_code} must not have a response body" - response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field(name=response_name, type_=model) - response_fields[additional_status_code] = response_field - if response_fields: - self.response_fields: Dict[Union[int, str], ModelField] = response_fields - else: - self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) - for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( - 0, - get_parameterless_sub_dependant(depends=depends, path=self.path_format), - ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( - flat_dependant=self._flat_dependant, - name=self.unique_id, - embed_body_fields=self._embed_body_fields, - ) - self.app = request_response(self.get_route_handler()) + + if not defer_init: + self.init_attributes() def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( @@ -589,6 +539,85 @@ class APIRoute(routing.Route): child_scope["route"] = self return match, child_scope + @cached_property + def response_field(self) -> Optional[ModelField]: + if not self.response_model: + return None + response_name = "Response_" + self.unique_id + return create_model_field( + name=response_name, + type_=self.response_model, + mode="serialization", + ) + + # Create a clone of the field, so that a Pydantic submodel is not returned + # as is just because it's an instance of a subclass of a more limited class + # e.g. UserInDB (containing hashed_password) could be a subclass of User + # that doesn't have the hashed_password. But because it's a subclass, it + # would pass the validation and be returned as is. + # By being a new field, no inheritance will be passed as is. A new model + # will always be created. + # TODO: remove when deprecating Pydantic v1 + @cached_property + def secure_cloned_response_field(self) -> Optional[ModelField]: + return create_cloned_field(self.response_field) if self.response_field else None + + @cached_property + def response_fields(self) -> Dict[Union[int, str], ModelField]: + response_fields = {} + for additional_status_code, response in self.responses.items(): + assert isinstance(response, dict), "An additional response must be a dict" + model = response.get("model") + if model: + assert is_body_allowed_for_status_code( + additional_status_code + ), f"Status code {additional_status_code} must not have a response body" + + response_name = f"Response_{additional_status_code}_{self.unique_id}" + response_field = create_model_field(name=response_name, type_=model) + response_fields[additional_status_code] = response_field + return response_fields + + @cached_property + def dependant(self) -> Dependant: + dependant = get_dependant(path=self.path_format, call=self.endpoint) + for depends in self.dependencies[::-1]: + dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant(depends=depends, path=self.path_format), + ) + return dependant + + @cached_property + def _flat_dependant(self) -> Dependant: + return get_flat_dependant(self.dependant) + + @cached_property + def _embed_body_fields(self) -> bool: + return _should_embed_body_fields(self._flat_dependant.body_params) + + @cached_property + def body_field(self) -> Optional[ModelField]: + return get_body_field( + flat_dependant=self._flat_dependant, + name=self.unique_id, + embed_body_fields=self._embed_body_fields, + ) + + @cached_property + def app(self) -> ASGIApp: # type: ignore + return request_response(self.get_route_handler()) + + def init_attributes(self) -> None: + self.app = self.app + self.dependant = self.dependant + self.response_field = self.response_field + self.response_fields = self.response_fields + self.secure_cloned_response_field = self.secure_cloned_response_field + self.body_field = self.body_field + self._flat_dependant = self._flat_dependant + self._embed_body_fields = self._embed_body_fields + class APIRouter(routing.Router): """ @@ -831,6 +860,13 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + defer_init: Annotated[ + bool, + Doc(""" + By default every route will defer its initialization upon usage. + This flag disables the behavior for the routes defined in this router, causing the routes to initialize immediately. + """), + ] = True, ) -> None: super().__init__( routes=routes, @@ -856,6 +892,9 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.defer_init = defer_init + if not self.defer_init: + self.init_routes() def route( self, @@ -955,6 +994,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + defer_init=self.defer_init, ) self.routes.append(route) @@ -1117,6 +1157,11 @@ class APIRouter(routing.Router): return decorator + def init_routes(self) -> None: + for route in self.routes: + if isinstance(route, APIRoute): + route.init_attributes() + def include_router( self, router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], From 81ea988f9d4abda7980bf2dfd8cac94f05ff0853 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Sat, 7 Sep 2024 00:15:20 +0200 Subject: [PATCH 2/7] recover lost status code check --- fastapi/routing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastapi/routing.py b/fastapi/routing.py index e5062979c..e8c1ebd7a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -505,6 +505,10 @@ class APIRoute(routing.Route): if isinstance(status_code, IntEnum): status_code = int(status_code) self.status_code = status_code + if self.response_model: + assert is_body_allowed_for_status_code( + status_code + ), f"Status code {status_code} must not have a response body" self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, From a4486a15f1a5bdf76b2104176f22308f0d986af0 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Sat, 7 Sep 2024 00:16:34 +0200 Subject: [PATCH 3/7] add tests for deferred route initialization remove useless checks in test test defer_init flag on route add route to test app fix coverage again --- tests/test_route_deferred_init.py | 106 ++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/test_route_deferred_init.py diff --git a/tests/test_route_deferred_init.py b/tests/test_route_deferred_init.py new file mode 100644 index 000000000..e0cd68426 --- /dev/null +++ b/tests/test_route_deferred_init.py @@ -0,0 +1,106 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, FastAPI +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.routing import BaseRoute + +deferred_keys = [ + "app", + "response_fields", + "body_field", + "response_field", + "secure_cloned_response_field", + "dependant", + "_flat_dependant", + "_embed_body_fields", +] + + +def check_if_initialized(route: APIRoute, should_not: bool = False): + for key in deferred_keys: + if should_not: + assert key not in route.__dict__ + else: + assert key in route.__dict__ + + +def create_test_router( + routes: Optional[List[BaseRoute]] = None, defer_init: bool = True +): + router = APIRouter(routes=routes or [], defer_init=defer_init) + + class UserIdBody(BaseModel): + user_id: int + + @router.get("/user_id", dependencies=[Depends(lambda: True)]) + async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody: + return {"user_id": user_id} + + return router + + +def test_route_defers(): + app = FastAPI() + router = create_test_router(routes=app.router.routes) + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route, should_not=True) + + app.router = router + client = TestClient(app) + response = client.get("/user_id") + assert response.status_code == 200 + response = client.get("/openapi.json") + assert response.status_code == 200 + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + +def test_route_manual_init(): + router = create_test_router() + for route in router.routes: + check_if_initialized(route, should_not=True) + route.init_attributes() + check_if_initialized(route) + + router = create_test_router() + router.init_routes() + for route in router.routes: + check_if_initialized(route) + + +def test_router_defer_init_flag(): + route = APIRoute("/test", lambda: {"test": True}, defer_init=False) + check_if_initialized(route) + + deferring_router = create_test_router() + router = create_test_router(routes=deferring_router.routes, defer_init=False) + + for route in router.routes: + check_if_initialized(route) + + +def test_root_router_always_initialized(): + app = FastAPI() + + @app.get("/test") + async def test_get(): + return {"test": 1} + + router = create_test_router() + app.include_router(router) + for route in app.router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 From 424a512a3ac3da6a83ca48e212656894721f72b3 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Mon, 9 Sep 2024 22:10:21 +0200 Subject: [PATCH 4/7] assert that include_router does not cause initialization --- tests/test_route_deferred_init.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_route_deferred_init.py b/tests/test_route_deferred_init.py index e0cd68426..355db896a 100644 --- a/tests/test_route_deferred_init.py +++ b/tests/test_route_deferred_init.py @@ -1,3 +1,4 @@ +from itertools import chain from typing import List, Optional from fastapi import APIRouter, Depends, FastAPI @@ -104,3 +105,13 @@ def test_root_router_always_initialized(): client = TestClient(app) response = client.get("/test") assert response.status_code == 200 + + +def test_include_router_no_init(): + router1 = create_test_router() + + router2 = create_test_router() + router2.include_router(router1) + + for route in chain(router1.routes, router2.routes): + check_if_initialized(route, should_not=True) From 64d0123a3fd36aed3506bfbe1ab473cee415dab1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Oct 2024 14:52:41 +0000 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 6246c6e9f..9d335ad10 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -578,7 +578,9 @@ class APIRoute(routing.Route): ), f"Status code {additional_status_code} must not have a response body" response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field(name=response_name, type_=model, mode="serialization") + response_field = create_model_field( + name=response_name, type_=model, mode="serialization" + ) response_fields[additional_status_code] = response_field return response_fields From d70de472f653a94e74d222bb0a8a6112a9d47d59 Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Tue, 14 Jan 2025 21:44:03 +0100 Subject: [PATCH 6/7] improve formatting for and re-word doc string --- fastapi/routing.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 9d335ad10..af597cada 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -868,10 +868,13 @@ class APIRouter(routing.Router): ] = Default(generate_unique_id), defer_init: Annotated[ bool, - Doc(""" - By default every route will defer its initialization upon usage. - This flag disables the behavior for the routes defined in this router, causing the routes to initialize immediately. - """), + Doc( + """ + By default, every route will defer its initialization until the first call. + This flag can be used to deactivate this behavior for the routes defined in this router, + causing the routes to initialize immediately when they are defined. + """ + ), ] = True, ) -> None: super().__init__( From f5b6e08db9e715698310576545f18d869fa905e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Mar 2025 17:50:21 +0000 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 2fdeea7b5..485f35835 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -506,9 +506,9 @@ class APIRoute(routing.Route): status_code = int(status_code) self.status_code = status_code if self.response_model: - assert is_body_allowed_for_status_code( - status_code - ), f"Status code {status_code} must not have a response body" + assert is_body_allowed_for_status_code(status_code), ( + f"Status code {status_code} must not have a response body" + ) self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, @@ -573,9 +573,9 @@ class APIRoute(routing.Route): assert isinstance(response, dict), "An additional response must be a dict" model = response.get("model") if model: - assert is_body_allowed_for_status_code( - additional_status_code - ), f"Status code {additional_status_code} must not have a response body" + assert is_body_allowed_for_status_code(additional_status_code), ( + f"Status code {additional_status_code} must not have a response body" + ) response_name = f"Response_{additional_status_code}_{self.unique_id}" response_field = create_model_field(