3 changed files with 186 additions and 776 deletions
@ -1,800 +1,207 @@ |
|||
import dataclasses |
|||
import inspect |
|||
from contextlib import contextmanager |
|||
from copy import deepcopy |
|||
from typing import ( |
|||
Any, |
|||
Callable, |
|||
Coroutine, |
|||
Dict, |
|||
ForwardRef, |
|||
List, |
|||
Mapping, |
|||
Optional, |
|||
Sequence, |
|||
Tuple, |
|||
Type, |
|||
Union, |
|||
cast, |
|||
) |
|||
|
|||
import anyio |
|||
from fastapi import params |
|||
from fastapi.concurrency import ( |
|||
AsyncExitStack, |
|||
asynccontextmanager, |
|||
contextmanager_in_threadpool, |
|||
) |
|||
from fastapi.dependencies.models import Dependant, SecurityRequirement |
|||
from fastapi.dependencies.stubs import GenericTypeStub |
|||
from fastapi.logger import logger |
|||
from fastapi.security.base import SecurityBase |
|||
from fastapi.security.oauth2 import OAuth2, SecurityScopes |
|||
from fastapi.security.open_id_connect_url import OpenIdConnect |
|||
from fastapi.utils import create_response_field, get_path_param_names |
|||
from pydantic import BaseModel, create_model |
|||
from pydantic.error_wrappers import ErrorWrapper |
|||
from pydantic.errors import MissingError |
|||
from pydantic.fields import ( |
|||
SHAPE_FROZENSET, |
|||
SHAPE_LIST, |
|||
SHAPE_SEQUENCE, |
|||
SHAPE_SET, |
|||
SHAPE_SINGLETON, |
|||
SHAPE_TUPLE, |
|||
SHAPE_TUPLE_ELLIPSIS, |
|||
FieldInfo, |
|||
ModelField, |
|||
Required, |
|||
Undefined, |
|||
) |
|||
from pydantic.schema import get_annotation_from_field_info |
|||
from pydantic.typing import evaluate_forwardref |
|||
import functools |
|||
import re |
|||
import warnings |
|||
from dataclasses import is_dataclass |
|||
from enum import Enum |
|||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast |
|||
|
|||
import fastapi |
|||
from fastapi.datastructures import DefaultPlaceholder, DefaultType |
|||
from fastapi.openapi.constants import REF_PREFIX |
|||
from pydantic import BaseConfig, BaseModel, create_model |
|||
from pydantic.class_validators import Validator |
|||
from pydantic.fields import FieldInfo, ModelField, UndefinedType |
|||
from pydantic.schema import model_process_schema |
|||
from pydantic.utils import lenient_issubclass |
|||
from starlette.background import BackgroundTasks |
|||
from starlette.concurrency import run_in_threadpool |
|||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile |
|||
from starlette.requests import HTTPConnection, Request |
|||
from starlette.responses import Response |
|||
from starlette.websockets import WebSocket |
|||
|
|||
sequence_shapes = { |
|||
SHAPE_LIST, |
|||
SHAPE_SET, |
|||
SHAPE_FROZENSET, |
|||
SHAPE_TUPLE, |
|||
SHAPE_SEQUENCE, |
|||
SHAPE_TUPLE_ELLIPSIS, |
|||
} |
|||
sequence_types = (list, set, tuple) |
|||
sequence_shape_to_type = { |
|||
SHAPE_LIST: list, |
|||
SHAPE_SET: set, |
|||
SHAPE_TUPLE: tuple, |
|||
SHAPE_SEQUENCE: list, |
|||
SHAPE_TUPLE_ELLIPSIS: list, |
|||
} |
|||
|
|||
|
|||
multipart_not_installed_error = ( |
|||
'Form data requires "python-multipart" to be installed. \n' |
|||
'You can install "python-multipart" with: \n\n' |
|||
"pip install python-multipart\n" |
|||
) |
|||
multipart_incorrect_install_error = ( |
|||
'Form data requires "python-multipart" to be installed. ' |
|||
'It seems you installed "multipart" instead. \n' |
|||
'You can remove "multipart" with: \n\n' |
|||
"pip uninstall multipart\n\n" |
|||
'And then install "python-multipart" with: \n\n' |
|||
"pip install python-multipart\n" |
|||
) |
|||
|
|||
|
|||
def check_file_field(field: ModelField) -> None: |
|||
field_info = field.field_info |
|||
if isinstance(field_info, params.Form): |
|||
try: |
|||
# __version__ is available in both multiparts, and can be mocked |
|||
from multipart import __version__ # type: ignore |
|||
|
|||
assert __version__ |
|||
try: |
|||
# parse_options_header is only available in the right multipart |
|||
from multipart.multipart import parse_options_header # type: ignore |
|||
|
|||
assert parse_options_header |
|||
except ImportError: |
|||
logger.error(multipart_incorrect_install_error) |
|||
raise RuntimeError(multipart_incorrect_install_error) from None |
|||
except ImportError: |
|||
logger.error(multipart_not_installed_error) |
|||
raise RuntimeError(multipart_not_installed_error) from None |
|||
|
|||
|
|||
def get_param_sub_dependant( |
|||
*, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None |
|||
) -> Dependant: |
|||
depends: params.Depends = param.default |
|||
if depends.dependency: |
|||
dependency = depends.dependency |
|||
else: |
|||
dependency = param.annotation |
|||
return get_sub_dependant( |
|||
depends=depends, |
|||
dependency=dependency, |
|||
path=path, |
|||
name=param.name, |
|||
security_scopes=security_scopes, |
|||
) |
|||
|
|||
|
|||
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: |
|||
assert callable( |
|||
depends.dependency |
|||
), "A parameter-less dependency must have a callable dependency" |
|||
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) |
|||
|
|||
|
|||
def get_sub_dependant( |
|||
*, |
|||
depends: params.Depends, |
|||
dependency: Callable[..., Any], |
|||
path: str, |
|||
name: Optional[str] = None, |
|||
security_scopes: Optional[List[str]] = None, |
|||
) -> Dependant: |
|||
security_requirement = None |
|||
security_scopes = security_scopes or [] |
|||
if isinstance(depends, params.Security): |
|||
dependency_scopes = depends.scopes |
|||
security_scopes.extend(dependency_scopes) |
|||
if isinstance(dependency, SecurityBase): |
|||
use_scopes: List[str] = [] |
|||
if isinstance(dependency, (OAuth2, OpenIdConnect)): |
|||
use_scopes = security_scopes |
|||
security_requirement = SecurityRequirement( |
|||
security_scheme=dependency, scopes=use_scopes |
|||
) |
|||
sub_dependant = get_dependant( |
|||
path=path, |
|||
call=dependency, |
|||
name=name, |
|||
security_scopes=security_scopes, |
|||
use_cache=depends.use_cache, |
|||
) |
|||
if security_requirement: |
|||
sub_dependant.security_requirements.append(security_requirement) |
|||
return sub_dependant |
|||
|
|||
|
|||
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] |
|||
|
|||
|
|||
def get_flat_dependant( |
|||
dependant: Dependant, |
|||
*, |
|||
skip_repeats: bool = False, |
|||
visited: Optional[List[CacheKey]] = None, |
|||
) -> Dependant: |
|||
if visited is None: |
|||
visited = [] |
|||
visited.append(dependant.cache_key) |
|||
|
|||
flat_dependant = Dependant( |
|||
path_params=dependant.path_params.copy(), |
|||
query_params=dependant.query_params.copy(), |
|||
header_params=dependant.header_params.copy(), |
|||
cookie_params=dependant.cookie_params.copy(), |
|||
body_params=dependant.body_params.copy(), |
|||
security_schemes=dependant.security_requirements.copy(), |
|||
use_cache=dependant.use_cache, |
|||
path=dependant.path, |
|||
) |
|||
for sub_dependant in dependant.dependencies: |
|||
if skip_repeats and sub_dependant.cache_key in visited: |
|||
continue |
|||
flat_sub = get_flat_dependant( |
|||
sub_dependant, skip_repeats=skip_repeats, visited=visited |
|||
) |
|||
flat_dependant.path_params.extend(flat_sub.path_params) |
|||
flat_dependant.query_params.extend(flat_sub.query_params) |
|||
flat_dependant.header_params.extend(flat_sub.header_params) |
|||
flat_dependant.cookie_params.extend(flat_sub.cookie_params) |
|||
flat_dependant.body_params.extend(flat_sub.body_params) |
|||
flat_dependant.security_requirements.extend(flat_sub.security_requirements) |
|||
return flat_dependant |
|||
|
|||
|
|||
def get_flat_params(dependant: Dependant) -> List[ModelField]: |
|||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True) |
|||
return ( |
|||
flat_dependant.path_params |
|||
+ flat_dependant.query_params |
|||
+ flat_dependant.header_params |
|||
+ flat_dependant.cookie_params |
|||
) |
|||
|
|||
|
|||
def is_scalar_field(field: ModelField) -> bool: |
|||
field_info = field.field_info |
|||
if not ( |
|||
field.shape == SHAPE_SINGLETON |
|||
and not lenient_issubclass(field.type_, BaseModel) |
|||
and not lenient_issubclass(field.type_, sequence_types + (dict,)) |
|||
and not dataclasses.is_dataclass(field.type_) |
|||
and not isinstance(field_info, params.Body) |
|||
): |
|||
return False |
|||
if field.sub_fields: |
|||
if not all(is_scalar_field(f) for f in field.sub_fields): |
|||
return False |
|||
return True |
|||
if TYPE_CHECKING: # pragma: nocover |
|||
from .routing import APIRoute |
|||
|
|||
|
|||
def is_scalar_sequence_field(field: ModelField) -> bool: |
|||
if (field.shape in sequence_shapes) and not lenient_issubclass( |
|||
field.type_, BaseModel |
|||
): |
|||
if field.sub_fields is not None: |
|||
for sub_field in field.sub_fields: |
|||
if not is_scalar_field(sub_field): |
|||
return False |
|||
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: |
|||
if status_code is None: |
|||
return True |
|||
if lenient_issubclass(field.type_, sequence_types): |
|||
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 |
|||
if status_code in { |
|||
"default", |
|||
"1XX", |
|||
"2XX", |
|||
"3XX", |
|||
"4XX", |
|||
"5XX", |
|||
}: |
|||
return True |
|||
return False |
|||
|
|||
|
|||
def is_generic_type(obj: Any) -> bool: |
|||
return hasattr(obj, "__origin__") and hasattr(obj, "__args__") |
|||
|
|||
|
|||
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: |
|||
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} |
|||
if is_generic_type(annotation): |
|||
args = tuple( |
|||
substitute_generic_type(arg, typevars) for arg in annotation.__args__ |
|||
) |
|||
annotation = collection_shells.get(annotation.__origin__, annotation) |
|||
return annotation[args] |
|||
return typevars.get(annotation.__name__, annotation) |
|||
|
|||
|
|||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: |
|||
typevars = None |
|||
if is_generic_type(call): |
|||
stub: GenericTypeStub = cast(GenericTypeStub, call) |
|||
typevars = { |
|||
typevar.__name__: value |
|||
for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__) |
|||
} |
|||
call = stub.__origin__.__init__ # type: ignore |
|||
|
|||
signature = inspect.signature(call) |
|||
globalns = getattr(call, "__globals__", {}) |
|||
|
|||
typed_params = [ |
|||
inspect.Parameter( |
|||
name=param.name, |
|||
kind=param.kind, |
|||
default=param.default, |
|||
annotation=get_typed_annotation(param.annotation, globalns, typevars), |
|||
) |
|||
for param in signature.parameters.values() |
|||
if param.name != "self" |
|||
] |
|||
typed_signature = inspect.Signature(typed_params) |
|||
return typed_signature |
|||
current_status_code = int(status_code) |
|||
return not (current_status_code < 200 or current_status_code in {204, 304}) |
|||
|
|||
|
|||
def get_typed_annotation( |
|||
annotation: Any, |
|||
globalns: Dict[str, Any], |
|||
typevars: Optional[Dict[str, type]] = None, |
|||
) -> Any: |
|||
if isinstance(annotation, str): |
|||
annotation = ForwardRef(annotation) |
|||
annotation = evaluate_forwardref(annotation, globalns, globalns) |
|||
if typevars: |
|||
annotation = substitute_generic_type(annotation, typevars) |
|||
return annotation |
|||
|
|||
|
|||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any: |
|||
signature = inspect.signature(call) |
|||
annotation = signature.return_annotation |
|||
|
|||
if annotation is inspect.Signature.empty: |
|||
return None |
|||
|
|||
globalns = getattr(call, "__globals__", {}) |
|||
return get_typed_annotation(annotation, globalns) |
|||
|
|||
|
|||
def get_dependant( |
|||
def get_model_definitions( |
|||
*, |
|||
path: str, |
|||
call: Callable[..., Any], |
|||
name: Optional[str] = None, |
|||
security_scopes: Optional[List[str]] = None, |
|||
use_cache: bool = True, |
|||
) -> Dependant: |
|||
path_param_names = get_path_param_names(path) |
|||
endpoint_signature = get_typed_signature(call) |
|||
signature_params = endpoint_signature.parameters |
|||
dependant = Dependant( |
|||
call=call, |
|||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]], |
|||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], |
|||
) -> Dict[str, Any]: |
|||
definitions: Dict[str, Dict[str, Any]] = {} |
|||
for model in flat_models: |
|||
m_schema, m_definitions, m_nested_models = model_process_schema( |
|||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX |
|||
) |
|||
definitions.update(m_definitions) |
|||
model_name = model_name_map[model] |
|||
if "description" in m_schema: |
|||
m_schema["description"] = m_schema["description"].split("\f")[0] |
|||
definitions[model_name] = m_schema |
|||
return definitions |
|||
|
|||
|
|||
def get_path_param_names(path: str) -> Set[str]: |
|||
return set(re.findall("{(.*?)}", path)) |
|||
|
|||
|
|||
def create_response_field( |
|||
name: str, |
|||
type_: Type[Any], |
|||
class_validators: Optional[Dict[str, Validator]] = None, |
|||
default: Optional[Any] = None, |
|||
required: Union[bool, UndefinedType] = True, |
|||
model_config: Type[BaseConfig] = BaseConfig, |
|||
field_info: Optional[FieldInfo] = None, |
|||
alias: Optional[str] = None, |
|||
) -> ModelField: |
|||
""" |
|||
Create a new response field. Raises if type_ is invalid. |
|||
""" |
|||
class_validators = class_validators or {} |
|||
field_info = field_info or FieldInfo() |
|||
|
|||
response_field = functools.partial( |
|||
ModelField, |
|||
name=name, |
|||
path=path, |
|||
security_scopes=security_scopes, |
|||
use_cache=use_cache, |
|||
type_=type_, |
|||
class_validators=class_validators, |
|||
default=default, |
|||
required=required, |
|||
model_config=model_config, |
|||
alias=alias, |
|||
) |
|||
for param_name, param in signature_params.items(): |
|||
if isinstance(param.default, params.Depends): |
|||
sub_dependant = get_param_sub_dependant( |
|||
param=param, path=path, security_scopes=security_scopes |
|||
) |
|||
dependant.dependencies.append(sub_dependant) |
|||
continue |
|||
if add_non_field_param_to_dependency(param=param, dependant=dependant): |
|||
continue |
|||
param_field = get_param_field( |
|||
param=param, default_field_info=params.Query, param_name=param_name |
|||
) |
|||
if param_name in path_param_names: |
|||
assert is_scalar_field( |
|||
field=param_field |
|||
), "Path params must be of one of the supported types" |
|||
ignore_default = not isinstance(param.default, params.Path) |
|||
param_field = get_param_field( |
|||
param=param, |
|||
param_name=param_name, |
|||
default_field_info=params.Path, |
|||
force_type=params.ParamTypes.path, |
|||
ignore_default=ignore_default, |
|||
) |
|||
add_param_to_fields(field=param_field, dependant=dependant) |
|||
elif is_scalar_field(field=param_field): |
|||
add_param_to_fields(field=param_field, dependant=dependant) |
|||
elif isinstance( |
|||
param.default, (params.Query, params.Header) |
|||
) and is_scalar_sequence_field(param_field): |
|||
add_param_to_fields(field=param_field, dependant=dependant) |
|||
else: |
|||
field_info = param_field.field_info |
|||
assert isinstance( |
|||
field_info, params.Body |
|||
), f"Param: {param_field.name} can only be a request body, using Body()" |
|||
dependant.body_params.append(param_field) |
|||
return dependant |
|||
|
|||
|
|||
def add_non_field_param_to_dependency( |
|||
*, param: inspect.Parameter, dependant: Dependant |
|||
) -> Optional[bool]: |
|||
if lenient_issubclass(param.annotation, Request): |
|||
dependant.request_param_name = param.name |
|||
return True |
|||
elif lenient_issubclass(param.annotation, WebSocket): |
|||
dependant.websocket_param_name = param.name |
|||
return True |
|||
elif lenient_issubclass(param.annotation, HTTPConnection): |
|||
dependant.http_connection_param_name = param.name |
|||
return True |
|||
elif lenient_issubclass(param.annotation, Response): |
|||
dependant.response_param_name = param.name |
|||
return True |
|||
elif lenient_issubclass(param.annotation, BackgroundTasks): |
|||
dependant.background_tasks_param_name = param.name |
|||
return True |
|||
elif lenient_issubclass(param.annotation, SecurityScopes): |
|||
dependant.security_scopes_param_name = param.name |
|||
return True |
|||
return None |
|||
|
|||
|
|||
def get_param_field( |
|||
try: |
|||
return response_field(field_info=field_info) |
|||
except RuntimeError: |
|||
raise fastapi.exceptions.FastAPIError( |
|||
"Invalid args for response field! Hint: " |
|||
f"check that {type_} is a valid Pydantic field type. " |
|||
"If you are using a return type annotation that is not a valid Pydantic " |
|||
"field (e.g. Union[Response, dict, None]) you can disable generating the " |
|||
"response model from the type annotation with the path operation decorator " |
|||
"parameter response_model=None. Read more: " |
|||
"https://fastapi.tiangolo.com/tutorial/response-model/" |
|||
) from None |
|||
|
|||
|
|||
def create_cloned_field( |
|||
field: ModelField, |
|||
*, |
|||
param: inspect.Parameter, |
|||
param_name: str, |
|||
default_field_info: Type[params.Param] = params.Param, |
|||
force_type: Optional[params.ParamTypes] = None, |
|||
ignore_default: bool = False, |
|||
cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, |
|||
) -> ModelField: |
|||
default_value: Any = Undefined |
|||
had_schema = False |
|||
if not param.default == param.empty and ignore_default is False: |
|||
default_value = param.default |
|||
if isinstance(default_value, FieldInfo): |
|||
had_schema = True |
|||
field_info = default_value |
|||
default_value = field_info.default |
|||
if ( |
|||
isinstance(field_info, params.Param) |
|||
and getattr(field_info, "in_", None) is None |
|||
): |
|||
field_info.in_ = default_field_info.in_ |
|||
if force_type: |
|||
field_info.in_ = force_type # type: ignore |
|||
else: |
|||
field_info = default_field_info(default=default_value) |
|||
required = True |
|||
if default_value is Required or ignore_default: |
|||
required = True |
|||
default_value = None |
|||
elif default_value is not Undefined: |
|||
required = False |
|||
annotation: Any = Any |
|||
if not param.annotation == param.empty: |
|||
annotation = param.annotation |
|||
annotation = get_annotation_from_field_info(annotation, field_info, param_name) |
|||
if not field_info.alias and getattr(field_info, "convert_underscores", None): |
|||
alias = param.name.replace("_", "-") |
|||
else: |
|||
alias = field_info.alias or param.name |
|||
field = create_response_field( |
|||
name=param.name, |
|||
type_=annotation, |
|||
default=default_value, |
|||
alias=alias, |
|||
required=required, |
|||
field_info=field_info, |
|||
# _cloned_types has already cloned types, to support recursive models |
|||
if cloned_types is None: |
|||
cloned_types = {} |
|||
original_type = field.type_ |
|||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): |
|||
original_type = original_type.__pydantic_model__ |
|||
use_type = original_type |
|||
if lenient_issubclass(original_type, BaseModel): |
|||
original_type = cast(Type[BaseModel], original_type) |
|||
use_type = cloned_types.get(original_type) |
|||
if use_type is None: |
|||
use_type = create_model(original_type.__name__, __base__=original_type) |
|||
cloned_types[original_type] = use_type |
|||
for f in original_type.__fields__.values(): |
|||
use_type.__fields__[f.name] = create_cloned_field( |
|||
f, cloned_types=cloned_types |
|||
) |
|||
new_field = create_response_field(name=field.name, type_=use_type) |
|||
new_field.has_alias = field.has_alias |
|||
new_field.alias = field.alias |
|||
new_field.class_validators = field.class_validators |
|||
new_field.default = field.default |
|||
new_field.required = field.required |
|||
new_field.model_config = field.model_config |
|||
new_field.field_info = field.field_info |
|||
new_field.allow_none = field.allow_none |
|||
new_field.validate_always = field.validate_always |
|||
if field.sub_fields: |
|||
new_field.sub_fields = [ |
|||
create_cloned_field(sub_field, cloned_types=cloned_types) |
|||
for sub_field in field.sub_fields |
|||
] |
|||
if field.key_field: |
|||
new_field.key_field = create_cloned_field( |
|||
field.key_field, cloned_types=cloned_types |
|||
) |
|||
new_field.validators = field.validators |
|||
new_field.pre_validators = field.pre_validators |
|||
new_field.post_validators = field.post_validators |
|||
new_field.parse_json = field.parse_json |
|||
new_field.shape = field.shape |
|||
new_field.populate_validators() |
|||
return new_field |
|||
|
|||
|
|||
def generate_operation_id_for_path( |
|||
*, name: str, path: str, method: str |
|||
) -> str: # pragma: nocover |
|||
warnings.warn( |
|||
"fastapi.utils.generate_operation_id_for_path() was deprecated, " |
|||
"it is not used internally, and will be removed soon", |
|||
DeprecationWarning, |
|||
stacklevel=2, |
|||
) |
|||
if not had_schema and not is_scalar_field(field=field): |
|||
field.field_info = params.Body(field_info.default) |
|||
if not had_schema and lenient_issubclass(field.type_, UploadFile): |
|||
field.field_info = params.File(field_info.default) |
|||
|
|||
return field |
|||
|
|||
|
|||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: |
|||
field_info = cast(params.Param, field.field_info) |
|||
if field_info.in_ == params.ParamTypes.path: |
|||
dependant.path_params.append(field) |
|||
elif field_info.in_ == params.ParamTypes.query: |
|||
dependant.query_params.append(field) |
|||
elif field_info.in_ == params.ParamTypes.header: |
|||
dependant.header_params.append(field) |
|||
else: |
|||
assert ( |
|||
field_info.in_ == params.ParamTypes.cookie |
|||
), f"non-body parameters must be in path, query, header or cookie: {field.name}" |
|||
dependant.cookie_params.append(field) |
|||
|
|||
|
|||
def is_coroutine_callable(call: Callable[..., Any]) -> bool: |
|||
if inspect.isroutine(call): |
|||
return inspect.iscoroutinefunction(call) |
|||
if inspect.isclass(call): |
|||
return False |
|||
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|||
return inspect.iscoroutinefunction(dunder_call) |
|||
|
|||
operation_id = name + path |
|||
operation_id = re.sub(r"\W", "_", operation_id) |
|||
operation_id = operation_id + "_" + method.lower() |
|||
return operation_id |
|||
|
|||
def is_async_gen_callable(call: Callable[..., Any]) -> bool: |
|||
if inspect.isasyncgenfunction(call): |
|||
return True |
|||
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|||
return inspect.isasyncgenfunction(dunder_call) |
|||
|
|||
|
|||
def is_gen_callable(call: Callable[..., Any]) -> bool: |
|||
if inspect.isgeneratorfunction(call): |
|||
return True |
|||
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|||
return inspect.isgeneratorfunction(dunder_call) |
|||
|
|||
def generate_unique_id(route: "APIRoute") -> str: |
|||
operation_id = route.name + route.path_format |
|||
operation_id = re.sub(r"\W", "_", operation_id) |
|||
assert route.methods |
|||
operation_id = operation_id + "_" + list(route.methods)[0].lower() |
|||
return operation_id |
|||
|
|||
async def solve_generator( |
|||
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] |
|||
) -> Any: |
|||
if is_gen_callable(call): |
|||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) |
|||
elif is_async_gen_callable(call): |
|||
cm = asynccontextmanager(call)(**sub_values) |
|||
return await stack.enter_async_context(cm) |
|||
|
|||
|
|||
async def solve_dependencies( |
|||
*, |
|||
request: Union[Request, WebSocket], |
|||
dependant: Dependant, |
|||
body: Optional[Union[Dict[str, Any], FormData]] = None, |
|||
background_tasks: Optional[BackgroundTasks] = None, |
|||
response: Optional[Response] = None, |
|||
dependency_overrides_provider: Optional[Any] = None, |
|||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, |
|||
) -> Tuple[ |
|||
Dict[str, Any], |
|||
List[ErrorWrapper], |
|||
Optional[BackgroundTasks], |
|||
Response, |
|||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any], |
|||
]: |
|||
values: Dict[str, Any] = {} |
|||
errors: List[ErrorWrapper] = [] |
|||
if response is None: |
|||
response = Response() |
|||
del response.headers["content-length"] |
|||
response.status_code = None # type: ignore |
|||
dependency_cache = dependency_cache or {} |
|||
sub_dependant: Dependant |
|||
for sub_dependant in dependant.dependencies: |
|||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) |
|||
sub_dependant.cache_key = cast( |
|||
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key |
|||
) |
|||
call = sub_dependant.call |
|||
use_sub_dependant = sub_dependant |
|||
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: |
|||
for key, value in update_dict.items(): |
|||
if ( |
|||
dependency_overrides_provider |
|||
and dependency_overrides_provider.dependency_overrides |
|||
key in main_dict |
|||
and isinstance(main_dict[key], dict) |
|||
and isinstance(value, dict) |
|||
): |
|||
original_call = sub_dependant.call |
|||
call = getattr( |
|||
dependency_overrides_provider, "dependency_overrides", {} |
|||
).get(original_call, original_call) |
|||
use_path: str = sub_dependant.path # type: ignore |
|||
use_sub_dependant = get_dependant( |
|||
path=use_path, |
|||
call=call, |
|||
name=sub_dependant.name, |
|||
security_scopes=sub_dependant.security_scopes, |
|||
) |
|||
|
|||
solved_result = await solve_dependencies( |
|||
request=request, |
|||
dependant=use_sub_dependant, |
|||
body=body, |
|||
background_tasks=background_tasks, |
|||
response=response, |
|||
dependency_overrides_provider=dependency_overrides_provider, |
|||
dependency_cache=dependency_cache, |
|||
) |
|||
( |
|||
sub_values, |
|||
sub_errors, |
|||
background_tasks, |
|||
_, # the subdependency returns the same response we have |
|||
sub_dependency_cache, |
|||
) = solved_result |
|||
dependency_cache.update(sub_dependency_cache) |
|||
if sub_errors: |
|||
errors.extend(sub_errors) |
|||
continue |
|||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: |
|||
solved = dependency_cache[sub_dependant.cache_key] |
|||
elif is_gen_callable(call) or is_async_gen_callable(call): |
|||
stack = request.scope.get("fastapi_astack") |
|||
assert isinstance(stack, AsyncExitStack) |
|||
solved = await solve_generator( |
|||
call=call, stack=stack, sub_values=sub_values |
|||
) |
|||
elif is_coroutine_callable(call): |
|||
solved = await call(**sub_values) |
|||
else: |
|||
solved = await run_in_threadpool(call, **sub_values) |
|||
if sub_dependant.name is not None: |
|||
values[sub_dependant.name] = solved |
|||
if sub_dependant.cache_key not in dependency_cache: |
|||
dependency_cache[sub_dependant.cache_key] = solved |
|||
path_values, path_errors = request_params_to_args( |
|||
dependant.path_params, request.path_params |
|||
) |
|||
query_values, query_errors = request_params_to_args( |
|||
dependant.query_params, request.query_params |
|||
) |
|||
header_values, header_errors = request_params_to_args( |
|||
dependant.header_params, request.headers |
|||
) |
|||
cookie_values, cookie_errors = request_params_to_args( |
|||
dependant.cookie_params, request.cookies |
|||
) |
|||
values.update(path_values) |
|||
values.update(query_values) |
|||
values.update(header_values) |
|||
values.update(cookie_values) |
|||
errors += path_errors + query_errors + header_errors + cookie_errors |
|||
if dependant.body_params: |
|||
( |
|||
body_values, |
|||
body_errors, |
|||
) = await request_body_to_args( # body_params checked above |
|||
required_params=dependant.body_params, received_body=body |
|||
) |
|||
values.update(body_values) |
|||
errors.extend(body_errors) |
|||
if dependant.http_connection_param_name: |
|||
values[dependant.http_connection_param_name] = request |
|||
if dependant.request_param_name and isinstance(request, Request): |
|||
values[dependant.request_param_name] = request |
|||
elif dependant.websocket_param_name and isinstance(request, WebSocket): |
|||
values[dependant.websocket_param_name] = request |
|||
if dependant.background_tasks_param_name: |
|||
if background_tasks is None: |
|||
background_tasks = BackgroundTasks() |
|||
values[dependant.background_tasks_param_name] = background_tasks |
|||
if dependant.response_param_name: |
|||
values[dependant.response_param_name] = response |
|||
if dependant.security_scopes_param_name: |
|||
values[dependant.security_scopes_param_name] = SecurityScopes( |
|||
scopes=dependant.security_scopes |
|||
) |
|||
return values, errors, background_tasks, response, dependency_cache |
|||
|
|||
|
|||
def request_params_to_args( |
|||
required_params: Sequence[ModelField], |
|||
received_params: Union[Mapping[str, Any], QueryParams, Headers], |
|||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
for field in required_params: |
|||
if is_scalar_sequence_field(field) and isinstance( |
|||
received_params, (QueryParams, Headers) |
|||
deep_dict_update(main_dict[key], value) |
|||
elif ( |
|||
key in main_dict |
|||
and isinstance(main_dict[key], list) |
|||
and isinstance(update_dict[key], list) |
|||
): |
|||
value = received_params.getlist(field.alias) or field.default |
|||
else: |
|||
value = received_params.get(field.alias) |
|||
field_info = field.field_info |
|||
assert isinstance( |
|||
field_info, params.Param |
|||
), "Params must be subclasses of Param" |
|||
if value is None: |
|||
if field.required: |
|||
errors.append( |
|||
ErrorWrapper( |
|||
MissingError(), loc=(field_info.in_.value, field.alias) |
|||
) |
|||
) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
v_, errors_ = field.validate( |
|||
value, values, loc=(field_info.in_.value, field.alias) |
|||
) |
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
main_dict[key] = main_dict[key] + update_dict[key] |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
main_dict[key] = value |
|||
|
|||
|
|||
async def request_body_to_args( |
|||
required_params: List[ModelField], |
|||
received_body: Optional[Union[Dict[str, Any], FormData]], |
|||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
if required_params: |
|||
field = required_params[0] |
|||
field_info = field.field_info |
|||
embed = getattr(field_info, "embed", None) |
|||
field_alias_omitted = len(required_params) == 1 and not embed |
|||
if field_alias_omitted: |
|||
received_body = {field.alias: received_body} |
|||
def get_value_or_default( |
|||
first_item: Union[DefaultPlaceholder, DefaultType], |
|||
*extra_items: Union[DefaultPlaceholder, DefaultType], |
|||
) -> Union[DefaultPlaceholder, DefaultType]: |
|||
""" |
|||
Pass items or `DefaultPlaceholder`s by descending priority. |
|||
|
|||
for field in required_params: |
|||
loc: Tuple[str, ...] |
|||
if field_alias_omitted: |
|||
loc = ("body",) |
|||
else: |
|||
loc = ("body", field.alias) |
|||
The first one to _not_ be a `DefaultPlaceholder` will be returned. |
|||
|
|||
value: Optional[Any] = None |
|||
if received_body is not None: |
|||
if ( |
|||
field.shape in sequence_shapes or field.type_ in sequence_types |
|||
) and isinstance(received_body, FormData): |
|||
value = received_body.getlist(field.alias) |
|||
else: |
|||
try: |
|||
value = received_body.get(field.alias) |
|||
except AttributeError: |
|||
errors.append(get_missing_field_error(loc)) |
|||
continue |
|||
if ( |
|||
value is None |
|||
or (isinstance(field_info, params.Form) and value == "") |
|||
or ( |
|||
isinstance(field_info, params.Form) |
|||
and field.shape in sequence_shapes |
|||
and len(value) == 0 |
|||
) |
|||
): |
|||
if field.required: |
|||
errors.append(get_missing_field_error(loc)) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
if ( |
|||
isinstance(field_info, params.File) |
|||
and lenient_issubclass(field.type_, bytes) |
|||
and isinstance(value, UploadFile) |
|||
): |
|||
value = await value.read() |
|||
elif ( |
|||
field.shape in sequence_shapes |
|||
and isinstance(field_info, params.File) |
|||
and lenient_issubclass(field.type_, bytes) |
|||
and isinstance(value, sequence_types) |
|||
): |
|||
results: List[Union[bytes, str]] = [] |
|||
|
|||
async def process_fn( |
|||
fn: Callable[[], Coroutine[Any, Any, Any]] |
|||
) -> None: |
|||
result = await fn() |
|||
results.append(result) |
|||
|
|||
async with anyio.create_task_group() as tg: |
|||
for sub_value in value: |
|||
tg.start_soon(process_fn, sub_value.read) |
|||
value = sequence_shape_to_type[field.shape](results) |
|||
|
|||
v_, errors_ = field.validate(value, values, loc=loc) |
|||
|
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
|
|||
|
|||
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper: |
|||
missing_field_error = ErrorWrapper(MissingError(), loc=loc) |
|||
return missing_field_error |
|||
|
|||
|
|||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: |
|||
flat_dependant = get_flat_dependant(dependant) |
|||
if not flat_dependant.body_params: |
|||
return None |
|||
first_param = flat_dependant.body_params[0] |
|||
field_info = first_param.field_info |
|||
embed = getattr(field_info, "embed", None) |
|||
body_param_names_set = {param.name for param in flat_dependant.body_params} |
|||
if len(body_param_names_set) == 1 and not embed: |
|||
check_file_field(first_param) |
|||
return first_param |
|||
# If one field requires to embed, all have to be embedded |
|||
# in case a sub-dependency is evaluated with a single unique body field |
|||
# That is combined (embedded) with other body fields |
|||
for param in flat_dependant.body_params: |
|||
setattr(param.field_info, "embed", True) # noqa: B010 |
|||
model_name = "Body_" + name |
|||
BodyModel: Type[BaseModel] = create_model(model_name) |
|||
for f in flat_dependant.body_params: |
|||
BodyModel.__fields__[f.name] = f |
|||
required = any(True for f in flat_dependant.body_params if f.required) |
|||
|
|||
BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None} |
|||
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): |
|||
BodyFieldInfo: Type[params.Body] = params.File |
|||
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): |
|||
BodyFieldInfo = params.Form |
|||
else: |
|||
BodyFieldInfo = params.Body |
|||
|
|||
body_param_media_types = [ |
|||
f.field_info.media_type |
|||
for f in flat_dependant.body_params |
|||
if isinstance(f.field_info, params.Body) |
|||
] |
|||
if len(set(body_param_media_types)) == 1: |
|||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] |
|||
final_field = create_response_field( |
|||
name="body", |
|||
type_=BodyModel, |
|||
required=required, |
|||
alias="body", |
|||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), |
|||
) |
|||
check_file_field(final_field) |
|||
return final_field |
|||
Otherwise, the first item (a `DefaultPlaceholder`) will be returned. |
|||
""" |
|||
items = (first_item,) + extra_items |
|||
for item in items: |
|||
if not isinstance(item, DefaultPlaceholder): |
|||
return item |
|||
return first_item |
|||
|
Loading…
Reference in new issue