|
|
@ -5,6 +5,7 @@ import inspect |
|
|
|
import json |
|
|
|
from contextlib import AsyncExitStack, asynccontextmanager |
|
|
|
from enum import Enum, IntEnum |
|
|
|
from functools import cached_property |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
AsyncIterator, |
|
|
@ -459,6 +460,7 @@ class APIRoute(routing.Route): |
|
|
|
generate_unique_id_function: Union[ |
|
|
|
Callable[["APIRoute"], str], DefaultPlaceholder |
|
|
|
] = Default(generate_unique_id), |
|
|
|
defer_init: bool = True, |
|
|
|
) -> None: |
|
|
|
self.path = path |
|
|
|
self.endpoint = endpoint |
|
|
@ -503,68 +505,16 @@ class APIRoute(routing.Route): |
|
|
|
if isinstance(status_code, IntEnum): |
|
|
|
status_code = int(status_code) |
|
|
|
self.status_code = status_code |
|
|
|
if self.response_model: |
|
|
|
assert is_body_allowed_for_status_code( |
|
|
|
status_code |
|
|
|
), f"Status code {status_code} must not have a response body" |
|
|
|
response_name = "Response_" + self.unique_id |
|
|
|
self.response_field = create_model_field( |
|
|
|
name=response_name, |
|
|
|
type_=self.response_model, |
|
|
|
mode="serialization", |
|
|
|
) |
|
|
|
# Create a clone of the field, so that a Pydantic submodel is not returned |
|
|
|
# as is just because it's an instance of a subclass of a more limited class |
|
|
|
# e.g. UserInDB (containing hashed_password) could be a subclass of User |
|
|
|
# that doesn't have the hashed_password. But because it's a subclass, it |
|
|
|
# would pass the validation and be returned as is. |
|
|
|
# By being a new field, no inheritance will be passed as is. A new model |
|
|
|
# will always be created. |
|
|
|
# TODO: remove when deprecating Pydantic v1 |
|
|
|
self.secure_cloned_response_field: Optional[ModelField] = ( |
|
|
|
create_cloned_field(self.response_field) |
|
|
|
) |
|
|
|
else: |
|
|
|
self.response_field = None # type: ignore |
|
|
|
self.secure_cloned_response_field = None |
|
|
|
self.dependencies = list(dependencies or []) |
|
|
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") |
|
|
|
# if a "form feed" character (page break) is found in the description text, |
|
|
|
# truncate description text to the content preceding the first "form feed" |
|
|
|
self.description = self.description.split("\f")[0].strip() |
|
|
|
response_fields = {} |
|
|
|
for additional_status_code, response in self.responses.items(): |
|
|
|
assert isinstance(response, dict), "An additional response must be a dict" |
|
|
|
model = response.get("model") |
|
|
|
if model: |
|
|
|
assert is_body_allowed_for_status_code( |
|
|
|
additional_status_code |
|
|
|
), f"Status code {additional_status_code} must not have a response body" |
|
|
|
response_name = f"Response_{additional_status_code}_{self.unique_id}" |
|
|
|
response_field = create_model_field(name=response_name, type_=model) |
|
|
|
response_fields[additional_status_code] = response_field |
|
|
|
if response_fields: |
|
|
|
self.response_fields: Dict[Union[int, str], ModelField] = response_fields |
|
|
|
else: |
|
|
|
self.response_fields = {} |
|
|
|
|
|
|
|
assert callable(endpoint), "An endpoint must be a callable" |
|
|
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) |
|
|
|
for depends in self.dependencies[::-1]: |
|
|
|
self.dependant.dependencies.insert( |
|
|
|
0, |
|
|
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format), |
|
|
|
) |
|
|
|
self._flat_dependant = get_flat_dependant(self.dependant) |
|
|
|
self._embed_body_fields = _should_embed_body_fields( |
|
|
|
self._flat_dependant.body_params |
|
|
|
) |
|
|
|
self.body_field = get_body_field( |
|
|
|
flat_dependant=self._flat_dependant, |
|
|
|
name=self.unique_id, |
|
|
|
embed_body_fields=self._embed_body_fields, |
|
|
|
) |
|
|
|
self.app = request_response(self.get_route_handler()) |
|
|
|
|
|
|
|
if not defer_init: |
|
|
|
self.init_attributes() |
|
|
|
|
|
|
|
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: |
|
|
|
return get_request_handler( |
|
|
@ -589,6 +539,85 @@ class APIRoute(routing.Route): |
|
|
|
child_scope["route"] = self |
|
|
|
return match, child_scope |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def response_field(self) -> Optional[ModelField]: |
|
|
|
if not self.response_model: |
|
|
|
return None |
|
|
|
response_name = "Response_" + self.unique_id |
|
|
|
return create_model_field( |
|
|
|
name=response_name, |
|
|
|
type_=self.response_model, |
|
|
|
mode="serialization", |
|
|
|
) |
|
|
|
|
|
|
|
# Create a clone of the field, so that a Pydantic submodel is not returned |
|
|
|
# as is just because it's an instance of a subclass of a more limited class |
|
|
|
# e.g. UserInDB (containing hashed_password) could be a subclass of User |
|
|
|
# that doesn't have the hashed_password. But because it's a subclass, it |
|
|
|
# would pass the validation and be returned as is. |
|
|
|
# By being a new field, no inheritance will be passed as is. A new model |
|
|
|
# will always be created. |
|
|
|
# TODO: remove when deprecating Pydantic v1 |
|
|
|
@cached_property |
|
|
|
def secure_cloned_response_field(self) -> Optional[ModelField]: |
|
|
|
return create_cloned_field(self.response_field) if self.response_field else None |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def response_fields(self) -> Dict[Union[int, str], ModelField]: |
|
|
|
response_fields = {} |
|
|
|
for additional_status_code, response in self.responses.items(): |
|
|
|
assert isinstance(response, dict), "An additional response must be a dict" |
|
|
|
model = response.get("model") |
|
|
|
if model: |
|
|
|
assert is_body_allowed_for_status_code( |
|
|
|
additional_status_code |
|
|
|
), f"Status code {additional_status_code} must not have a response body" |
|
|
|
|
|
|
|
response_name = f"Response_{additional_status_code}_{self.unique_id}" |
|
|
|
response_field = create_model_field(name=response_name, type_=model) |
|
|
|
response_fields[additional_status_code] = response_field |
|
|
|
return response_fields |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def dependant(self) -> Dependant: |
|
|
|
dependant = get_dependant(path=self.path_format, call=self.endpoint) |
|
|
|
for depends in self.dependencies[::-1]: |
|
|
|
dependant.dependencies.insert( |
|
|
|
0, |
|
|
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format), |
|
|
|
) |
|
|
|
return dependant |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def _flat_dependant(self) -> Dependant: |
|
|
|
return get_flat_dependant(self.dependant) |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def _embed_body_fields(self) -> bool: |
|
|
|
return _should_embed_body_fields(self._flat_dependant.body_params) |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def body_field(self) -> Optional[ModelField]: |
|
|
|
return get_body_field( |
|
|
|
flat_dependant=self._flat_dependant, |
|
|
|
name=self.unique_id, |
|
|
|
embed_body_fields=self._embed_body_fields, |
|
|
|
) |
|
|
|
|
|
|
|
@cached_property |
|
|
|
def app(self) -> ASGIApp: # type: ignore |
|
|
|
return request_response(self.get_route_handler()) |
|
|
|
|
|
|
|
def init_attributes(self) -> None: |
|
|
|
self.app = self.app |
|
|
|
self.dependant = self.dependant |
|
|
|
self.response_field = self.response_field |
|
|
|
self.response_fields = self.response_fields |
|
|
|
self.secure_cloned_response_field = self.secure_cloned_response_field |
|
|
|
self.body_field = self.body_field |
|
|
|
self._flat_dependant = self._flat_dependant |
|
|
|
self._embed_body_fields = self._embed_body_fields |
|
|
|
|
|
|
|
|
|
|
|
class APIRouter(routing.Router): |
|
|
|
""" |
|
|
@ -831,6 +860,13 @@ class APIRouter(routing.Router): |
|
|
|
""" |
|
|
|
), |
|
|
|
] = Default(generate_unique_id), |
|
|
|
defer_init: Annotated[ |
|
|
|
bool, |
|
|
|
Doc(""" |
|
|
|
By default every route will defer its initialization upon usage. |
|
|
|
This flag disables the behavior for the routes defined in this router, causing the routes to initialize immediately. |
|
|
|
"""), |
|
|
|
] = True, |
|
|
|
) -> None: |
|
|
|
super().__init__( |
|
|
|
routes=routes, |
|
|
@ -856,6 +892,9 @@ class APIRouter(routing.Router): |
|
|
|
self.route_class = route_class |
|
|
|
self.default_response_class = default_response_class |
|
|
|
self.generate_unique_id_function = generate_unique_id_function |
|
|
|
self.defer_init = defer_init |
|
|
|
if not self.defer_init: |
|
|
|
self.init_routes() |
|
|
|
|
|
|
|
def route( |
|
|
|
self, |
|
|
@ -955,6 +994,7 @@ class APIRouter(routing.Router): |
|
|
|
callbacks=current_callbacks, |
|
|
|
openapi_extra=openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id, |
|
|
|
defer_init=self.defer_init, |
|
|
|
) |
|
|
|
self.routes.append(route) |
|
|
|
|
|
|
@ -1117,6 +1157,11 @@ class APIRouter(routing.Router): |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
def init_routes(self) -> None: |
|
|
|
for route in self.routes: |
|
|
|
if isinstance(route, APIRoute): |
|
|
|
route.init_attributes() |
|
|
|
|
|
|
|
def include_router( |
|
|
|
self, |
|
|
|
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], |
|
|
|