|
|
@ -1,207 +1,801 @@ |
|
|
|
import functools |
|
|
|
import re |
|
|
|
import warnings |
|
|
|
from dataclasses import is_dataclass |
|
|
|
from enum import Enum |
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast |
|
|
|
|
|
|
|
import fastapi |
|
|
|
from fastapi.datastructures import DefaultPlaceholder, DefaultType |
|
|
|
from fastapi.openapi.constants import REF_PREFIX |
|
|
|
from pydantic import BaseConfig, BaseModel, create_model |
|
|
|
from pydantic.class_validators import Validator |
|
|
|
from pydantic.fields import FieldInfo, ModelField, UndefinedType |
|
|
|
from pydantic.schema import model_process_schema |
|
|
|
import dataclasses |
|
|
|
import inspect |
|
|
|
from contextlib import contextmanager |
|
|
|
from copy import deepcopy |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Callable, |
|
|
|
Coroutine, |
|
|
|
Dict, |
|
|
|
ForwardRef, |
|
|
|
List, |
|
|
|
Mapping, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
Tuple, |
|
|
|
Type, |
|
|
|
Union, |
|
|
|
cast, |
|
|
|
) |
|
|
|
|
|
|
|
import anyio |
|
|
|
from fastapi import params |
|
|
|
from fastapi.concurrency import ( |
|
|
|
AsyncExitStack, |
|
|
|
asynccontextmanager, |
|
|
|
contextmanager_in_threadpool, |
|
|
|
) |
|
|
|
from fastapi.dependencies.models import Dependant, SecurityRequirement |
|
|
|
from fastapi.dependencies.stubs import GenericTypeStub |
|
|
|
from fastapi.logger import logger |
|
|
|
from fastapi.security.base import SecurityBase |
|
|
|
from fastapi.security.oauth2 import OAuth2, SecurityScopes |
|
|
|
from fastapi.security.open_id_connect_url import OpenIdConnect |
|
|
|
from fastapi.utils import create_response_field, get_path_param_names |
|
|
|
from pydantic import BaseModel, create_model |
|
|
|
from pydantic.error_wrappers import ErrorWrapper |
|
|
|
from pydantic.errors import MissingError |
|
|
|
from pydantic.fields import ( |
|
|
|
SHAPE_FROZENSET, |
|
|
|
SHAPE_LIST, |
|
|
|
SHAPE_SEQUENCE, |
|
|
|
SHAPE_SET, |
|
|
|
SHAPE_SINGLETON, |
|
|
|
SHAPE_TUPLE, |
|
|
|
SHAPE_TUPLE_ELLIPSIS, |
|
|
|
FieldInfo, |
|
|
|
ModelField, |
|
|
|
Required, |
|
|
|
Undefined, |
|
|
|
) |
|
|
|
from pydantic.schema import get_annotation_from_field_info |
|
|
|
from pydantic.typing import evaluate_forwardref |
|
|
|
from pydantic.utils import lenient_issubclass |
|
|
|
from starlette.background import BackgroundTasks |
|
|
|
from starlette.concurrency import run_in_threadpool |
|
|
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile |
|
|
|
from starlette.requests import HTTPConnection, Request |
|
|
|
from starlette.responses import Response |
|
|
|
from starlette.websockets import WebSocket |
|
|
|
|
|
|
|
if TYPE_CHECKING: # pragma: nocover |
|
|
|
from .routing import APIRoute |
|
|
|
sequence_shapes = { |
|
|
|
SHAPE_LIST, |
|
|
|
SHAPE_SET, |
|
|
|
SHAPE_FROZENSET, |
|
|
|
SHAPE_TUPLE, |
|
|
|
SHAPE_SEQUENCE, |
|
|
|
SHAPE_TUPLE_ELLIPSIS, |
|
|
|
} |
|
|
|
sequence_types = (list, set, tuple) |
|
|
|
sequence_shape_to_type = { |
|
|
|
SHAPE_LIST: list, |
|
|
|
SHAPE_SET: set, |
|
|
|
SHAPE_TUPLE: tuple, |
|
|
|
SHAPE_SEQUENCE: list, |
|
|
|
SHAPE_TUPLE_ELLIPSIS: list, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: |
|
|
|
if status_code is None: |
|
|
|
return True |
|
|
|
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 |
|
|
|
if status_code in { |
|
|
|
"default", |
|
|
|
"1XX", |
|
|
|
"2XX", |
|
|
|
"3XX", |
|
|
|
"4XX", |
|
|
|
"5XX", |
|
|
|
}: |
|
|
|
return True |
|
|
|
current_status_code = int(status_code) |
|
|
|
return not (current_status_code < 200 or current_status_code in {204, 304}) |
|
|
|
multipart_not_installed_error = ( |
|
|
|
'Form data requires "python-multipart" to be installed. \n' |
|
|
|
'You can install "python-multipart" with: \n\n' |
|
|
|
"pip install python-multipart\n" |
|
|
|
) |
|
|
|
multipart_incorrect_install_error = ( |
|
|
|
'Form data requires "python-multipart" to be installed. ' |
|
|
|
'It seems you installed "multipart" instead. \n' |
|
|
|
'You can remove "multipart" with: \n\n' |
|
|
|
"pip uninstall multipart\n\n" |
|
|
|
'And then install "python-multipart" with: \n\n' |
|
|
|
"pip install python-multipart\n" |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def check_file_field(field: ModelField) -> None: |
|
|
|
field_info = field.field_info |
|
|
|
if isinstance(field_info, params.Form): |
|
|
|
try: |
|
|
|
# __version__ is available in both multiparts, and can be mocked |
|
|
|
from multipart import __version__ # type: ignore |
|
|
|
|
|
|
|
def get_model_definitions( |
|
|
|
assert __version__ |
|
|
|
try: |
|
|
|
# parse_options_header is only available in the right multipart |
|
|
|
from multipart.multipart import parse_options_header # type: ignore |
|
|
|
|
|
|
|
assert parse_options_header |
|
|
|
except ImportError: |
|
|
|
logger.error(multipart_incorrect_install_error) |
|
|
|
raise RuntimeError(multipart_incorrect_install_error) from None |
|
|
|
except ImportError: |
|
|
|
logger.error(multipart_not_installed_error) |
|
|
|
raise RuntimeError(multipart_not_installed_error) from None |
|
|
|
|
|
|
|
|
|
|
|
def get_param_sub_dependant( |
|
|
|
*, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None |
|
|
|
) -> Dependant: |
|
|
|
depends: params.Depends = param.default |
|
|
|
if depends.dependency: |
|
|
|
dependency = depends.dependency |
|
|
|
else: |
|
|
|
dependency = param.annotation |
|
|
|
return get_sub_dependant( |
|
|
|
depends=depends, |
|
|
|
dependency=dependency, |
|
|
|
path=path, |
|
|
|
name=param.name, |
|
|
|
security_scopes=security_scopes, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: |
|
|
|
assert callable( |
|
|
|
depends.dependency |
|
|
|
), "A parameter-less dependency must have a callable dependency" |
|
|
|
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) |
|
|
|
|
|
|
|
|
|
|
|
def get_sub_dependant( |
|
|
|
*, |
|
|
|
flat_models: Set[Union[Type[BaseModel], Type[Enum]]], |
|
|
|
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], |
|
|
|
) -> Dict[str, Any]: |
|
|
|
definitions: Dict[str, Dict[str, Any]] = {} |
|
|
|
for model in flat_models: |
|
|
|
m_schema, m_definitions, m_nested_models = model_process_schema( |
|
|
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX |
|
|
|
depends: params.Depends, |
|
|
|
dependency: Callable[..., Any], |
|
|
|
path: str, |
|
|
|
name: Optional[str] = None, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
) -> Dependant: |
|
|
|
security_requirement = None |
|
|
|
security_scopes = security_scopes or [] |
|
|
|
if isinstance(depends, params.Security): |
|
|
|
dependency_scopes = depends.scopes |
|
|
|
security_scopes.extend(dependency_scopes) |
|
|
|
if isinstance(dependency, SecurityBase): |
|
|
|
use_scopes: List[str] = [] |
|
|
|
if isinstance(dependency, (OAuth2, OpenIdConnect)): |
|
|
|
use_scopes = security_scopes |
|
|
|
security_requirement = SecurityRequirement( |
|
|
|
security_scheme=dependency, scopes=use_scopes |
|
|
|
) |
|
|
|
definitions.update(m_definitions) |
|
|
|
model_name = model_name_map[model] |
|
|
|
if "description" in m_schema: |
|
|
|
m_schema["description"] = m_schema["description"].split("\f")[0] |
|
|
|
definitions[model_name] = m_schema |
|
|
|
return definitions |
|
|
|
|
|
|
|
|
|
|
|
def get_path_param_names(path: str) -> Set[str]: |
|
|
|
return set(re.findall("{(.*?)}", path)) |
|
|
|
|
|
|
|
|
|
|
|
def create_response_field( |
|
|
|
name: str, |
|
|
|
type_: Type[Any], |
|
|
|
class_validators: Optional[Dict[str, Validator]] = None, |
|
|
|
default: Optional[Any] = None, |
|
|
|
required: Union[bool, UndefinedType] = True, |
|
|
|
model_config: Type[BaseConfig] = BaseConfig, |
|
|
|
field_info: Optional[FieldInfo] = None, |
|
|
|
alias: Optional[str] = None, |
|
|
|
) -> ModelField: |
|
|
|
""" |
|
|
|
Create a new response field. Raises if type_ is invalid. |
|
|
|
""" |
|
|
|
class_validators = class_validators or {} |
|
|
|
field_info = field_info or FieldInfo() |
|
|
|
|
|
|
|
response_field = functools.partial( |
|
|
|
ModelField, |
|
|
|
sub_dependant = get_dependant( |
|
|
|
path=path, |
|
|
|
call=dependency, |
|
|
|
name=name, |
|
|
|
type_=type_, |
|
|
|
class_validators=class_validators, |
|
|
|
default=default, |
|
|
|
required=required, |
|
|
|
model_config=model_config, |
|
|
|
alias=alias, |
|
|
|
security_scopes=security_scopes, |
|
|
|
use_cache=depends.use_cache, |
|
|
|
) |
|
|
|
if security_requirement: |
|
|
|
sub_dependant.security_requirements.append(security_requirement) |
|
|
|
return sub_dependant |
|
|
|
|
|
|
|
try: |
|
|
|
return response_field(field_info=field_info) |
|
|
|
except RuntimeError: |
|
|
|
raise fastapi.exceptions.FastAPIError( |
|
|
|
"Invalid args for response field! Hint: " |
|
|
|
f"check that {type_} is a valid Pydantic field type. " |
|
|
|
"If you are using a return type annotation that is not a valid Pydantic " |
|
|
|
"field (e.g. Union[Response, dict, None]) you can disable generating the " |
|
|
|
"response model from the type annotation with the path operation decorator " |
|
|
|
"parameter response_model=None. Read more: " |
|
|
|
"https://fastapi.tiangolo.com/tutorial/response-model/" |
|
|
|
) from None |
|
|
|
|
|
|
|
|
|
|
|
def create_cloned_field( |
|
|
|
field: ModelField, |
|
|
|
|
|
|
|
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] |
|
|
|
|
|
|
|
|
|
|
|
def get_flat_dependant( |
|
|
|
dependant: Dependant, |
|
|
|
*, |
|
|
|
cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, |
|
|
|
) -> ModelField: |
|
|
|
# _cloned_types has already cloned types, to support recursive models |
|
|
|
if cloned_types is None: |
|
|
|
cloned_types = {} |
|
|
|
original_type = field.type_ |
|
|
|
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): |
|
|
|
original_type = original_type.__pydantic_model__ |
|
|
|
use_type = original_type |
|
|
|
if lenient_issubclass(original_type, BaseModel): |
|
|
|
original_type = cast(Type[BaseModel], original_type) |
|
|
|
use_type = cloned_types.get(original_type) |
|
|
|
if use_type is None: |
|
|
|
use_type = create_model(original_type.__name__, __base__=original_type) |
|
|
|
cloned_types[original_type] = use_type |
|
|
|
for f in original_type.__fields__.values(): |
|
|
|
use_type.__fields__[f.name] = create_cloned_field( |
|
|
|
f, cloned_types=cloned_types |
|
|
|
) |
|
|
|
new_field = create_response_field(name=field.name, type_=use_type) |
|
|
|
new_field.has_alias = field.has_alias |
|
|
|
new_field.alias = field.alias |
|
|
|
new_field.class_validators = field.class_validators |
|
|
|
new_field.default = field.default |
|
|
|
new_field.required = field.required |
|
|
|
new_field.model_config = field.model_config |
|
|
|
new_field.field_info = field.field_info |
|
|
|
new_field.allow_none = field.allow_none |
|
|
|
new_field.validate_always = field.validate_always |
|
|
|
skip_repeats: bool = False, |
|
|
|
visited: Optional[List[CacheKey]] = None, |
|
|
|
) -> Dependant: |
|
|
|
if visited is None: |
|
|
|
visited = [] |
|
|
|
visited.append(dependant.cache_key) |
|
|
|
|
|
|
|
flat_dependant = Dependant( |
|
|
|
path_params=dependant.path_params.copy(), |
|
|
|
query_params=dependant.query_params.copy(), |
|
|
|
header_params=dependant.header_params.copy(), |
|
|
|
cookie_params=dependant.cookie_params.copy(), |
|
|
|
body_params=dependant.body_params.copy(), |
|
|
|
security_schemes=dependant.security_requirements.copy(), |
|
|
|
use_cache=dependant.use_cache, |
|
|
|
path=dependant.path, |
|
|
|
) |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
|
if skip_repeats and sub_dependant.cache_key in visited: |
|
|
|
continue |
|
|
|
flat_sub = get_flat_dependant( |
|
|
|
sub_dependant, skip_repeats=skip_repeats, visited=visited |
|
|
|
) |
|
|
|
flat_dependant.path_params.extend(flat_sub.path_params) |
|
|
|
flat_dependant.query_params.extend(flat_sub.query_params) |
|
|
|
flat_dependant.header_params.extend(flat_sub.header_params) |
|
|
|
flat_dependant.cookie_params.extend(flat_sub.cookie_params) |
|
|
|
flat_dependant.body_params.extend(flat_sub.body_params) |
|
|
|
flat_dependant.security_requirements.extend(flat_sub.security_requirements) |
|
|
|
return flat_dependant |
|
|
|
|
|
|
|
|
|
|
|
def get_flat_params(dependant: Dependant) -> List[ModelField]: |
|
|
|
flat_dependant = get_flat_dependant(dependant, skip_repeats=True) |
|
|
|
return ( |
|
|
|
flat_dependant.path_params |
|
|
|
+ flat_dependant.query_params |
|
|
|
+ flat_dependant.header_params |
|
|
|
+ flat_dependant.cookie_params |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def is_scalar_field(field: ModelField) -> bool: |
|
|
|
field_info = field.field_info |
|
|
|
if not ( |
|
|
|
field.shape == SHAPE_SINGLETON |
|
|
|
and not lenient_issubclass(field.type_, BaseModel) |
|
|
|
and not lenient_issubclass(field.type_, sequence_types + (dict,)) |
|
|
|
and not dataclasses.is_dataclass(field.type_) |
|
|
|
and not isinstance(field_info, params.Body) |
|
|
|
): |
|
|
|
return False |
|
|
|
if field.sub_fields: |
|
|
|
new_field.sub_fields = [ |
|
|
|
create_cloned_field(sub_field, cloned_types=cloned_types) |
|
|
|
for sub_field in field.sub_fields |
|
|
|
] |
|
|
|
if field.key_field: |
|
|
|
new_field.key_field = create_cloned_field( |
|
|
|
field.key_field, cloned_types=cloned_types |
|
|
|
if not all(is_scalar_field(f) for f in field.sub_fields): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def is_scalar_sequence_field(field: ModelField) -> bool: |
|
|
|
if (field.shape in sequence_shapes) and not lenient_issubclass( |
|
|
|
field.type_, BaseModel |
|
|
|
): |
|
|
|
if field.sub_fields is not None: |
|
|
|
for sub_field in field.sub_fields: |
|
|
|
if not is_scalar_field(sub_field): |
|
|
|
return False |
|
|
|
return True |
|
|
|
if lenient_issubclass(field.type_, sequence_types): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def is_generic_type(obj: Any) -> bool: |
|
|
|
return hasattr(obj, "__origin__") and hasattr(obj, "__args__") |
|
|
|
|
|
|
|
|
|
|
|
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: |
|
|
|
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} |
|
|
|
if is_generic_type(annotation): |
|
|
|
args = tuple( |
|
|
|
substitute_generic_type(arg, typevars) for arg in annotation.__args__ |
|
|
|
) |
|
|
|
new_field.validators = field.validators |
|
|
|
new_field.pre_validators = field.pre_validators |
|
|
|
new_field.post_validators = field.post_validators |
|
|
|
new_field.parse_json = field.parse_json |
|
|
|
new_field.shape = field.shape |
|
|
|
new_field.populate_validators() |
|
|
|
return new_field |
|
|
|
|
|
|
|
|
|
|
|
def generate_operation_id_for_path( |
|
|
|
*, name: str, path: str, method: str |
|
|
|
) -> str: # pragma: nocover |
|
|
|
warnings.warn( |
|
|
|
"fastapi.utils.generate_operation_id_for_path() was deprecated, " |
|
|
|
"it is not used internally, and will be removed soon", |
|
|
|
DeprecationWarning, |
|
|
|
stacklevel=2, |
|
|
|
annotation = collection_shells.get(annotation.__origin__, annotation) |
|
|
|
return annotation[args] |
|
|
|
return typevars.get(annotation.__name__, annotation) |
|
|
|
|
|
|
|
|
|
|
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: |
|
|
|
typevars = None |
|
|
|
if is_generic_type(call): |
|
|
|
stub: GenericTypeStub = cast(GenericTypeStub, call) |
|
|
|
typevars = { |
|
|
|
typevar.__name__: value |
|
|
|
for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__) |
|
|
|
} |
|
|
|
call = stub.__origin__.__init__ # type: ignore |
|
|
|
|
|
|
|
signature = inspect.signature(call) |
|
|
|
globalns = getattr(call, "__globals__", {}) |
|
|
|
|
|
|
|
typed_params = [ |
|
|
|
inspect.Parameter( |
|
|
|
name=param.name, |
|
|
|
kind=param.kind, |
|
|
|
default=param.default, |
|
|
|
annotation=get_typed_annotation(param.annotation, globalns, typevars), |
|
|
|
) |
|
|
|
for param in signature.parameters.values() |
|
|
|
if param.name != 'self' |
|
|
|
] |
|
|
|
typed_signature = inspect.Signature(typed_params) |
|
|
|
return typed_signature |
|
|
|
|
|
|
|
|
|
|
|
def get_typed_annotation( |
|
|
|
annotation: Any, |
|
|
|
globalns: Dict[str, Any], |
|
|
|
typevars: Optional[Dict[str, type]] = None, |
|
|
|
) -> Any: |
|
|
|
if isinstance(annotation, str): |
|
|
|
annotation = ForwardRef(annotation) |
|
|
|
annotation = evaluate_forwardref(annotation, globalns, globalns) |
|
|
|
if typevars: |
|
|
|
annotation = substitute_generic_type(annotation, typevars) |
|
|
|
return annotation |
|
|
|
|
|
|
|
|
|
|
|
def get_typed_return_annotation(call: Callable[..., Any]) -> Any: |
|
|
|
signature = inspect.signature(call) |
|
|
|
annotation = signature.return_annotation |
|
|
|
|
|
|
|
if annotation is inspect.Signature.empty: |
|
|
|
return None |
|
|
|
|
|
|
|
globalns = getattr(call, "__globals__", {}) |
|
|
|
return get_typed_annotation(annotation, globalns) |
|
|
|
|
|
|
|
|
|
|
|
def get_dependant( |
|
|
|
*, |
|
|
|
path: str, |
|
|
|
call: Callable[..., Any], |
|
|
|
name: Optional[str] = None, |
|
|
|
security_scopes: Optional[List[str]] = None, |
|
|
|
use_cache: bool = True, |
|
|
|
) -> Dependant: |
|
|
|
path_param_names = get_path_param_names(path) |
|
|
|
endpoint_signature = get_typed_signature(call) |
|
|
|
signature_params = endpoint_signature.parameters |
|
|
|
dependant = Dependant( |
|
|
|
call=call, |
|
|
|
name=name, |
|
|
|
path=path, |
|
|
|
security_scopes=security_scopes, |
|
|
|
use_cache=use_cache, |
|
|
|
) |
|
|
|
operation_id = name + path |
|
|
|
operation_id = re.sub(r"\W", "_", operation_id) |
|
|
|
operation_id = operation_id + "_" + method.lower() |
|
|
|
return operation_id |
|
|
|
for param_name, param in signature_params.items(): |
|
|
|
if isinstance(param.default, params.Depends): |
|
|
|
sub_dependant = get_param_sub_dependant( |
|
|
|
param=param, path=path, security_scopes=security_scopes |
|
|
|
) |
|
|
|
dependant.dependencies.append(sub_dependant) |
|
|
|
continue |
|
|
|
if add_non_field_param_to_dependency(param=param, dependant=dependant): |
|
|
|
continue |
|
|
|
param_field = get_param_field( |
|
|
|
param=param, default_field_info=params.Query, param_name=param_name |
|
|
|
) |
|
|
|
if param_name in path_param_names: |
|
|
|
assert is_scalar_field( |
|
|
|
field=param_field |
|
|
|
), "Path params must be of one of the supported types" |
|
|
|
ignore_default = not isinstance(param.default, params.Path) |
|
|
|
param_field = get_param_field( |
|
|
|
param=param, |
|
|
|
param_name=param_name, |
|
|
|
default_field_info=params.Path, |
|
|
|
force_type=params.ParamTypes.path, |
|
|
|
ignore_default=ignore_default, |
|
|
|
) |
|
|
|
add_param_to_fields(field=param_field, dependant=dependant) |
|
|
|
elif is_scalar_field(field=param_field): |
|
|
|
add_param_to_fields(field=param_field, dependant=dependant) |
|
|
|
elif isinstance( |
|
|
|
param.default, (params.Query, params.Header) |
|
|
|
) and is_scalar_sequence_field(param_field): |
|
|
|
add_param_to_fields(field=param_field, dependant=dependant) |
|
|
|
else: |
|
|
|
field_info = param_field.field_info |
|
|
|
assert isinstance( |
|
|
|
field_info, params.Body |
|
|
|
), f"Param: {param_field.name} can only be a request body, using Body()" |
|
|
|
dependant.body_params.append(param_field) |
|
|
|
return dependant |
|
|
|
|
|
|
|
|
|
|
|
def generate_unique_id(route: "APIRoute") -> str: |
|
|
|
operation_id = route.name + route.path_format |
|
|
|
operation_id = re.sub(r"\W", "_", operation_id) |
|
|
|
assert route.methods |
|
|
|
operation_id = operation_id + "_" + list(route.methods)[0].lower() |
|
|
|
return operation_id |
|
|
|
def add_non_field_param_to_dependency( |
|
|
|
*, param: inspect.Parameter, dependant: Dependant |
|
|
|
) -> Optional[bool]: |
|
|
|
if lenient_issubclass(param.annotation, Request): |
|
|
|
dependant.request_param_name = param.name |
|
|
|
return True |
|
|
|
elif lenient_issubclass(param.annotation, WebSocket): |
|
|
|
dependant.websocket_param_name = param.name |
|
|
|
return True |
|
|
|
elif lenient_issubclass(param.annotation, HTTPConnection): |
|
|
|
dependant.http_connection_param_name = param.name |
|
|
|
return True |
|
|
|
elif lenient_issubclass(param.annotation, Response): |
|
|
|
dependant.response_param_name = param.name |
|
|
|
return True |
|
|
|
elif lenient_issubclass(param.annotation, BackgroundTasks): |
|
|
|
dependant.background_tasks_param_name = param.name |
|
|
|
return True |
|
|
|
elif lenient_issubclass(param.annotation, SecurityScopes): |
|
|
|
dependant.security_scopes_param_name = param.name |
|
|
|
return True |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def get_param_field( |
|
|
|
*, |
|
|
|
param: inspect.Parameter, |
|
|
|
param_name: str, |
|
|
|
default_field_info: Type[params.Param] = params.Param, |
|
|
|
force_type: Optional[params.ParamTypes] = None, |
|
|
|
ignore_default: bool = False, |
|
|
|
) -> ModelField: |
|
|
|
default_value: Any = Undefined |
|
|
|
had_schema = False |
|
|
|
if not param.default == param.empty and ignore_default is False: |
|
|
|
default_value = param.default |
|
|
|
if isinstance(default_value, FieldInfo): |
|
|
|
had_schema = True |
|
|
|
field_info = default_value |
|
|
|
default_value = field_info.default |
|
|
|
if ( |
|
|
|
isinstance(field_info, params.Param) |
|
|
|
and getattr(field_info, "in_", None) is None |
|
|
|
): |
|
|
|
field_info.in_ = default_field_info.in_ |
|
|
|
if force_type: |
|
|
|
field_info.in_ = force_type # type: ignore |
|
|
|
else: |
|
|
|
field_info = default_field_info(default=default_value) |
|
|
|
required = True |
|
|
|
if default_value is Required or ignore_default: |
|
|
|
required = True |
|
|
|
default_value = None |
|
|
|
elif default_value is not Undefined: |
|
|
|
required = False |
|
|
|
annotation: Any = Any |
|
|
|
if not param.annotation == param.empty: |
|
|
|
annotation = param.annotation |
|
|
|
annotation = get_annotation_from_field_info(annotation, field_info, param_name) |
|
|
|
if not field_info.alias and getattr(field_info, "convert_underscores", None): |
|
|
|
alias = param.name.replace("_", "-") |
|
|
|
else: |
|
|
|
alias = field_info.alias or param.name |
|
|
|
field = create_response_field( |
|
|
|
name=param.name, |
|
|
|
type_=annotation, |
|
|
|
default=default_value, |
|
|
|
alias=alias, |
|
|
|
required=required, |
|
|
|
field_info=field_info, |
|
|
|
) |
|
|
|
if not had_schema and not is_scalar_field(field=field): |
|
|
|
field.field_info = params.Body(field_info.default) |
|
|
|
if not had_schema and lenient_issubclass(field.type_, UploadFile): |
|
|
|
field.field_info = params.File(field_info.default) |
|
|
|
|
|
|
|
return field |
|
|
|
|
|
|
|
|
|
|
|
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: |
|
|
|
field_info = cast(params.Param, field.field_info) |
|
|
|
if field_info.in_ == params.ParamTypes.path: |
|
|
|
dependant.path_params.append(field) |
|
|
|
elif field_info.in_ == params.ParamTypes.query: |
|
|
|
dependant.query_params.append(field) |
|
|
|
elif field_info.in_ == params.ParamTypes.header: |
|
|
|
dependant.header_params.append(field) |
|
|
|
else: |
|
|
|
assert ( |
|
|
|
field_info.in_ == params.ParamTypes.cookie |
|
|
|
), f"non-body parameters must be in path, query, header or cookie: {field.name}" |
|
|
|
dependant.cookie_params.append(field) |
|
|
|
|
|
|
|
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: |
|
|
|
for key, value in update_dict.items(): |
|
|
|
|
|
|
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool: |
|
|
|
if inspect.isroutine(call): |
|
|
|
return inspect.iscoroutinefunction(call) |
|
|
|
if inspect.isclass(call): |
|
|
|
return False |
|
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|
|
|
return inspect.iscoroutinefunction(dunder_call) |
|
|
|
|
|
|
|
|
|
|
|
def is_async_gen_callable(call: Callable[..., Any]) -> bool: |
|
|
|
if inspect.isasyncgenfunction(call): |
|
|
|
return True |
|
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|
|
|
return inspect.isasyncgenfunction(dunder_call) |
|
|
|
|
|
|
|
|
|
|
|
def is_gen_callable(call: Callable[..., Any]) -> bool: |
|
|
|
if inspect.isgeneratorfunction(call): |
|
|
|
return True |
|
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004 |
|
|
|
return inspect.isgeneratorfunction(dunder_call) |
|
|
|
|
|
|
|
|
|
|
|
async def solve_generator( |
|
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] |
|
|
|
) -> Any: |
|
|
|
if is_gen_callable(call): |
|
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) |
|
|
|
elif is_async_gen_callable(call): |
|
|
|
cm = asynccontextmanager(call)(**sub_values) |
|
|
|
return await stack.enter_async_context(cm) |
|
|
|
|
|
|
|
|
|
|
|
async def solve_dependencies( |
|
|
|
*, |
|
|
|
request: Union[Request, WebSocket], |
|
|
|
dependant: Dependant, |
|
|
|
body: Optional[Union[Dict[str, Any], FormData]] = None, |
|
|
|
background_tasks: Optional[BackgroundTasks] = None, |
|
|
|
response: Optional[Response] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, |
|
|
|
) -> Tuple[ |
|
|
|
Dict[str, Any], |
|
|
|
List[ErrorWrapper], |
|
|
|
Optional[BackgroundTasks], |
|
|
|
Response, |
|
|
|
Dict[Tuple[Callable[..., Any], Tuple[str]], Any], |
|
|
|
]: |
|
|
|
values: Dict[str, Any] = {} |
|
|
|
errors: List[ErrorWrapper] = [] |
|
|
|
if response is None: |
|
|
|
response = Response() |
|
|
|
del response.headers["content-length"] |
|
|
|
response.status_code = None # type: ignore |
|
|
|
dependency_cache = dependency_cache or {} |
|
|
|
sub_dependant: Dependant |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
|
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) |
|
|
|
sub_dependant.cache_key = cast( |
|
|
|
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key |
|
|
|
) |
|
|
|
call = sub_dependant.call |
|
|
|
use_sub_dependant = sub_dependant |
|
|
|
if ( |
|
|
|
key in main_dict |
|
|
|
and isinstance(main_dict[key], dict) |
|
|
|
and isinstance(value, dict) |
|
|
|
dependency_overrides_provider |
|
|
|
and dependency_overrides_provider.dependency_overrides |
|
|
|
): |
|
|
|
deep_dict_update(main_dict[key], value) |
|
|
|
elif ( |
|
|
|
key in main_dict |
|
|
|
and isinstance(main_dict[key], list) |
|
|
|
and isinstance(update_dict[key], list) |
|
|
|
original_call = sub_dependant.call |
|
|
|
call = getattr( |
|
|
|
dependency_overrides_provider, "dependency_overrides", {} |
|
|
|
).get(original_call, original_call) |
|
|
|
use_path: str = sub_dependant.path # type: ignore |
|
|
|
use_sub_dependant = get_dependant( |
|
|
|
path=use_path, |
|
|
|
call=call, |
|
|
|
name=sub_dependant.name, |
|
|
|
security_scopes=sub_dependant.security_scopes, |
|
|
|
) |
|
|
|
|
|
|
|
solved_result = await solve_dependencies( |
|
|
|
request=request, |
|
|
|
dependant=use_sub_dependant, |
|
|
|
body=body, |
|
|
|
background_tasks=background_tasks, |
|
|
|
response=response, |
|
|
|
dependency_overrides_provider=dependency_overrides_provider, |
|
|
|
dependency_cache=dependency_cache, |
|
|
|
) |
|
|
|
( |
|
|
|
sub_values, |
|
|
|
sub_errors, |
|
|
|
background_tasks, |
|
|
|
_, # the subdependency returns the same response we have |
|
|
|
sub_dependency_cache, |
|
|
|
) = solved_result |
|
|
|
dependency_cache.update(sub_dependency_cache) |
|
|
|
if sub_errors: |
|
|
|
errors.extend(sub_errors) |
|
|
|
continue |
|
|
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: |
|
|
|
solved = dependency_cache[sub_dependant.cache_key] |
|
|
|
elif is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
stack = request.scope.get("fastapi_astack") |
|
|
|
assert isinstance(stack, AsyncExitStack) |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, stack=stack, sub_values=sub_values |
|
|
|
) |
|
|
|
elif is_coroutine_callable(call): |
|
|
|
solved = await call(**sub_values) |
|
|
|
else: |
|
|
|
solved = await run_in_threadpool(call, **sub_values) |
|
|
|
if sub_dependant.name is not None: |
|
|
|
values[sub_dependant.name] = solved |
|
|
|
if sub_dependant.cache_key not in dependency_cache: |
|
|
|
dependency_cache[sub_dependant.cache_key] = solved |
|
|
|
path_values, path_errors = request_params_to_args( |
|
|
|
dependant.path_params, request.path_params |
|
|
|
) |
|
|
|
query_values, query_errors = request_params_to_args( |
|
|
|
dependant.query_params, request.query_params |
|
|
|
) |
|
|
|
header_values, header_errors = request_params_to_args( |
|
|
|
dependant.header_params, request.headers |
|
|
|
) |
|
|
|
cookie_values, cookie_errors = request_params_to_args( |
|
|
|
dependant.cookie_params, request.cookies |
|
|
|
) |
|
|
|
values.update(path_values) |
|
|
|
values.update(query_values) |
|
|
|
values.update(header_values) |
|
|
|
values.update(cookie_values) |
|
|
|
errors += path_errors + query_errors + header_errors + cookie_errors |
|
|
|
if dependant.body_params: |
|
|
|
( |
|
|
|
body_values, |
|
|
|
body_errors, |
|
|
|
) = await request_body_to_args( # body_params checked above |
|
|
|
required_params=dependant.body_params, received_body=body |
|
|
|
) |
|
|
|
values.update(body_values) |
|
|
|
errors.extend(body_errors) |
|
|
|
if dependant.http_connection_param_name: |
|
|
|
values[dependant.http_connection_param_name] = request |
|
|
|
if dependant.request_param_name and isinstance(request, Request): |
|
|
|
values[dependant.request_param_name] = request |
|
|
|
elif dependant.websocket_param_name and isinstance(request, WebSocket): |
|
|
|
values[dependant.websocket_param_name] = request |
|
|
|
if dependant.background_tasks_param_name: |
|
|
|
if background_tasks is None: |
|
|
|
background_tasks = BackgroundTasks() |
|
|
|
values[dependant.background_tasks_param_name] = background_tasks |
|
|
|
if dependant.response_param_name: |
|
|
|
values[dependant.response_param_name] = response |
|
|
|
if dependant.security_scopes_param_name: |
|
|
|
values[dependant.security_scopes_param_name] = SecurityScopes( |
|
|
|
scopes=dependant.security_scopes |
|
|
|
) |
|
|
|
return values, errors, background_tasks, response, dependency_cache |
|
|
|
|
|
|
|
|
|
|
|
def request_params_to_args( |
|
|
|
required_params: Sequence[ModelField], |
|
|
|
received_params: Union[Mapping[str, Any], QueryParams, Headers], |
|
|
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|
|
|
values = {} |
|
|
|
errors = [] |
|
|
|
for field in required_params: |
|
|
|
if is_scalar_sequence_field(field) and isinstance( |
|
|
|
received_params, (QueryParams, Headers) |
|
|
|
): |
|
|
|
main_dict[key] = main_dict[key] + update_dict[key] |
|
|
|
value = received_params.getlist(field.alias) or field.default |
|
|
|
else: |
|
|
|
main_dict[key] = value |
|
|
|
value = received_params.get(field.alias) |
|
|
|
field_info = field.field_info |
|
|
|
assert isinstance( |
|
|
|
field_info, params.Param |
|
|
|
), "Params must be subclasses of Param" |
|
|
|
if value is None: |
|
|
|
if field.required: |
|
|
|
errors.append( |
|
|
|
ErrorWrapper( |
|
|
|
MissingError(), loc=(field_info.in_.value, field.alias) |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
values[field.name] = deepcopy(field.default) |
|
|
|
continue |
|
|
|
v_, errors_ = field.validate( |
|
|
|
value, values, loc=(field_info.in_.value, field.alias) |
|
|
|
) |
|
|
|
if isinstance(errors_, ErrorWrapper): |
|
|
|
errors.append(errors_) |
|
|
|
elif isinstance(errors_, list): |
|
|
|
errors.extend(errors_) |
|
|
|
else: |
|
|
|
values[field.name] = v_ |
|
|
|
return values, errors |
|
|
|
|
|
|
|
|
|
|
|
async def request_body_to_args( |
|
|
|
required_params: List[ModelField], |
|
|
|
received_body: Optional[Union[Dict[str, Any], FormData]], |
|
|
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|
|
|
values = {} |
|
|
|
errors = [] |
|
|
|
if required_params: |
|
|
|
field = required_params[0] |
|
|
|
field_info = field.field_info |
|
|
|
embed = getattr(field_info, "embed", None) |
|
|
|
field_alias_omitted = len(required_params) == 1 and not embed |
|
|
|
if field_alias_omitted: |
|
|
|
received_body = {field.alias: received_body} |
|
|
|
|
|
|
|
def get_value_or_default( |
|
|
|
first_item: Union[DefaultPlaceholder, DefaultType], |
|
|
|
*extra_items: Union[DefaultPlaceholder, DefaultType], |
|
|
|
) -> Union[DefaultPlaceholder, DefaultType]: |
|
|
|
""" |
|
|
|
Pass items or `DefaultPlaceholder`s by descending priority. |
|
|
|
for field in required_params: |
|
|
|
loc: Tuple[str, ...] |
|
|
|
if field_alias_omitted: |
|
|
|
loc = ("body",) |
|
|
|
else: |
|
|
|
loc = ("body", field.alias) |
|
|
|
|
|
|
|
The first one to _not_ be a `DefaultPlaceholder` will be returned. |
|
|
|
value: Optional[Any] = None |
|
|
|
if received_body is not None: |
|
|
|
if ( |
|
|
|
field.shape in sequence_shapes or field.type_ in sequence_types |
|
|
|
) and isinstance(received_body, FormData): |
|
|
|
value = received_body.getlist(field.alias) |
|
|
|
else: |
|
|
|
try: |
|
|
|
value = received_body.get(field.alias) |
|
|
|
except AttributeError: |
|
|
|
errors.append(get_missing_field_error(loc)) |
|
|
|
continue |
|
|
|
if ( |
|
|
|
value is None |
|
|
|
or (isinstance(field_info, params.Form) and value == "") |
|
|
|
or ( |
|
|
|
isinstance(field_info, params.Form) |
|
|
|
and field.shape in sequence_shapes |
|
|
|
and len(value) == 0 |
|
|
|
) |
|
|
|
): |
|
|
|
if field.required: |
|
|
|
errors.append(get_missing_field_error(loc)) |
|
|
|
else: |
|
|
|
values[field.name] = deepcopy(field.default) |
|
|
|
continue |
|
|
|
if ( |
|
|
|
isinstance(field_info, params.File) |
|
|
|
and lenient_issubclass(field.type_, bytes) |
|
|
|
and isinstance(value, UploadFile) |
|
|
|
): |
|
|
|
value = await value.read() |
|
|
|
elif ( |
|
|
|
field.shape in sequence_shapes |
|
|
|
and isinstance(field_info, params.File) |
|
|
|
and lenient_issubclass(field.type_, bytes) |
|
|
|
and isinstance(value, sequence_types) |
|
|
|
): |
|
|
|
results: List[Union[bytes, str]] = [] |
|
|
|
|
|
|
|
async def process_fn( |
|
|
|
fn: Callable[[], Coroutine[Any, Any, Any]] |
|
|
|
) -> None: |
|
|
|
result = await fn() |
|
|
|
results.append(result) |
|
|
|
|
|
|
|
async with anyio.create_task_group() as tg: |
|
|
|
for sub_value in value: |
|
|
|
tg.start_soon(process_fn, sub_value.read) |
|
|
|
value = sequence_shape_to_type[field.shape](results) |
|
|
|
|
|
|
|
v_, errors_ = field.validate(value, values, loc=loc) |
|
|
|
|
|
|
|
if isinstance(errors_, ErrorWrapper): |
|
|
|
errors.append(errors_) |
|
|
|
elif isinstance(errors_, list): |
|
|
|
errors.extend(errors_) |
|
|
|
else: |
|
|
|
values[field.name] = v_ |
|
|
|
return values, errors |
|
|
|
|
|
|
|
|
|
|
|
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper: |
|
|
|
missing_field_error = ErrorWrapper(MissingError(), loc=loc) |
|
|
|
return missing_field_error |
|
|
|
|
|
|
|
|
|
|
|
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: |
|
|
|
flat_dependant = get_flat_dependant(dependant) |
|
|
|
if not flat_dependant.body_params: |
|
|
|
return None |
|
|
|
first_param = flat_dependant.body_params[0] |
|
|
|
field_info = first_param.field_info |
|
|
|
embed = getattr(field_info, "embed", None) |
|
|
|
body_param_names_set = {param.name for param in flat_dependant.body_params} |
|
|
|
if len(body_param_names_set) == 1 and not embed: |
|
|
|
check_file_field(first_param) |
|
|
|
return first_param |
|
|
|
# If one field requires to embed, all have to be embedded |
|
|
|
# in case a sub-dependency is evaluated with a single unique body field |
|
|
|
# That is combined (embedded) with other body fields |
|
|
|
for param in flat_dependant.body_params: |
|
|
|
setattr(param.field_info, "embed", True) # noqa: B010 |
|
|
|
model_name = "Body_" + name |
|
|
|
BodyModel: Type[BaseModel] = create_model(model_name) |
|
|
|
for f in flat_dependant.body_params: |
|
|
|
BodyModel.__fields__[f.name] = f |
|
|
|
required = any(True for f in flat_dependant.body_params if f.required) |
|
|
|
|
|
|
|
BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None} |
|
|
|
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): |
|
|
|
BodyFieldInfo: Type[params.Body] = params.File |
|
|
|
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): |
|
|
|
BodyFieldInfo = params.Form |
|
|
|
else: |
|
|
|
BodyFieldInfo = params.Body |
|
|
|
|
|
|
|
body_param_media_types = [ |
|
|
|
f.field_info.media_type |
|
|
|
for f in flat_dependant.body_params |
|
|
|
if isinstance(f.field_info, params.Body) |
|
|
|
] |
|
|
|
if len(set(body_param_media_types)) == 1: |
|
|
|
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] |
|
|
|
final_field = create_response_field( |
|
|
|
name="body", |
|
|
|
type_=BodyModel, |
|
|
|
required=required, |
|
|
|
alias="body", |
|
|
|
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), |
|
|
|
) |
|
|
|
check_file_field(final_field) |
|
|
|
return final_field |
|
|
|
|
|
|
|
Otherwise, the first item (a `DefaultPlaceholder`) will be returned. |
|
|
|
""" |
|
|
|
items = (first_item,) + extra_items |
|
|
|
for item in items: |
|
|
|
if not isinstance(item, DefaultPlaceholder): |
|
|
|
return item |
|
|
|
return first_item |
|
|
|