diff --git a/fastapi/stubs.py b/fastapi/stubs.py new file mode 100644 index 000000000..b9ab48a69 --- /dev/null +++ b/fastapi/stubs.py @@ -0,0 +1,9 @@ +from typing import Any, Tuple + + +class GenericTypeStub: + class OriginTypeStub: + __parameters__: Tuple[Any] + + __origin__: OriginTypeStub + __args__: Tuple[Any] diff --git a/fastapi/utils.py b/fastapi/utils.py index 391c47d81..1bd2de317 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,207 +1,801 @@ -import functools -import re -import warnings -from dataclasses import is_dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast - -import fastapi -from fastapi.datastructures import DefaultPlaceholder, DefaultType -from fastapi.openapi.constants import REF_PREFIX -from pydantic import BaseConfig, BaseModel, create_model -from pydantic.class_validators import Validator -from pydantic.fields import FieldInfo, ModelField, UndefinedType -from pydantic.schema import model_process_schema +import dataclasses +import inspect +from contextlib import contextmanager +from copy import deepcopy +from typing import ( + Any, + Callable, + Coroutine, + Dict, + ForwardRef, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import anyio +from fastapi import params +from fastapi.concurrency import ( + AsyncExitStack, + asynccontextmanager, + contextmanager_in_threadpool, +) +from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.stubs import GenericTypeStub +from fastapi.logger import logger +from fastapi.security.base import SecurityBase +from fastapi.security.oauth2 import OAuth2, SecurityScopes +from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.utils import create_response_field, get_path_param_names +from pydantic import BaseModel, create_model +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import ( + SHAPE_FROZENSET, + SHAPE_LIST, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_SINGLETON, + SHAPE_TUPLE, + SHAPE_TUPLE_ELLIPSIS, + FieldInfo, + ModelField, + Required, + Undefined, +) +from pydantic.schema import get_annotation_from_field_info +from pydantic.typing import evaluate_forwardref from pydantic.utils import lenient_issubclass +from starlette.background import BackgroundTasks +from starlette.concurrency import run_in_threadpool +from starlette.datastructures import FormData, Headers, QueryParams, UploadFile +from starlette.requests import HTTPConnection, Request +from starlette.responses import Response +from starlette.websockets import WebSocket -if TYPE_CHECKING: # pragma: nocover - from .routing import APIRoute +sequence_shapes = { + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, +} +sequence_types = (list, set, tuple) +sequence_shape_to_type = { + SHAPE_LIST: list, + SHAPE_SET: set, + SHAPE_TUPLE: tuple, + SHAPE_SEQUENCE: list, + SHAPE_TUPLE_ELLIPSIS: list, +} -def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: - if status_code is None: - return True - # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 - if status_code in { - "default", - "1XX", - "2XX", - "3XX", - "4XX", - "5XX", - }: - return True - current_status_code = int(status_code) - return not (current_status_code < 200 or current_status_code in {204, 304}) +multipart_not_installed_error = ( + 'Form data requires "python-multipart" to be installed. \n' + 'You can install "python-multipart" with: \n\n' + "pip install python-multipart\n" +) +multipart_incorrect_install_error = ( + 'Form data requires "python-multipart" to be installed. ' + 'It seems you installed "multipart" instead. \n' + 'You can remove "multipart" with: \n\n' + "pip uninstall multipart\n\n" + 'And then install "python-multipart" with: \n\n' + "pip install python-multipart\n" +) + +def check_file_field(field: ModelField) -> None: + field_info = field.field_info + if isinstance(field_info, params.Form): + try: + # __version__ is available in both multiparts, and can be mocked + from multipart import __version__ # type: ignore -def get_model_definitions( + assert __version__ + try: + # parse_options_header is only available in the right multipart + from multipart.multipart import parse_options_header # type: ignore + + assert parse_options_header + except ImportError: + logger.error(multipart_incorrect_install_error) + raise RuntimeError(multipart_incorrect_install_error) from None + except ImportError: + logger.error(multipart_not_installed_error) + raise RuntimeError(multipart_not_installed_error) from None + + +def get_param_sub_dependant( + *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None +) -> Dependant: + depends: params.Depends = param.default + if depends.dependency: + dependency = depends.dependency + else: + dependency = param.annotation + return get_sub_dependant( + depends=depends, + dependency=dependency, + path=path, + name=param.name, + security_scopes=security_scopes, + ) + + +def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" + return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + + +def get_sub_dependant( *, - flat_models: Set[Union[Type[BaseModel], Type[Enum]]], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], -) -> Dict[str, Any]: - definitions: Dict[str, Dict[str, Any]] = {} - for model in flat_models: - m_schema, m_definitions, m_nested_models = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=REF_PREFIX + depends: params.Depends, + dependency: Callable[..., Any], + path: str, + name: Optional[str] = None, + security_scopes: Optional[List[str]] = None, +) -> Dependant: + security_requirement = None + security_scopes = security_scopes or [] + if isinstance(depends, params.Security): + dependency_scopes = depends.scopes + security_scopes.extend(dependency_scopes) + if isinstance(dependency, SecurityBase): + use_scopes: List[str] = [] + if isinstance(dependency, (OAuth2, OpenIdConnect)): + use_scopes = security_scopes + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes ) - definitions.update(m_definitions) - model_name = model_name_map[model] - if "description" in m_schema: - m_schema["description"] = m_schema["description"].split("\f")[0] - definitions[model_name] = m_schema - return definitions - - -def get_path_param_names(path: str) -> Set[str]: - return set(re.findall("{(.*?)}", path)) - - -def create_response_field( - name: str, - type_: Type[Any], - class_validators: Optional[Dict[str, Validator]] = None, - default: Optional[Any] = None, - required: Union[bool, UndefinedType] = True, - model_config: Type[BaseConfig] = BaseConfig, - field_info: Optional[FieldInfo] = None, - alias: Optional[str] = None, -) -> ModelField: - """ - Create a new response field. Raises if type_ is invalid. - """ - class_validators = class_validators or {} - field_info = field_info or FieldInfo() - - response_field = functools.partial( - ModelField, + sub_dependant = get_dependant( + path=path, + call=dependency, name=name, - type_=type_, - class_validators=class_validators, - default=default, - required=required, - model_config=model_config, - alias=alias, + security_scopes=security_scopes, + use_cache=depends.use_cache, ) + if security_requirement: + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant - try: - return response_field(field_info=field_info) - except RuntimeError: - raise fastapi.exceptions.FastAPIError( - "Invalid args for response field! Hint: " - f"check that {type_} is a valid Pydantic field type. " - "If you are using a return type annotation that is not a valid Pydantic " - "field (e.g. Union[Response, dict, None]) you can disable generating the " - "response model from the type annotation with the path operation decorator " - "parameter response_model=None. Read more: " - "https://fastapi.tiangolo.com/tutorial/response-model/" - ) from None - - -def create_cloned_field( - field: ModelField, + +CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + + +def get_flat_dependant( + dependant: Dependant, *, - cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, -) -> ModelField: - # _cloned_types has already cloned types, to support recursive models - if cloned_types is None: - cloned_types = {} - original_type = field.type_ - if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): - original_type = original_type.__pydantic_model__ - use_type = original_type - if lenient_issubclass(original_type, BaseModel): - original_type = cast(Type[BaseModel], original_type) - use_type = cloned_types.get(original_type) - if use_type is None: - use_type = create_model(original_type.__name__, __base__=original_type) - cloned_types[original_type] = use_type - for f in original_type.__fields__.values(): - use_type.__fields__[f.name] = create_cloned_field( - f, cloned_types=cloned_types - ) - new_field = create_response_field(name=field.name, type_=use_type) - new_field.has_alias = field.has_alias - new_field.alias = field.alias - new_field.class_validators = field.class_validators - new_field.default = field.default - new_field.required = field.required - new_field.model_config = field.model_config - new_field.field_info = field.field_info - new_field.allow_none = field.allow_none - new_field.validate_always = field.validate_always + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_schemes=dependant.security_requirements.copy(), + use_cache=dependant.use_cache, + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + flat_sub = get_flat_dependant( + sub_dependant, skip_repeats=skip_repeats, visited=visited + ) + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + flat_dependant.security_requirements.extend(flat_sub.security_requirements) + return flat_dependant + + +def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + +def is_scalar_field(field: ModelField) -> bool: + field_info = field.field_info + if not ( + field.shape == SHAPE_SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not lenient_issubclass(field.type_, sequence_types + (dict,)) + and not dataclasses.is_dataclass(field.type_) + and not isinstance(field_info, params.Body) + ): + return False if field.sub_fields: - new_field.sub_fields = [ - create_cloned_field(sub_field, cloned_types=cloned_types) - for sub_field in field.sub_fields - ] - if field.key_field: - new_field.key_field = create_cloned_field( - field.key_field, cloned_types=cloned_types + if not all(is_scalar_field(f) for f in field.sub_fields): + return False + return True + + +def is_scalar_sequence_field(field: ModelField) -> bool: + if (field.shape in sequence_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, sequence_types): + return True + return False + + +def is_generic_type(obj: Any) -> bool: + return hasattr(obj, "__origin__") and hasattr(obj, "__args__") + + +def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: + collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} + if is_generic_type(annotation): + args = tuple( + substitute_generic_type(arg, typevars) for arg in annotation.__args__ ) - new_field.validators = field.validators - new_field.pre_validators = field.pre_validators - new_field.post_validators = field.post_validators - new_field.parse_json = field.parse_json - new_field.shape = field.shape - new_field.populate_validators() - return new_field - - -def generate_operation_id_for_path( - *, name: str, path: str, method: str -) -> str: # pragma: nocover - warnings.warn( - "fastapi.utils.generate_operation_id_for_path() was deprecated, " - "it is not used internally, and will be removed soon", - DeprecationWarning, - stacklevel=2, + annotation = collection_shells.get(annotation.__origin__, annotation) + return annotation[args] + return typevars.get(annotation.__name__, annotation) + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + typevars = None + if is_generic_type(call): + stub: GenericTypeStub = cast(GenericTypeStub, call) + typevars = { + typevar.__name__: value + for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__) + } + call = stub.__origin__.__init__ # type: ignore + + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns, typevars), + ) + for param in signature.parameters.values() + if param.name != 'self' + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, +) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + if typevars: + annotation = substitute_generic_type(annotation, typevars) + return annotation + + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + return get_typed_annotation(annotation, globalns) + + +def get_dependant( + *, + path: str, + call: Callable[..., Any], + name: Optional[str] = None, + security_scopes: Optional[List[str]] = None, + use_cache: bool = True, +) -> Dependant: + path_param_names = get_path_param_names(path) + endpoint_signature = get_typed_signature(call) + signature_params = endpoint_signature.parameters + dependant = Dependant( + call=call, + name=name, + path=path, + security_scopes=security_scopes, + use_cache=use_cache, ) - operation_id = name + path - operation_id = re.sub(r"\W", "_", operation_id) - operation_id = operation_id + "_" + method.lower() - return operation_id + for param_name, param in signature_params.items(): + if isinstance(param.default, params.Depends): + sub_dependant = get_param_sub_dependant( + param=param, path=path, security_scopes=security_scopes + ) + dependant.dependencies.append(sub_dependant) + continue + if add_non_field_param_to_dependency(param=param, dependant=dependant): + continue + param_field = get_param_field( + param=param, default_field_info=params.Query, param_name=param_name + ) + if param_name in path_param_names: + assert is_scalar_field( + field=param_field + ), "Path params must be of one of the supported types" + ignore_default = not isinstance(param.default, params.Path) + param_field = get_param_field( + param=param, + param_name=param_name, + default_field_info=params.Path, + force_type=params.ParamTypes.path, + ignore_default=ignore_default, + ) + add_param_to_fields(field=param_field, dependant=dependant) + elif is_scalar_field(field=param_field): + add_param_to_fields(field=param_field, dependant=dependant) + elif isinstance( + param.default, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): + add_param_to_fields(field=param_field, dependant=dependant) + else: + field_info = param_field.field_info + assert isinstance( + field_info, params.Body + ), f"Param: {param_field.name} can only be a request body, using Body()" + dependant.body_params.append(param_field) + return dependant -def generate_unique_id(route: "APIRoute") -> str: - operation_id = route.name + route.path_format - operation_id = re.sub(r"\W", "_", operation_id) - assert route.methods - operation_id = operation_id + "_" + list(route.methods)[0].lower() - return operation_id +def add_non_field_param_to_dependency( + *, param: inspect.Parameter, dependant: Dependant +) -> Optional[bool]: + if lenient_issubclass(param.annotation, Request): + dependant.request_param_name = param.name + return True + elif lenient_issubclass(param.annotation, WebSocket): + dependant.websocket_param_name = param.name + return True + elif lenient_issubclass(param.annotation, HTTPConnection): + dependant.http_connection_param_name = param.name + return True + elif lenient_issubclass(param.annotation, Response): + dependant.response_param_name = param.name + return True + elif lenient_issubclass(param.annotation, BackgroundTasks): + dependant.background_tasks_param_name = param.name + return True + elif lenient_issubclass(param.annotation, SecurityScopes): + dependant.security_scopes_param_name = param.name + return True + return None + + +def get_param_field( + *, + param: inspect.Parameter, + param_name: str, + default_field_info: Type[params.Param] = params.Param, + force_type: Optional[params.ParamTypes] = None, + ignore_default: bool = False, +) -> ModelField: + default_value: Any = Undefined + had_schema = False + if not param.default == param.empty and ignore_default is False: + default_value = param.default + if isinstance(default_value, FieldInfo): + had_schema = True + field_info = default_value + default_value = field_info.default + if ( + isinstance(field_info, params.Param) + and getattr(field_info, "in_", None) is None + ): + field_info.in_ = default_field_info.in_ + if force_type: + field_info.in_ = force_type # type: ignore + else: + field_info = default_field_info(default=default_value) + required = True + if default_value is Required or ignore_default: + required = True + default_value = None + elif default_value is not Undefined: + required = False + annotation: Any = Any + if not param.annotation == param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info(annotation, field_info, param_name) + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param.name.replace("_", "-") + else: + alias = field_info.alias or param.name + field = create_response_field( + name=param.name, + type_=annotation, + default=default_value, + alias=alias, + required=required, + field_info=field_info, + ) + if not had_schema and not is_scalar_field(field=field): + field.field_info = params.Body(field_info.default) + if not had_schema and lenient_issubclass(field.type_, UploadFile): + field.field_info = params.File(field_info.default) + + return field + +def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: + field_info = cast(params.Param, field.field_info) + if field_info.in_ == params.ParamTypes.path: + dependant.path_params.append(field) + elif field_info.in_ == params.ParamTypes.query: + dependant.query_params.append(field) + elif field_info.in_ == params.ParamTypes.header: + dependant.header_params.append(field) + else: + assert ( + field_info.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {field.name}" + dependant.cookie_params.append(field) -def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: - for key, value in update_dict.items(): + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + +def is_async_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isasyncgenfunction(dunder_call) + + +def is_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isgeneratorfunction(dunder_call) + + +async def solve_generator( + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] +) -> Any: + if is_gen_callable(call): + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + elif is_async_gen_callable(call): + cm = asynccontextmanager(call)(**sub_values) + return await stack.enter_async_context(cm) + + +async def solve_dependencies( + *, + request: Union[Request, WebSocket], + dependant: Dependant, + body: Optional[Union[Dict[str, Any], FormData]] = None, + background_tasks: Optional[BackgroundTasks] = None, + response: Optional[Response] = None, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, +) -> Tuple[ + Dict[str, Any], + List[ErrorWrapper], + Optional[BackgroundTasks], + Response, + Dict[Tuple[Callable[..., Any], Tuple[str]], Any], +]: + values: Dict[str, Any] = {} + errors: List[ErrorWrapper] = [] + if response is None: + response = Response() + del response.headers["content-length"] + response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} + sub_dependant: Dependant + for sub_dependant in dependant.dependencies: + sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) + sub_dependant.cache_key = cast( + Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key + ) + call = sub_dependant.call + use_sub_dependant = sub_dependant if ( - key in main_dict - and isinstance(main_dict[key], dict) - and isinstance(value, dict) + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides ): - deep_dict_update(main_dict[key], value) - elif ( - key in main_dict - and isinstance(main_dict[key], list) - and isinstance(update_dict[key], list) + original_call = sub_dependant.call + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) + use_path: str = sub_dependant.path # type: ignore + use_sub_dependant = get_dependant( + path=use_path, + call=call, + name=sub_dependant.name, + security_scopes=sub_dependant.security_scopes, + ) + + solved_result = await solve_dependencies( + request=request, + dependant=use_sub_dependant, + body=body, + background_tasks=background_tasks, + response=response, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + ) + ( + sub_values, + sub_errors, + background_tasks, + _, # the subdependency returns the same response we have + sub_dependency_cache, + ) = solved_result + dependency_cache.update(sub_dependency_cache) + if sub_errors: + errors.extend(sub_errors) + continue + if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: + solved = dependency_cache[sub_dependant.cache_key] + elif is_gen_callable(call) or is_async_gen_callable(call): + stack = request.scope.get("fastapi_astack") + assert isinstance(stack, AsyncExitStack) + solved = await solve_generator( + call=call, stack=stack, sub_values=sub_values + ) + elif is_coroutine_callable(call): + solved = await call(**sub_values) + else: + solved = await run_in_threadpool(call, **sub_values) + if sub_dependant.name is not None: + values[sub_dependant.name] = solved + if sub_dependant.cache_key not in dependency_cache: + dependency_cache[sub_dependant.cache_key] = solved + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) + values.update(path_values) + values.update(query_values) + values.update(header_values) + values.update(cookie_values) + errors += path_errors + query_errors + header_errors + cookie_errors + if dependant.body_params: + ( + body_values, + body_errors, + ) = await request_body_to_args( # body_params checked above + required_params=dependant.body_params, received_body=body + ) + values.update(body_values) + errors.extend(body_errors) + if dependant.http_connection_param_name: + values[dependant.http_connection_param_name] = request + if dependant.request_param_name and isinstance(request, Request): + values[dependant.request_param_name] = request + elif dependant.websocket_param_name and isinstance(request, WebSocket): + values[dependant.websocket_param_name] = request + if dependant.background_tasks_param_name: + if background_tasks is None: + background_tasks = BackgroundTasks() + values[dependant.background_tasks_param_name] = background_tasks + if dependant.response_param_name: + values[dependant.response_param_name] = response + if dependant.security_scopes_param_name: + values[dependant.security_scopes_param_name] = SecurityScopes( + scopes=dependant.security_scopes + ) + return values, errors, background_tasks, response, dependency_cache + + +def request_params_to_args( + required_params: Sequence[ModelField], + received_params: Union[Mapping[str, Any], QueryParams, Headers], +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + for field in required_params: + if is_scalar_sequence_field(field) and isinstance( + received_params, (QueryParams, Headers) ): - main_dict[key] = main_dict[key] + update_dict[key] + value = received_params.getlist(field.alias) or field.default else: - main_dict[key] = value + value = received_params.get(field.alias) + field_info = field.field_info + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" + if value is None: + if field.required: + errors.append( + ErrorWrapper( + MissingError(), loc=(field_info.in_.value, field.alias) + ) + ) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate( + value, values, loc=(field_info.in_.value, field.alias) + ) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + +async def request_body_to_args( + required_params: List[ModelField], + received_body: Optional[Union[Dict[str, Any], FormData]], +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + if required_params: + field = required_params[0] + field_info = field.field_info + embed = getattr(field_info, "embed", None) + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} -def get_value_or_default( - first_item: Union[DefaultPlaceholder, DefaultType], - *extra_items: Union[DefaultPlaceholder, DefaultType], -) -> Union[DefaultPlaceholder, DefaultType]: - """ - Pass items or `DefaultPlaceholder`s by descending priority. + for field in required_params: + loc: Tuple[str, ...] + if field_alias_omitted: + loc = ("body",) + else: + loc = ("body", field.alias) - The first one to _not_ be a `DefaultPlaceholder` will be returned. + value: Optional[Any] = None + if received_body is not None: + if ( + field.shape in sequence_shapes or field.type_ in sequence_types + ) and isinstance(received_body, FormData): + value = received_body.getlist(field.alias) + else: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + continue + if ( + value is None + or (isinstance(field_info, params.Form) and value == "") + or ( + isinstance(field_info, params.Form) + and field.shape in sequence_shapes + and len(value) == 0 + ) + ): + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + if ( + isinstance(field_info, params.File) + and lenient_issubclass(field.type_, bytes) + and isinstance(value, UploadFile) + ): + value = await value.read() + elif ( + field.shape in sequence_shapes + and isinstance(field_info, params.File) + and lenient_issubclass(field.type_, bytes) + and isinstance(value, sequence_types) + ): + results: List[Union[bytes, str]] = [] + + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]] + ) -> None: + result = await fn() + results.append(result) + + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = sequence_shape_to_type[field.shape](results) + + v_, errors_ = field.validate(value, values, loc=loc) + + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper: + missing_field_error = ErrorWrapper(MissingError(), loc=loc) + return missing_field_error + + +def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + first_param = flat_dependant.body_params[0] + field_info = first_param.field_info + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + check_file_field(first_param) + return first_param + # If one field requires to embed, all have to be embedded + # in case a sub-dependency is evaluated with a single unique body field + # That is combined (embedded) with other body fields + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 + model_name = "Body_" + name + BodyModel: Type[BaseModel] = create_model(model_name) + for f in flat_dependant.body_params: + BodyModel.__fields__[f.name] = f + required = any(True for f in flat_dependant.body_params if f.required) + + BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None} + if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): + BodyFieldInfo: Type[params.Body] = params.File + elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): + BodyFieldInfo = params.Form + else: + BodyFieldInfo = params.Body + + body_param_media_types = [ + f.field_info.media_type + for f in flat_dependant.body_params + if isinstance(f.field_info, params.Body) + ] + if len(set(body_param_media_types)) == 1: + BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] + final_field = create_response_field( + name="body", + type_=BodyModel, + required=required, + alias="body", + field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), + ) + check_file_field(final_field) + return final_field - Otherwise, the first item (a `DefaultPlaceholder`) will be returned. - """ - items = (first_item,) + extra_items - for item in items: - if not isinstance(item, DefaultPlaceholder): - return item - return first_item diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py index 3eeb8291e..95b81eb84 100644 --- a/tests/test_dependency_generic_class.py +++ b/tests/test_dependency_generic_class.py @@ -19,7 +19,7 @@ class SecondGenericType(Generic[T, C]): simple: T, lst: List[T], dct: Dict[T, C], - custom_class: FirstGenericType[T] = Depends(), + custom_class: FirstGenericType[T] = Depends() ): self.simple = simple self.lst = lst @@ -38,7 +38,7 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): "dct": obj.dct, "custom_class": { "simple": obj.custom_class.simple, - "lst": obj.custom_class.lst, + "lst": obj.custom_class.lst }, }