Browse Source

Merge f5b6e08db9 into 1d434dec47

pull/10589/merge
Jan Vollmer 2 days ago
committed by GitHub
parent
commit
83f77a6635
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 1
      fastapi/applications.py
  2. 158
      fastapi/routing.py
  3. 117
      tests/test_route_deferred_init.py

1
fastapi/applications.py

@ -943,6 +943,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
defer_init=False,
)
self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]

158
fastapi/routing.py

@ -5,6 +5,7 @@ import inspect
import json
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from functools import cached_property
from typing import (
Any,
AsyncIterator,
@ -459,6 +460,7 @@ class APIRoute(routing.Route):
generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id),
defer_init: bool = True,
) -> None:
self.path = path
self.endpoint = endpoint
@ -507,66 +509,16 @@ class APIRoute(routing.Route):
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + self.unique_id
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0].strip()
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(additional_status_code), (
f"Status code {additional_status_code} must not have a response body"
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
else:
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.app = request_response(self.get_route_handler())
if not defer_init:
self.init_attributes()
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
@ -591,6 +543,87 @@ class APIRoute(routing.Route):
child_scope["route"] = self
return match, child_scope
@cached_property
def response_field(self) -> Optional[ModelField]:
if not self.response_model:
return None
response_name = "Response_" + self.unique_id
return create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
@cached_property
def secure_cloned_response_field(self) -> Optional[ModelField]:
return create_cloned_field(self.response_field) if self.response_field else None
@cached_property
def response_fields(self) -> Dict[Union[int, str], ModelField]:
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
if model:
assert is_body_allowed_for_status_code(additional_status_code), (
f"Status code {additional_status_code} must not have a response body"
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
return response_fields
@cached_property
def dependant(self) -> Dependant:
dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
return dependant
@cached_property
def _flat_dependant(self) -> Dependant:
return get_flat_dependant(self.dependant)
@cached_property
def _embed_body_fields(self) -> bool:
return _should_embed_body_fields(self._flat_dependant.body_params)
@cached_property
def body_field(self) -> Optional[ModelField]:
return get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
@cached_property
def app(self) -> ASGIApp: # type: ignore
return request_response(self.get_route_handler())
def init_attributes(self) -> None:
self.app = self.app
self.dependant = self.dependant
self.response_field = self.response_field
self.response_fields = self.response_fields
self.secure_cloned_response_field = self.secure_cloned_response_field
self.body_field = self.body_field
self._flat_dependant = self._flat_dependant
self._embed_body_fields = self._embed_body_fields
class APIRouter(routing.Router):
"""
@ -833,6 +866,16 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
defer_init: Annotated[
bool,
Doc(
"""
By default, every route will defer its initialization until the first call.
This flag can be used to deactivate this behavior for the routes defined in this router,
causing the routes to initialize immediately when they are defined.
"""
),
] = True,
) -> None:
super().__init__(
routes=routes,
@ -858,6 +901,9 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.defer_init = defer_init
if not self.defer_init:
self.init_routes()
def route(
self,
@ -957,6 +1003,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id,
defer_init=self.defer_init,
)
self.routes.append(route)
@ -1119,6 +1166,11 @@ class APIRouter(routing.Router):
return decorator
def init_routes(self) -> None:
for route in self.routes:
if isinstance(route, APIRoute):
route.init_attributes()
def include_router(
self,
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],

117
tests/test_route_deferred_init.py

@ -0,0 +1,117 @@
from itertools import chain
from typing import List, Optional
from fastapi import APIRouter, Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute
deferred_keys = [
"app",
"response_fields",
"body_field",
"response_field",
"secure_cloned_response_field",
"dependant",
"_flat_dependant",
"_embed_body_fields",
]
def check_if_initialized(route: APIRoute, should_not: bool = False):
for key in deferred_keys:
if should_not:
assert key not in route.__dict__
else:
assert key in route.__dict__
def create_test_router(
routes: Optional[List[BaseRoute]] = None, defer_init: bool = True
):
router = APIRouter(routes=routes or [], defer_init=defer_init)
class UserIdBody(BaseModel):
user_id: int
@router.get("/user_id", dependencies=[Depends(lambda: True)])
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
return {"user_id": user_id}
return router
def test_route_defers():
app = FastAPI()
router = create_test_router(routes=app.router.routes)
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route, should_not=True)
app.router = router
client = TestClient(app)
response = client.get("/user_id")
assert response.status_code == 200
response = client.get("/openapi.json")
assert response.status_code == 200
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
def test_route_manual_init():
router = create_test_router()
for route in router.routes:
check_if_initialized(route, should_not=True)
route.init_attributes()
check_if_initialized(route)
router = create_test_router()
router.init_routes()
for route in router.routes:
check_if_initialized(route)
def test_router_defer_init_flag():
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
check_if_initialized(route)
deferring_router = create_test_router()
router = create_test_router(routes=deferring_router.routes, defer_init=False)
for route in router.routes:
check_if_initialized(route)
def test_root_router_always_initialized():
app = FastAPI()
@app.get("/test")
async def test_get():
return {"test": 1}
router = create_test_router()
app.include_router(router)
for route in app.router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
def test_include_router_no_init():
router1 = create_test_router()
router2 = create_test_router()
router2.include_router(router1)
for route in chain(router1.routes, router2.routes):
check_if_initialized(route, should_not=True)
Loading…
Cancel
Save