From 743cca58c3816ac3dca5ec3f513f72c25c537826 Mon Sep 17 00:00:00 2001 From: yyklimenko Date: Wed, 22 Feb 2023 21:38:55 +0600 Subject: [PATCH 1/6] Added support for injection of custom generic classes via Depends() --- fastapi/dependencies/utils.py | 41 ++++++++++++++-- tests/test_dependency_generic_class.py | 66 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 tests/test_dependency_generic_class.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 32e171f18..dd10a2a2a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -245,26 +245,60 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False +def is_generic_type(obj: Any) -> bool: + return hasattr(obj, '__origin__') and hasattr(obj, '__args__') + + +def substitute_generic_type( + annotation: Any, typevars: Dict[str, Any] +) -> Any: + collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} + if is_generic_type(annotation): + args = tuple( + substitute_generic_type(arg, typevars) for arg in annotation.__args__ + ) + annotation = collection_shells.get(annotation.__origin__, annotation) + return annotation[args] + return typevars.get(annotation.__name__, annotation) + + def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + typevars = None + if is_generic_type(call): + origin = getattr(call, '__origin__') + typevars = { + typevar.__name__: value for typevar, value in zip( + origin.__parameters__, getattr(call, '__args__') + ) + } + call = origin.__init__ + signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) + typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=get_typed_annotation(param.annotation, globalns, typevars), ) - for param in signature.parameters.values() + for param in signature.parameters.values() if param.name != 'self' ] typed_signature = inspect.Signature(typed_params) return typed_signature -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, +) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) + if typevars: + annotation = substitute_generic_type(annotation, typevars) return annotation @@ -765,3 +799,4 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ) check_file_field(final_field) return final_field + diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py new file mode 100644 index 000000000..1cddbf242 --- /dev/null +++ b/tests/test_dependency_generic_class.py @@ -0,0 +1,66 @@ +from typing import Generic, List, TypeVar, Dict + +from starlette.testclient import TestClient + +from fastapi import Depends, FastAPI + +T = TypeVar("T") +C = TypeVar("C") + + +class FirstGenericType(Generic[T]): + + def __init__(self, simple: T, lst: List[T]): + self.simple = simple + self.lst = lst + + +class SecondGenericType(Generic[T, C]): + + def __init__( + self, + simple: T, + lst: List[T], + dct: Dict[T, C], + custom_class: FirstGenericType[T] = Depends() + ): + self.simple = simple + self.lst = lst + self.dct = dct + self.custom_class = custom_class + + +app = FastAPI() + + +@app.post("/test_generic_class") +def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): + return { + "simple": obj.simple, + "lst": obj.lst, + "dct": obj.dct, + "custom_class": { + "simple": obj.custom_class.simple, + "lst": obj.custom_class.lst + } + } + + +client = TestClient(app) + + +def test_generic_class_dependency(): + response = client.post("/test_generic_class?simple=simple", json={ + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + }) + assert response.status_code == 200, response.json() + assert response.json() == { + "custom_class": { + "lst": ["string_1", "string_2"], + "simple": "simple", + }, + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + "simple": "simple", + } From 410424b4d1ab6d36a76900b4212e9b48d554a60e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 16:19:05 +0000 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 23 ++++++++---------- tests/test_dependency_generic_class.py | 32 +++++++++++++------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index dd10a2a2a..1bf780cd6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -246,12 +246,10 @@ def is_scalar_sequence_field(field: ModelField) -> bool: def is_generic_type(obj: Any) -> bool: - return hasattr(obj, '__origin__') and hasattr(obj, '__args__') + return hasattr(obj, "__origin__") and hasattr(obj, "__args__") -def substitute_generic_type( - annotation: Any, typevars: Dict[str, Any] -) -> Any: +def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} if is_generic_type(annotation): args = tuple( @@ -265,11 +263,10 @@ def substitute_generic_type( def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typevars = None if is_generic_type(call): - origin = getattr(call, '__origin__') + origin = call.__origin__ typevars = { - typevar.__name__: value for typevar, value in zip( - origin.__parameters__, getattr(call, '__args__') - ) + typevar.__name__: value + for typevar, value in zip(origin.__parameters__, call.__args__) } call = origin.__init__ @@ -283,16 +280,17 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: default=param.default, annotation=get_typed_annotation(param.annotation, globalns, typevars), ) - for param in signature.parameters.values() if param.name != 'self' + for param in signature.parameters.values() + if param.name != "self" ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_annotation( - annotation: Any, - globalns: Dict[str, Any], - typevars: Optional[Dict[str, type]] = None, + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, ) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) @@ -799,4 +797,3 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ) check_file_field(final_field) return final_field - diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py index 1cddbf242..3eeb8291e 100644 --- a/tests/test_dependency_generic_class.py +++ b/tests/test_dependency_generic_class.py @@ -1,28 +1,25 @@ -from typing import Generic, List, TypeVar, Dict - -from starlette.testclient import TestClient +from typing import Dict, Generic, List, TypeVar from fastapi import Depends, FastAPI +from starlette.testclient import TestClient T = TypeVar("T") C = TypeVar("C") class FirstGenericType(Generic[T]): - def __init__(self, simple: T, lst: List[T]): self.simple = simple self.lst = lst class SecondGenericType(Generic[T, C]): - def __init__( - self, - simple: T, - lst: List[T], - dct: Dict[T, C], - custom_class: FirstGenericType[T] = Depends() + self, + simple: T, + lst: List[T], + dct: Dict[T, C], + custom_class: FirstGenericType[T] = Depends(), ): self.simple = simple self.lst = lst @@ -41,8 +38,8 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): "dct": obj.dct, "custom_class": { "simple": obj.custom_class.simple, - "lst": obj.custom_class.lst - } + "lst": obj.custom_class.lst, + }, } @@ -50,10 +47,13 @@ client = TestClient(app) def test_generic_class_dependency(): - response = client.post("/test_generic_class?simple=simple", json={ - "lst": ["string_1", "string_2"], - "dct": {"key": 1}, - }) + response = client.post( + "/test_generic_class?simple=simple", + json={ + "lst": ["string_1", "string_2"], + "dct": {"key": 1}, + }, + ) assert response.status_code == 200, response.json() assert response.json() == { "custom_class": { From 9551ad1970be9f8cc6ef4c2abfb5a1bfba149db9 Mon Sep 17 00:00:00 2001 From: yyklimenko Date: Thu, 23 Feb 2023 22:04:12 +0600 Subject: [PATCH 3/6] Added stubs file to mypy checking --- fastapi/stubs.py | 9 + fastapi/utils.py | 954 ++++++++++++++++++++----- tests/test_dependency_generic_class.py | 4 +- 3 files changed, 785 insertions(+), 182 deletions(-) create mode 100644 fastapi/stubs.py diff --git a/fastapi/stubs.py b/fastapi/stubs.py new file mode 100644 index 000000000..b9ab48a69 --- /dev/null +++ b/fastapi/stubs.py @@ -0,0 +1,9 @@ +from typing import Any, Tuple + + +class GenericTypeStub: + class OriginTypeStub: + __parameters__: Tuple[Any] + + __origin__: OriginTypeStub + __args__: Tuple[Any] diff --git a/fastapi/utils.py b/fastapi/utils.py index 391c47d81..1bd2de317 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,207 +1,801 @@ -import functools -import re -import warnings -from dataclasses import is_dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast - -import fastapi -from fastapi.datastructures import DefaultPlaceholder, DefaultType -from fastapi.openapi.constants import REF_PREFIX -from pydantic import BaseConfig, BaseModel, create_model -from pydantic.class_validators import Validator -from pydantic.fields import FieldInfo, ModelField, UndefinedType -from pydantic.schema import model_process_schema +import dataclasses +import inspect +from contextlib import contextmanager +from copy import deepcopy +from typing import ( + Any, + Callable, + Coroutine, + Dict, + ForwardRef, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import anyio +from fastapi import params +from fastapi.concurrency import ( + AsyncExitStack, + asynccontextmanager, + contextmanager_in_threadpool, +) +from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.stubs import GenericTypeStub +from fastapi.logger import logger +from fastapi.security.base import SecurityBase +from fastapi.security.oauth2 import OAuth2, SecurityScopes +from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.utils import create_response_field, get_path_param_names +from pydantic import BaseModel, create_model +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import ( + SHAPE_FROZENSET, + SHAPE_LIST, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_SINGLETON, + SHAPE_TUPLE, + SHAPE_TUPLE_ELLIPSIS, + FieldInfo, + ModelField, + Required, + Undefined, +) +from pydantic.schema import get_annotation_from_field_info +from pydantic.typing import evaluate_forwardref from pydantic.utils import lenient_issubclass +from starlette.background import BackgroundTasks +from starlette.concurrency import run_in_threadpool +from starlette.datastructures import FormData, Headers, QueryParams, UploadFile +from starlette.requests import HTTPConnection, Request +from starlette.responses import Response +from starlette.websockets import WebSocket -if TYPE_CHECKING: # pragma: nocover - from .routing import APIRoute +sequence_shapes = { + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, +} +sequence_types = (list, set, tuple) +sequence_shape_to_type = { + SHAPE_LIST: list, + SHAPE_SET: set, + SHAPE_TUPLE: tuple, + SHAPE_SEQUENCE: list, + SHAPE_TUPLE_ELLIPSIS: list, +} -def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: - if status_code is None: - return True - # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 - if status_code in { - "default", - "1XX", - "2XX", - "3XX", - "4XX", - "5XX", - }: - return True - current_status_code = int(status_code) - return not (current_status_code < 200 or current_status_code in {204, 304}) +multipart_not_installed_error = ( + 'Form data requires "python-multipart" to be installed. \n' + 'You can install "python-multipart" with: \n\n' + "pip install python-multipart\n" +) +multipart_incorrect_install_error = ( + 'Form data requires "python-multipart" to be installed. ' + 'It seems you installed "multipart" instead. \n' + 'You can remove "multipart" with: \n\n' + "pip uninstall multipart\n\n" + 'And then install "python-multipart" with: \n\n' + "pip install python-multipart\n" +) + +def check_file_field(field: ModelField) -> None: + field_info = field.field_info + if isinstance(field_info, params.Form): + try: + # __version__ is available in both multiparts, and can be mocked + from multipart import __version__ # type: ignore -def get_model_definitions( + assert __version__ + try: + # parse_options_header is only available in the right multipart + from multipart.multipart import parse_options_header # type: ignore + + assert parse_options_header + except ImportError: + logger.error(multipart_incorrect_install_error) + raise RuntimeError(multipart_incorrect_install_error) from None + except ImportError: + logger.error(multipart_not_installed_error) + raise RuntimeError(multipart_not_installed_error) from None + + +def get_param_sub_dependant( + *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None +) -> Dependant: + depends: params.Depends = param.default + if depends.dependency: + dependency = depends.dependency + else: + dependency = param.annotation + return get_sub_dependant( + depends=depends, + dependency=dependency, + path=path, + name=param.name, + security_scopes=security_scopes, + ) + + +def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" + return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) + + +def get_sub_dependant( *, - flat_models: Set[Union[Type[BaseModel], Type[Enum]]], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], -) -> Dict[str, Any]: - definitions: Dict[str, Dict[str, Any]] = {} - for model in flat_models: - m_schema, m_definitions, m_nested_models = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=REF_PREFIX + depends: params.Depends, + dependency: Callable[..., Any], + path: str, + name: Optional[str] = None, + security_scopes: Optional[List[str]] = None, +) -> Dependant: + security_requirement = None + security_scopes = security_scopes or [] + if isinstance(depends, params.Security): + dependency_scopes = depends.scopes + security_scopes.extend(dependency_scopes) + if isinstance(dependency, SecurityBase): + use_scopes: List[str] = [] + if isinstance(dependency, (OAuth2, OpenIdConnect)): + use_scopes = security_scopes + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes ) - definitions.update(m_definitions) - model_name = model_name_map[model] - if "description" in m_schema: - m_schema["description"] = m_schema["description"].split("\f")[0] - definitions[model_name] = m_schema - return definitions - - -def get_path_param_names(path: str) -> Set[str]: - return set(re.findall("{(.*?)}", path)) - - -def create_response_field( - name: str, - type_: Type[Any], - class_validators: Optional[Dict[str, Validator]] = None, - default: Optional[Any] = None, - required: Union[bool, UndefinedType] = True, - model_config: Type[BaseConfig] = BaseConfig, - field_info: Optional[FieldInfo] = None, - alias: Optional[str] = None, -) -> ModelField: - """ - Create a new response field. Raises if type_ is invalid. - """ - class_validators = class_validators or {} - field_info = field_info or FieldInfo() - - response_field = functools.partial( - ModelField, + sub_dependant = get_dependant( + path=path, + call=dependency, name=name, - type_=type_, - class_validators=class_validators, - default=default, - required=required, - model_config=model_config, - alias=alias, + security_scopes=security_scopes, + use_cache=depends.use_cache, ) + if security_requirement: + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant - try: - return response_field(field_info=field_info) - except RuntimeError: - raise fastapi.exceptions.FastAPIError( - "Invalid args for response field! Hint: " - f"check that {type_} is a valid Pydantic field type. " - "If you are using a return type annotation that is not a valid Pydantic " - "field (e.g. Union[Response, dict, None]) you can disable generating the " - "response model from the type annotation with the path operation decorator " - "parameter response_model=None. Read more: " - "https://fastapi.tiangolo.com/tutorial/response-model/" - ) from None - - -def create_cloned_field( - field: ModelField, + +CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] + + +def get_flat_dependant( + dependant: Dependant, *, - cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, -) -> ModelField: - # _cloned_types has already cloned types, to support recursive models - if cloned_types is None: - cloned_types = {} - original_type = field.type_ - if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): - original_type = original_type.__pydantic_model__ - use_type = original_type - if lenient_issubclass(original_type, BaseModel): - original_type = cast(Type[BaseModel], original_type) - use_type = cloned_types.get(original_type) - if use_type is None: - use_type = create_model(original_type.__name__, __base__=original_type) - cloned_types[original_type] = use_type - for f in original_type.__fields__.values(): - use_type.__fields__[f.name] = create_cloned_field( - f, cloned_types=cloned_types - ) - new_field = create_response_field(name=field.name, type_=use_type) - new_field.has_alias = field.has_alias - new_field.alias = field.alias - new_field.class_validators = field.class_validators - new_field.default = field.default - new_field.required = field.required - new_field.model_config = field.model_config - new_field.field_info = field.field_info - new_field.allow_none = field.allow_none - new_field.validate_always = field.validate_always + skip_repeats: bool = False, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_schemes=dependant.security_requirements.copy(), + use_cache=dependant.use_cache, + path=dependant.path, + ) + for sub_dependant in dependant.dependencies: + if skip_repeats and sub_dependant.cache_key in visited: + continue + flat_sub = get_flat_dependant( + sub_dependant, skip_repeats=skip_repeats, visited=visited + ) + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + flat_dependant.security_requirements.extend(flat_sub.security_requirements) + return flat_dependant + + +def get_flat_params(dependant: Dependant) -> List[ModelField]: + flat_dependant = get_flat_dependant(dependant, skip_repeats=True) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + +def is_scalar_field(field: ModelField) -> bool: + field_info = field.field_info + if not ( + field.shape == SHAPE_SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not lenient_issubclass(field.type_, sequence_types + (dict,)) + and not dataclasses.is_dataclass(field.type_) + and not isinstance(field_info, params.Body) + ): + return False if field.sub_fields: - new_field.sub_fields = [ - create_cloned_field(sub_field, cloned_types=cloned_types) - for sub_field in field.sub_fields - ] - if field.key_field: - new_field.key_field = create_cloned_field( - field.key_field, cloned_types=cloned_types + if not all(is_scalar_field(f) for f in field.sub_fields): + return False + return True + + +def is_scalar_sequence_field(field: ModelField) -> bool: + if (field.shape in sequence_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, sequence_types): + return True + return False + + +def is_generic_type(obj: Any) -> bool: + return hasattr(obj, "__origin__") and hasattr(obj, "__args__") + + +def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: + collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} + if is_generic_type(annotation): + args = tuple( + substitute_generic_type(arg, typevars) for arg in annotation.__args__ ) - new_field.validators = field.validators - new_field.pre_validators = field.pre_validators - new_field.post_validators = field.post_validators - new_field.parse_json = field.parse_json - new_field.shape = field.shape - new_field.populate_validators() - return new_field - - -def generate_operation_id_for_path( - *, name: str, path: str, method: str -) -> str: # pragma: nocover - warnings.warn( - "fastapi.utils.generate_operation_id_for_path() was deprecated, " - "it is not used internally, and will be removed soon", - DeprecationWarning, - stacklevel=2, + annotation = collection_shells.get(annotation.__origin__, annotation) + return annotation[args] + return typevars.get(annotation.__name__, annotation) + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + typevars = None + if is_generic_type(call): + stub: GenericTypeStub = cast(GenericTypeStub, call) + typevars = { + typevar.__name__: value + for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__) + } + call = stub.__origin__.__init__ # type: ignore + + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns, typevars), + ) + for param in signature.parameters.values() + if param.name != 'self' + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, +) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + if typevars: + annotation = substitute_generic_type(annotation, typevars) + return annotation + + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + return get_typed_annotation(annotation, globalns) + + +def get_dependant( + *, + path: str, + call: Callable[..., Any], + name: Optional[str] = None, + security_scopes: Optional[List[str]] = None, + use_cache: bool = True, +) -> Dependant: + path_param_names = get_path_param_names(path) + endpoint_signature = get_typed_signature(call) + signature_params = endpoint_signature.parameters + dependant = Dependant( + call=call, + name=name, + path=path, + security_scopes=security_scopes, + use_cache=use_cache, ) - operation_id = name + path - operation_id = re.sub(r"\W", "_", operation_id) - operation_id = operation_id + "_" + method.lower() - return operation_id + for param_name, param in signature_params.items(): + if isinstance(param.default, params.Depends): + sub_dependant = get_param_sub_dependant( + param=param, path=path, security_scopes=security_scopes + ) + dependant.dependencies.append(sub_dependant) + continue + if add_non_field_param_to_dependency(param=param, dependant=dependant): + continue + param_field = get_param_field( + param=param, default_field_info=params.Query, param_name=param_name + ) + if param_name in path_param_names: + assert is_scalar_field( + field=param_field + ), "Path params must be of one of the supported types" + ignore_default = not isinstance(param.default, params.Path) + param_field = get_param_field( + param=param, + param_name=param_name, + default_field_info=params.Path, + force_type=params.ParamTypes.path, + ignore_default=ignore_default, + ) + add_param_to_fields(field=param_field, dependant=dependant) + elif is_scalar_field(field=param_field): + add_param_to_fields(field=param_field, dependant=dependant) + elif isinstance( + param.default, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): + add_param_to_fields(field=param_field, dependant=dependant) + else: + field_info = param_field.field_info + assert isinstance( + field_info, params.Body + ), f"Param: {param_field.name} can only be a request body, using Body()" + dependant.body_params.append(param_field) + return dependant -def generate_unique_id(route: "APIRoute") -> str: - operation_id = route.name + route.path_format - operation_id = re.sub(r"\W", "_", operation_id) - assert route.methods - operation_id = operation_id + "_" + list(route.methods)[0].lower() - return operation_id +def add_non_field_param_to_dependency( + *, param: inspect.Parameter, dependant: Dependant +) -> Optional[bool]: + if lenient_issubclass(param.annotation, Request): + dependant.request_param_name = param.name + return True + elif lenient_issubclass(param.annotation, WebSocket): + dependant.websocket_param_name = param.name + return True + elif lenient_issubclass(param.annotation, HTTPConnection): + dependant.http_connection_param_name = param.name + return True + elif lenient_issubclass(param.annotation, Response): + dependant.response_param_name = param.name + return True + elif lenient_issubclass(param.annotation, BackgroundTasks): + dependant.background_tasks_param_name = param.name + return True + elif lenient_issubclass(param.annotation, SecurityScopes): + dependant.security_scopes_param_name = param.name + return True + return None + + +def get_param_field( + *, + param: inspect.Parameter, + param_name: str, + default_field_info: Type[params.Param] = params.Param, + force_type: Optional[params.ParamTypes] = None, + ignore_default: bool = False, +) -> ModelField: + default_value: Any = Undefined + had_schema = False + if not param.default == param.empty and ignore_default is False: + default_value = param.default + if isinstance(default_value, FieldInfo): + had_schema = True + field_info = default_value + default_value = field_info.default + if ( + isinstance(field_info, params.Param) + and getattr(field_info, "in_", None) is None + ): + field_info.in_ = default_field_info.in_ + if force_type: + field_info.in_ = force_type # type: ignore + else: + field_info = default_field_info(default=default_value) + required = True + if default_value is Required or ignore_default: + required = True + default_value = None + elif default_value is not Undefined: + required = False + annotation: Any = Any + if not param.annotation == param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info(annotation, field_info, param_name) + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param.name.replace("_", "-") + else: + alias = field_info.alias or param.name + field = create_response_field( + name=param.name, + type_=annotation, + default=default_value, + alias=alias, + required=required, + field_info=field_info, + ) + if not had_schema and not is_scalar_field(field=field): + field.field_info = params.Body(field_info.default) + if not had_schema and lenient_issubclass(field.type_, UploadFile): + field.field_info = params.File(field_info.default) + + return field + +def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: + field_info = cast(params.Param, field.field_info) + if field_info.in_ == params.ParamTypes.path: + dependant.path_params.append(field) + elif field_info.in_ == params.ParamTypes.query: + dependant.query_params.append(field) + elif field_info.in_ == params.ParamTypes.header: + dependant.header_params.append(field) + else: + assert ( + field_info.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {field.name}" + dependant.cookie_params.append(field) -def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: - for key, value in update_dict.items(): + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + +def is_async_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isasyncgenfunction(dunder_call) + + +def is_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isgeneratorfunction(dunder_call) + + +async def solve_generator( + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] +) -> Any: + if is_gen_callable(call): + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + elif is_async_gen_callable(call): + cm = asynccontextmanager(call)(**sub_values) + return await stack.enter_async_context(cm) + + +async def solve_dependencies( + *, + request: Union[Request, WebSocket], + dependant: Dependant, + body: Optional[Union[Dict[str, Any], FormData]] = None, + background_tasks: Optional[BackgroundTasks] = None, + response: Optional[Response] = None, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, +) -> Tuple[ + Dict[str, Any], + List[ErrorWrapper], + Optional[BackgroundTasks], + Response, + Dict[Tuple[Callable[..., Any], Tuple[str]], Any], +]: + values: Dict[str, Any] = {} + errors: List[ErrorWrapper] = [] + if response is None: + response = Response() + del response.headers["content-length"] + response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} + sub_dependant: Dependant + for sub_dependant in dependant.dependencies: + sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) + sub_dependant.cache_key = cast( + Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key + ) + call = sub_dependant.call + use_sub_dependant = sub_dependant if ( - key in main_dict - and isinstance(main_dict[key], dict) - and isinstance(value, dict) + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides ): - deep_dict_update(main_dict[key], value) - elif ( - key in main_dict - and isinstance(main_dict[key], list) - and isinstance(update_dict[key], list) + original_call = sub_dependant.call + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) + use_path: str = sub_dependant.path # type: ignore + use_sub_dependant = get_dependant( + path=use_path, + call=call, + name=sub_dependant.name, + security_scopes=sub_dependant.security_scopes, + ) + + solved_result = await solve_dependencies( + request=request, + dependant=use_sub_dependant, + body=body, + background_tasks=background_tasks, + response=response, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + ) + ( + sub_values, + sub_errors, + background_tasks, + _, # the subdependency returns the same response we have + sub_dependency_cache, + ) = solved_result + dependency_cache.update(sub_dependency_cache) + if sub_errors: + errors.extend(sub_errors) + continue + if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: + solved = dependency_cache[sub_dependant.cache_key] + elif is_gen_callable(call) or is_async_gen_callable(call): + stack = request.scope.get("fastapi_astack") + assert isinstance(stack, AsyncExitStack) + solved = await solve_generator( + call=call, stack=stack, sub_values=sub_values + ) + elif is_coroutine_callable(call): + solved = await call(**sub_values) + else: + solved = await run_in_threadpool(call, **sub_values) + if sub_dependant.name is not None: + values[sub_dependant.name] = solved + if sub_dependant.cache_key not in dependency_cache: + dependency_cache[sub_dependant.cache_key] = solved + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) + values.update(path_values) + values.update(query_values) + values.update(header_values) + values.update(cookie_values) + errors += path_errors + query_errors + header_errors + cookie_errors + if dependant.body_params: + ( + body_values, + body_errors, + ) = await request_body_to_args( # body_params checked above + required_params=dependant.body_params, received_body=body + ) + values.update(body_values) + errors.extend(body_errors) + if dependant.http_connection_param_name: + values[dependant.http_connection_param_name] = request + if dependant.request_param_name and isinstance(request, Request): + values[dependant.request_param_name] = request + elif dependant.websocket_param_name and isinstance(request, WebSocket): + values[dependant.websocket_param_name] = request + if dependant.background_tasks_param_name: + if background_tasks is None: + background_tasks = BackgroundTasks() + values[dependant.background_tasks_param_name] = background_tasks + if dependant.response_param_name: + values[dependant.response_param_name] = response + if dependant.security_scopes_param_name: + values[dependant.security_scopes_param_name] = SecurityScopes( + scopes=dependant.security_scopes + ) + return values, errors, background_tasks, response, dependency_cache + + +def request_params_to_args( + required_params: Sequence[ModelField], + received_params: Union[Mapping[str, Any], QueryParams, Headers], +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + for field in required_params: + if is_scalar_sequence_field(field) and isinstance( + received_params, (QueryParams, Headers) ): - main_dict[key] = main_dict[key] + update_dict[key] + value = received_params.getlist(field.alias) or field.default else: - main_dict[key] = value + value = received_params.get(field.alias) + field_info = field.field_info + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" + if value is None: + if field.required: + errors.append( + ErrorWrapper( + MissingError(), loc=(field_info.in_.value, field.alias) + ) + ) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate( + value, values, loc=(field_info.in_.value, field.alias) + ) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + +async def request_body_to_args( + required_params: List[ModelField], + received_body: Optional[Union[Dict[str, Any], FormData]], +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + if required_params: + field = required_params[0] + field_info = field.field_info + embed = getattr(field_info, "embed", None) + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} -def get_value_or_default( - first_item: Union[DefaultPlaceholder, DefaultType], - *extra_items: Union[DefaultPlaceholder, DefaultType], -) -> Union[DefaultPlaceholder, DefaultType]: - """ - Pass items or `DefaultPlaceholder`s by descending priority. + for field in required_params: + loc: Tuple[str, ...] + if field_alias_omitted: + loc = ("body",) + else: + loc = ("body", field.alias) - The first one to _not_ be a `DefaultPlaceholder` will be returned. + value: Optional[Any] = None + if received_body is not None: + if ( + field.shape in sequence_shapes or field.type_ in sequence_types + ) and isinstance(received_body, FormData): + value = received_body.getlist(field.alias) + else: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + continue + if ( + value is None + or (isinstance(field_info, params.Form) and value == "") + or ( + isinstance(field_info, params.Form) + and field.shape in sequence_shapes + and len(value) == 0 + ) + ): + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + if ( + isinstance(field_info, params.File) + and lenient_issubclass(field.type_, bytes) + and isinstance(value, UploadFile) + ): + value = await value.read() + elif ( + field.shape in sequence_shapes + and isinstance(field_info, params.File) + and lenient_issubclass(field.type_, bytes) + and isinstance(value, sequence_types) + ): + results: List[Union[bytes, str]] = [] + + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]] + ) -> None: + result = await fn() + results.append(result) + + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = sequence_shape_to_type[field.shape](results) + + v_, errors_ = field.validate(value, values, loc=loc) + + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper: + missing_field_error = ErrorWrapper(MissingError(), loc=loc) + return missing_field_error + + +def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + first_param = flat_dependant.body_params[0] + field_info = first_param.field_info + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + check_file_field(first_param) + return first_param + # If one field requires to embed, all have to be embedded + # in case a sub-dependency is evaluated with a single unique body field + # That is combined (embedded) with other body fields + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 + model_name = "Body_" + name + BodyModel: Type[BaseModel] = create_model(model_name) + for f in flat_dependant.body_params: + BodyModel.__fields__[f.name] = f + required = any(True for f in flat_dependant.body_params if f.required) + + BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None} + if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): + BodyFieldInfo: Type[params.Body] = params.File + elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): + BodyFieldInfo = params.Form + else: + BodyFieldInfo = params.Body + + body_param_media_types = [ + f.field_info.media_type + for f in flat_dependant.body_params + if isinstance(f.field_info, params.Body) + ] + if len(set(body_param_media_types)) == 1: + BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] + final_field = create_response_field( + name="body", + type_=BodyModel, + required=required, + alias="body", + field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), + ) + check_file_field(final_field) + return final_field - Otherwise, the first item (a `DefaultPlaceholder`) will be returned. - """ - items = (first_item,) + extra_items - for item in items: - if not isinstance(item, DefaultPlaceholder): - return item - return first_item diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py index 3eeb8291e..95b81eb84 100644 --- a/tests/test_dependency_generic_class.py +++ b/tests/test_dependency_generic_class.py @@ -19,7 +19,7 @@ class SecondGenericType(Generic[T, C]): simple: T, lst: List[T], dct: Dict[T, C], - custom_class: FirstGenericType[T] = Depends(), + custom_class: FirstGenericType[T] = Depends() ): self.simple = simple self.lst = lst @@ -38,7 +38,7 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): "dct": obj.dct, "custom_class": { "simple": obj.custom_class.simple, - "lst": obj.custom_class.lst, + "lst": obj.custom_class.lst }, } From f7aadd24b9b7c27b0b88765e90dd3ec120cc62e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Feb 2023 16:09:18 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/utils.py | 9 ++++----- tests/test_dependency_generic_class.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fastapi/utils.py b/fastapi/utils.py index 1bd2de317..721331131 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -282,16 +282,16 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: annotation=get_typed_annotation(param.annotation, globalns, typevars), ) for param in signature.parameters.values() - if param.name != 'self' + if param.name != "self" ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_annotation( - annotation: Any, - globalns: Dict[str, Any], - typevars: Optional[Dict[str, type]] = None, + annotation: Any, + globalns: Dict[str, Any], + typevars: Optional[Dict[str, type]] = None, ) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) @@ -798,4 +798,3 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: ) check_file_field(final_field) return final_field - diff --git a/tests/test_dependency_generic_class.py b/tests/test_dependency_generic_class.py index 95b81eb84..3eeb8291e 100644 --- a/tests/test_dependency_generic_class.py +++ b/tests/test_dependency_generic_class.py @@ -19,7 +19,7 @@ class SecondGenericType(Generic[T, C]): simple: T, lst: List[T], dct: Dict[T, C], - custom_class: FirstGenericType[T] = Depends() + custom_class: FirstGenericType[T] = Depends(), ): self.simple = simple self.lst = lst @@ -38,7 +38,7 @@ def depend_generic_type(obj: SecondGenericType[str, int] = Depends()): "dct": obj.dct, "custom_class": { "simple": obj.custom_class.simple, - "lst": obj.custom_class.lst + "lst": obj.custom_class.lst, }, } From 40965c982586f1262150766a6210459e5d1126e0 Mon Sep 17 00:00:00 2001 From: yyklimenko Date: Sun, 26 Feb 2023 01:55:16 +0600 Subject: [PATCH 5/6] fixed bug in file structure --- fastapi/{ => dependencies}/stubs.py | 0 fastapi/dependencies/utils.py | 9 +- fastapi/utils.py | 953 ++++++---------------------- 3 files changed, 186 insertions(+), 776 deletions(-) rename fastapi/{ => dependencies}/stubs.py (100%) diff --git a/fastapi/stubs.py b/fastapi/dependencies/stubs.py similarity index 100% rename from fastapi/stubs.py rename to fastapi/dependencies/stubs.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1bf780cd6..ade00816f 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -26,6 +26,7 @@ from fastapi.concurrency import ( contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.stubs import GenericTypeStub from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -263,16 +264,18 @@ def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typevars = None if is_generic_type(call): - origin = call.__origin__ + generic: GenericTypeStub = cast(GenericTypeStub, call) typevars = { typevar.__name__: value - for typevar, value in zip(origin.__parameters__, call.__args__) + for typevar, value in zip( + generic.__origin__.__parameters__, generic.__args__ + ) } + origin: Any = generic.__origin__ call = origin.__init__ signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) - typed_params = [ inspect.Parameter( name=param.name, diff --git a/fastapi/utils.py b/fastapi/utils.py index 721331131..391c47d81 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,800 +1,207 @@ -import dataclasses -import inspect -from contextlib import contextmanager -from copy import deepcopy -from typing import ( - Any, - Callable, - Coroutine, - Dict, - ForwardRef, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) - -import anyio -from fastapi import params -from fastapi.concurrency import ( - AsyncExitStack, - asynccontextmanager, - contextmanager_in_threadpool, -) -from fastapi.dependencies.models import Dependant, SecurityRequirement -from fastapi.dependencies.stubs import GenericTypeStub -from fastapi.logger import logger -from fastapi.security.base import SecurityBase -from fastapi.security.oauth2 import OAuth2, SecurityScopes -from fastapi.security.open_id_connect_url import OpenIdConnect -from fastapi.utils import create_response_field, get_path_param_names -from pydantic import BaseModel, create_model -from pydantic.error_wrappers import ErrorWrapper -from pydantic.errors import MissingError -from pydantic.fields import ( - SHAPE_FROZENSET, - SHAPE_LIST, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_SINGLETON, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - FieldInfo, - ModelField, - Required, - Undefined, -) -from pydantic.schema import get_annotation_from_field_info -from pydantic.typing import evaluate_forwardref +import functools +import re +import warnings +from dataclasses import is_dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast + +import fastapi +from fastapi.datastructures import DefaultPlaceholder, DefaultType +from fastapi.openapi.constants import REF_PREFIX +from pydantic import BaseConfig, BaseModel, create_model +from pydantic.class_validators import Validator +from pydantic.fields import FieldInfo, ModelField, UndefinedType +from pydantic.schema import model_process_schema from pydantic.utils import lenient_issubclass -from starlette.background import BackgroundTasks -from starlette.concurrency import run_in_threadpool -from starlette.datastructures import FormData, Headers, QueryParams, UploadFile -from starlette.requests import HTTPConnection, Request -from starlette.responses import Response -from starlette.websockets import WebSocket - -sequence_shapes = { - SHAPE_LIST, - SHAPE_SET, - SHAPE_FROZENSET, - SHAPE_TUPLE, - SHAPE_SEQUENCE, - SHAPE_TUPLE_ELLIPSIS, -} -sequence_types = (list, set, tuple) -sequence_shape_to_type = { - SHAPE_LIST: list, - SHAPE_SET: set, - SHAPE_TUPLE: tuple, - SHAPE_SEQUENCE: list, - SHAPE_TUPLE_ELLIPSIS: list, -} - - -multipart_not_installed_error = ( - 'Form data requires "python-multipart" to be installed. \n' - 'You can install "python-multipart" with: \n\n' - "pip install python-multipart\n" -) -multipart_incorrect_install_error = ( - 'Form data requires "python-multipart" to be installed. ' - 'It seems you installed "multipart" instead. \n' - 'You can remove "multipart" with: \n\n' - "pip uninstall multipart\n\n" - 'And then install "python-multipart" with: \n\n' - "pip install python-multipart\n" -) - - -def check_file_field(field: ModelField) -> None: - field_info = field.field_info - if isinstance(field_info, params.Form): - try: - # __version__ is available in both multiparts, and can be mocked - from multipart import __version__ # type: ignore - - assert __version__ - try: - # parse_options_header is only available in the right multipart - from multipart.multipart import parse_options_header # type: ignore - - assert parse_options_header - except ImportError: - logger.error(multipart_incorrect_install_error) - raise RuntimeError(multipart_incorrect_install_error) from None - except ImportError: - logger.error(multipart_not_installed_error) - raise RuntimeError(multipart_not_installed_error) from None - - -def get_param_sub_dependant( - *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None -) -> Dependant: - depends: params.Depends = param.default - if depends.dependency: - dependency = depends.dependency - else: - dependency = param.annotation - return get_sub_dependant( - depends=depends, - dependency=dependency, - path=path, - name=param.name, - security_scopes=security_scopes, - ) - - -def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable( - depends.dependency - ), "A parameter-less dependency must have a callable dependency" - return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) - - -def get_sub_dependant( - *, - depends: params.Depends, - dependency: Callable[..., Any], - path: str, - name: Optional[str] = None, - security_scopes: Optional[List[str]] = None, -) -> Dependant: - security_requirement = None - security_scopes = security_scopes or [] - if isinstance(depends, params.Security): - dependency_scopes = depends.scopes - security_scopes.extend(dependency_scopes) - if isinstance(dependency, SecurityBase): - use_scopes: List[str] = [] - if isinstance(dependency, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes - security_requirement = SecurityRequirement( - security_scheme=dependency, scopes=use_scopes - ) - sub_dependant = get_dependant( - path=path, - call=dependency, - name=name, - security_scopes=security_scopes, - use_cache=depends.use_cache, - ) - if security_requirement: - sub_dependant.security_requirements.append(security_requirement) - return sub_dependant - - -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] - - -def get_flat_dependant( - dependant: Dependant, - *, - skip_repeats: bool = False, - visited: Optional[List[CacheKey]] = None, -) -> Dependant: - if visited is None: - visited = [] - visited.append(dependant.cache_key) - - flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), - security_schemes=dependant.security_requirements.copy(), - use_cache=dependant.use_cache, - path=dependant.path, - ) - for sub_dependant in dependant.dependencies: - if skip_repeats and sub_dependant.cache_key in visited: - continue - flat_sub = get_flat_dependant( - sub_dependant, skip_repeats=skip_repeats, visited=visited - ) - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - flat_dependant.security_requirements.extend(flat_sub.security_requirements) - return flat_dependant - -def get_flat_params(dependant: Dependant) -> List[ModelField]: - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - return ( - flat_dependant.path_params - + flat_dependant.query_params - + flat_dependant.header_params - + flat_dependant.cookie_params - ) - - -def is_scalar_field(field: ModelField) -> bool: - field_info = field.field_info - if not ( - field.shape == SHAPE_SINGLETON - and not lenient_issubclass(field.type_, BaseModel) - and not lenient_issubclass(field.type_, sequence_types + (dict,)) - and not dataclasses.is_dataclass(field.type_) - and not isinstance(field_info, params.Body) - ): - return False - if field.sub_fields: - if not all(is_scalar_field(f) for f in field.sub_fields): - return False - return True +if TYPE_CHECKING: # pragma: nocover + from .routing import APIRoute -def is_scalar_sequence_field(field: ModelField) -> bool: - if (field.shape in sequence_shapes) and not lenient_issubclass( - field.type_, BaseModel - ): - if field.sub_fields is not None: - for sub_field in field.sub_fields: - if not is_scalar_field(sub_field): - return False +def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: + if status_code is None: return True - if lenient_issubclass(field.type_, sequence_types): + # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 + if status_code in { + "default", + "1XX", + "2XX", + "3XX", + "4XX", + "5XX", + }: return True - return False - - -def is_generic_type(obj: Any) -> bool: - return hasattr(obj, "__origin__") and hasattr(obj, "__args__") - - -def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: - collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple} - if is_generic_type(annotation): - args = tuple( - substitute_generic_type(arg, typevars) for arg in annotation.__args__ - ) - annotation = collection_shells.get(annotation.__origin__, annotation) - return annotation[args] - return typevars.get(annotation.__name__, annotation) - - -def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: - typevars = None - if is_generic_type(call): - stub: GenericTypeStub = cast(GenericTypeStub, call) - typevars = { - typevar.__name__: value - for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__) - } - call = stub.__origin__.__init__ # type: ignore - - signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) - - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=get_typed_annotation(param.annotation, globalns, typevars), - ) - for param in signature.parameters.values() - if param.name != "self" - ] - typed_signature = inspect.Signature(typed_params) - return typed_signature + current_status_code = int(status_code) + return not (current_status_code < 200 or current_status_code in {204, 304}) -def get_typed_annotation( - annotation: Any, - globalns: Dict[str, Any], - typevars: Optional[Dict[str, type]] = None, -) -> Any: - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) - if typevars: - annotation = substitute_generic_type(annotation, typevars) - return annotation - - -def get_typed_return_annotation(call: Callable[..., Any]) -> Any: - signature = inspect.signature(call) - annotation = signature.return_annotation - - if annotation is inspect.Signature.empty: - return None - - globalns = getattr(call, "__globals__", {}) - return get_typed_annotation(annotation, globalns) - - -def get_dependant( +def get_model_definitions( *, - path: str, - call: Callable[..., Any], - name: Optional[str] = None, - security_scopes: Optional[List[str]] = None, - use_cache: bool = True, -) -> Dependant: - path_param_names = get_path_param_names(path) - endpoint_signature = get_typed_signature(call) - signature_params = endpoint_signature.parameters - dependant = Dependant( - call=call, + flat_models: Set[Union[Type[BaseModel], Type[Enum]]], + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], +) -> Dict[str, Any]: + definitions: Dict[str, Dict[str, Any]] = {} + for model in flat_models: + m_schema, m_definitions, m_nested_models = model_process_schema( + model, model_name_map=model_name_map, ref_prefix=REF_PREFIX + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + if "description" in m_schema: + m_schema["description"] = m_schema["description"].split("\f")[0] + definitions[model_name] = m_schema + return definitions + + +def get_path_param_names(path: str) -> Set[str]: + return set(re.findall("{(.*?)}", path)) + + +def create_response_field( + name: str, + type_: Type[Any], + class_validators: Optional[Dict[str, Validator]] = None, + default: Optional[Any] = None, + required: Union[bool, UndefinedType] = True, + model_config: Type[BaseConfig] = BaseConfig, + field_info: Optional[FieldInfo] = None, + alias: Optional[str] = None, +) -> ModelField: + """ + Create a new response field. Raises if type_ is invalid. + """ + class_validators = class_validators or {} + field_info = field_info or FieldInfo() + + response_field = functools.partial( + ModelField, name=name, - path=path, - security_scopes=security_scopes, - use_cache=use_cache, + type_=type_, + class_validators=class_validators, + default=default, + required=required, + model_config=model_config, + alias=alias, ) - for param_name, param in signature_params.items(): - if isinstance(param.default, params.Depends): - sub_dependant = get_param_sub_dependant( - param=param, path=path, security_scopes=security_scopes - ) - dependant.dependencies.append(sub_dependant) - continue - if add_non_field_param_to_dependency(param=param, dependant=dependant): - continue - param_field = get_param_field( - param=param, default_field_info=params.Query, param_name=param_name - ) - if param_name in path_param_names: - assert is_scalar_field( - field=param_field - ), "Path params must be of one of the supported types" - ignore_default = not isinstance(param.default, params.Path) - param_field = get_param_field( - param=param, - param_name=param_name, - default_field_info=params.Path, - force_type=params.ParamTypes.path, - ignore_default=ignore_default, - ) - add_param_to_fields(field=param_field, dependant=dependant) - elif is_scalar_field(field=param_field): - add_param_to_fields(field=param_field, dependant=dependant) - elif isinstance( - param.default, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): - add_param_to_fields(field=param_field, dependant=dependant) - else: - field_info = param_field.field_info - assert isinstance( - field_info, params.Body - ), f"Param: {param_field.name} can only be a request body, using Body()" - dependant.body_params.append(param_field) - return dependant - - -def add_non_field_param_to_dependency( - *, param: inspect.Parameter, dependant: Dependant -) -> Optional[bool]: - if lenient_issubclass(param.annotation, Request): - dependant.request_param_name = param.name - return True - elif lenient_issubclass(param.annotation, WebSocket): - dependant.websocket_param_name = param.name - return True - elif lenient_issubclass(param.annotation, HTTPConnection): - dependant.http_connection_param_name = param.name - return True - elif lenient_issubclass(param.annotation, Response): - dependant.response_param_name = param.name - return True - elif lenient_issubclass(param.annotation, BackgroundTasks): - dependant.background_tasks_param_name = param.name - return True - elif lenient_issubclass(param.annotation, SecurityScopes): - dependant.security_scopes_param_name = param.name - return True - return None - -def get_param_field( + try: + return response_field(field_info=field_info) + except RuntimeError: + raise fastapi.exceptions.FastAPIError( + "Invalid args for response field! Hint: " + f"check that {type_} is a valid Pydantic field type. " + "If you are using a return type annotation that is not a valid Pydantic " + "field (e.g. Union[Response, dict, None]) you can disable generating the " + "response model from the type annotation with the path operation decorator " + "parameter response_model=None. Read more: " + "https://fastapi.tiangolo.com/tutorial/response-model/" + ) from None + + +def create_cloned_field( + field: ModelField, *, - param: inspect.Parameter, - param_name: str, - default_field_info: Type[params.Param] = params.Param, - force_type: Optional[params.ParamTypes] = None, - ignore_default: bool = False, + cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None, ) -> ModelField: - default_value: Any = Undefined - had_schema = False - if not param.default == param.empty and ignore_default is False: - default_value = param.default - if isinstance(default_value, FieldInfo): - had_schema = True - field_info = default_value - default_value = field_info.default - if ( - isinstance(field_info, params.Param) - and getattr(field_info, "in_", None) is None - ): - field_info.in_ = default_field_info.in_ - if force_type: - field_info.in_ = force_type # type: ignore - else: - field_info = default_field_info(default=default_value) - required = True - if default_value is Required or ignore_default: - required = True - default_value = None - elif default_value is not Undefined: - required = False - annotation: Any = Any - if not param.annotation == param.empty: - annotation = param.annotation - annotation = get_annotation_from_field_info(annotation, field_info, param_name) - if not field_info.alias and getattr(field_info, "convert_underscores", None): - alias = param.name.replace("_", "-") - else: - alias = field_info.alias or param.name - field = create_response_field( - name=param.name, - type_=annotation, - default=default_value, - alias=alias, - required=required, - field_info=field_info, + # _cloned_types has already cloned types, to support recursive models + if cloned_types is None: + cloned_types = {} + original_type = field.type_ + if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): + original_type = original_type.__pydantic_model__ + use_type = original_type + if lenient_issubclass(original_type, BaseModel): + original_type = cast(Type[BaseModel], original_type) + use_type = cloned_types.get(original_type) + if use_type is None: + use_type = create_model(original_type.__name__, __base__=original_type) + cloned_types[original_type] = use_type + for f in original_type.__fields__.values(): + use_type.__fields__[f.name] = create_cloned_field( + f, cloned_types=cloned_types + ) + new_field = create_response_field(name=field.name, type_=use_type) + new_field.has_alias = field.has_alias + new_field.alias = field.alias + new_field.class_validators = field.class_validators + new_field.default = field.default + new_field.required = field.required + new_field.model_config = field.model_config + new_field.field_info = field.field_info + new_field.allow_none = field.allow_none + new_field.validate_always = field.validate_always + if field.sub_fields: + new_field.sub_fields = [ + create_cloned_field(sub_field, cloned_types=cloned_types) + for sub_field in field.sub_fields + ] + if field.key_field: + new_field.key_field = create_cloned_field( + field.key_field, cloned_types=cloned_types + ) + new_field.validators = field.validators + new_field.pre_validators = field.pre_validators + new_field.post_validators = field.post_validators + new_field.parse_json = field.parse_json + new_field.shape = field.shape + new_field.populate_validators() + return new_field + + +def generate_operation_id_for_path( + *, name: str, path: str, method: str +) -> str: # pragma: nocover + warnings.warn( + "fastapi.utils.generate_operation_id_for_path() was deprecated, " + "it is not used internally, and will be removed soon", + DeprecationWarning, + stacklevel=2, ) - if not had_schema and not is_scalar_field(field=field): - field.field_info = params.Body(field_info.default) - if not had_schema and lenient_issubclass(field.type_, UploadFile): - field.field_info = params.File(field_info.default) - - return field - - -def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: - field_info = cast(params.Param, field.field_info) - if field_info.in_ == params.ParamTypes.path: - dependant.path_params.append(field) - elif field_info.in_ == params.ParamTypes.query: - dependant.query_params.append(field) - elif field_info.in_ == params.ParamTypes.header: - dependant.header_params.append(field) - else: - assert ( - field_info.in_ == params.ParamTypes.cookie - ), f"non-body parameters must be in path, query, header or cookie: {field.name}" - dependant.cookie_params.append(field) - - -def is_coroutine_callable(call: Callable[..., Any]) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.iscoroutinefunction(dunder_call) - + operation_id = name + path + operation_id = re.sub(r"\W", "_", operation_id) + operation_id = operation_id + "_" + method.lower() + return operation_id -def is_async_gen_callable(call: Callable[..., Any]) -> bool: - if inspect.isasyncgenfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) - - -def is_gen_callable(call: Callable[..., Any]) -> bool: - if inspect.isgeneratorfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) +def generate_unique_id(route: "APIRoute") -> str: + operation_id = route.name + route.path_format + operation_id = re.sub(r"\W", "_", operation_id) + assert route.methods + operation_id = operation_id + "_" + list(route.methods)[0].lower() + return operation_id -async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] -) -> Any: - if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): - cm = asynccontextmanager(call)(**sub_values) - return await stack.enter_async_context(cm) - -async def solve_dependencies( - *, - request: Union[Request, WebSocket], - dependant: Dependant, - body: Optional[Union[Dict[str, Any], FormData]] = None, - background_tasks: Optional[BackgroundTasks] = None, - response: Optional[Response] = None, - dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, -) -> Tuple[ - Dict[str, Any], - List[ErrorWrapper], - Optional[BackgroundTasks], - Response, - Dict[Tuple[Callable[..., Any], Tuple[str]], Any], -]: - values: Dict[str, Any] = {} - errors: List[ErrorWrapper] = [] - if response is None: - response = Response() - del response.headers["content-length"] - response.status_code = None # type: ignore - dependency_cache = dependency_cache or {} - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key - ) - call = sub_dependant.call - use_sub_dependant = sub_dependant +def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: + for key, value in update_dict.items(): if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides + key in main_dict + and isinstance(main_dict[key], dict) + and isinstance(value, dict) ): - original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) - use_path: str = sub_dependant.path # type: ignore - use_sub_dependant = get_dependant( - path=use_path, - call=call, - name=sub_dependant.name, - security_scopes=sub_dependant.security_scopes, - ) - - solved_result = await solve_dependencies( - request=request, - dependant=use_sub_dependant, - body=body, - background_tasks=background_tasks, - response=response, - dependency_overrides_provider=dependency_overrides_provider, - dependency_cache=dependency_cache, - ) - ( - sub_values, - sub_errors, - background_tasks, - _, # the subdependency returns the same response we have - sub_dependency_cache, - ) = solved_result - dependency_cache.update(sub_dependency_cache) - if sub_errors: - errors.extend(sub_errors) - continue - if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: - solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - stack = request.scope.get("fastapi_astack") - assert isinstance(stack, AsyncExitStack) - solved = await solve_generator( - call=call, stack=stack, sub_values=sub_values - ) - elif is_coroutine_callable(call): - solved = await call(**sub_values) - else: - solved = await run_in_threadpool(call, **sub_values) - if sub_dependant.name is not None: - values[sub_dependant.name] = solved - if sub_dependant.cache_key not in dependency_cache: - dependency_cache[sub_dependant.cache_key] = solved - path_values, path_errors = request_params_to_args( - dependant.path_params, request.path_params - ) - query_values, query_errors = request_params_to_args( - dependant.query_params, request.query_params - ) - header_values, header_errors = request_params_to_args( - dependant.header_params, request.headers - ) - cookie_values, cookie_errors = request_params_to_args( - dependant.cookie_params, request.cookies - ) - values.update(path_values) - values.update(query_values) - values.update(header_values) - values.update(cookie_values) - errors += path_errors + query_errors + header_errors + cookie_errors - if dependant.body_params: - ( - body_values, - body_errors, - ) = await request_body_to_args( # body_params checked above - required_params=dependant.body_params, received_body=body - ) - values.update(body_values) - errors.extend(body_errors) - if dependant.http_connection_param_name: - values[dependant.http_connection_param_name] = request - if dependant.request_param_name and isinstance(request, Request): - values[dependant.request_param_name] = request - elif dependant.websocket_param_name and isinstance(request, WebSocket): - values[dependant.websocket_param_name] = request - if dependant.background_tasks_param_name: - if background_tasks is None: - background_tasks = BackgroundTasks() - values[dependant.background_tasks_param_name] = background_tasks - if dependant.response_param_name: - values[dependant.response_param_name] = response - if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.security_scopes - ) - return values, errors, background_tasks, response, dependency_cache - - -def request_params_to_args( - required_params: Sequence[ModelField], - received_params: Union[Mapping[str, Any], QueryParams, Headers], -) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: - values = {} - errors = [] - for field in required_params: - if is_scalar_sequence_field(field) and isinstance( - received_params, (QueryParams, Headers) + deep_dict_update(main_dict[key], value) + elif ( + key in main_dict + and isinstance(main_dict[key], list) + and isinstance(update_dict[key], list) ): - value = received_params.getlist(field.alias) or field.default - else: - value = received_params.get(field.alias) - field_info = field.field_info - assert isinstance( - field_info, params.Param - ), "Params must be subclasses of Param" - if value is None: - if field.required: - errors.append( - ErrorWrapper( - MissingError(), loc=(field_info.in_.value, field.alias) - ) - ) - else: - values[field.name] = deepcopy(field.default) - continue - v_, errors_ = field.validate( - value, values, loc=(field_info.in_.value, field.alias) - ) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - errors.extend(errors_) + main_dict[key] = main_dict[key] + update_dict[key] else: - values[field.name] = v_ - return values, errors + main_dict[key] = value -async def request_body_to_args( - required_params: List[ModelField], - received_body: Optional[Union[Dict[str, Any], FormData]], -) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: - values = {} - errors = [] - if required_params: - field = required_params[0] - field_info = field.field_info - embed = getattr(field_info, "embed", None) - field_alias_omitted = len(required_params) == 1 and not embed - if field_alias_omitted: - received_body = {field.alias: received_body} +def get_value_or_default( + first_item: Union[DefaultPlaceholder, DefaultType], + *extra_items: Union[DefaultPlaceholder, DefaultType], +) -> Union[DefaultPlaceholder, DefaultType]: + """ + Pass items or `DefaultPlaceholder`s by descending priority. - for field in required_params: - loc: Tuple[str, ...] - if field_alias_omitted: - loc = ("body",) - else: - loc = ("body", field.alias) + The first one to _not_ be a `DefaultPlaceholder` will be returned. - value: Optional[Any] = None - if received_body is not None: - if ( - field.shape in sequence_shapes or field.type_ in sequence_types - ) and isinstance(received_body, FormData): - value = received_body.getlist(field.alias) - else: - try: - value = received_body.get(field.alias) - except AttributeError: - errors.append(get_missing_field_error(loc)) - continue - if ( - value is None - or (isinstance(field_info, params.Form) and value == "") - or ( - isinstance(field_info, params.Form) - and field.shape in sequence_shapes - and len(value) == 0 - ) - ): - if field.required: - errors.append(get_missing_field_error(loc)) - else: - values[field.name] = deepcopy(field.default) - continue - if ( - isinstance(field_info, params.File) - and lenient_issubclass(field.type_, bytes) - and isinstance(value, UploadFile) - ): - value = await value.read() - elif ( - field.shape in sequence_shapes - and isinstance(field_info, params.File) - and lenient_issubclass(field.type_, bytes) - and isinstance(value, sequence_types) - ): - results: List[Union[bytes, str]] = [] - - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]] - ) -> None: - result = await fn() - results.append(result) - - async with anyio.create_task_group() as tg: - for sub_value in value: - tg.start_soon(process_fn, sub_value.read) - value = sequence_shape_to_type[field.shape](results) - - v_, errors_ = field.validate(value, values, loc=loc) - - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - errors.extend(errors_) - else: - values[field.name] = v_ - return values, errors - - -def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper: - missing_field_error = ErrorWrapper(MissingError(), loc=loc) - return missing_field_error - - -def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: - flat_dependant = get_flat_dependant(dependant) - if not flat_dependant.body_params: - return None - first_param = flat_dependant.body_params[0] - field_info = first_param.field_info - embed = getattr(field_info, "embed", None) - body_param_names_set = {param.name for param in flat_dependant.body_params} - if len(body_param_names_set) == 1 and not embed: - check_file_field(first_param) - return first_param - # If one field requires to embed, all have to be embedded - # in case a sub-dependency is evaluated with a single unique body field - # That is combined (embedded) with other body fields - for param in flat_dependant.body_params: - setattr(param.field_info, "embed", True) # noqa: B010 - model_name = "Body_" + name - BodyModel: Type[BaseModel] = create_model(model_name) - for f in flat_dependant.body_params: - BodyModel.__fields__[f.name] = f - required = any(True for f in flat_dependant.body_params if f.required) - - BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None} - if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): - BodyFieldInfo: Type[params.Body] = params.File - elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): - BodyFieldInfo = params.Form - else: - BodyFieldInfo = params.Body - - body_param_media_types = [ - f.field_info.media_type - for f in flat_dependant.body_params - if isinstance(f.field_info, params.Body) - ] - if len(set(body_param_media_types)) == 1: - BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] - final_field = create_response_field( - name="body", - type_=BodyModel, - required=required, - alias="body", - field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), - ) - check_file_field(final_field) - return final_field + Otherwise, the first item (a `DefaultPlaceholder`) will be returned. + """ + items = (first_item,) + extra_items + for item in items: + if not isinstance(item, DefaultPlaceholder): + return item + return first_item From d2ec5d6620d8183a2e3ab972d2d1b7e88d76c36a Mon Sep 17 00:00:00 2001 From: yyklimenko Date: Mon, 27 Feb 2023 19:39:41 +0600 Subject: [PATCH 6/6] The code has been refactored --- fastapi/dependencies/protocols.py | 11 +++++++++++ fastapi/dependencies/stubs.py | 9 --------- fastapi/dependencies/utils.py | 10 ++++------ 3 files changed, 15 insertions(+), 15 deletions(-) create mode 100644 fastapi/dependencies/protocols.py delete mode 100644 fastapi/dependencies/stubs.py diff --git a/fastapi/dependencies/protocols.py b/fastapi/dependencies/protocols.py new file mode 100644 index 000000000..857e2169e --- /dev/null +++ b/fastapi/dependencies/protocols.py @@ -0,0 +1,11 @@ +from typing import Any, Tuple + +from typing_extensions import Protocol + + +class GenericTypeProtocol(Protocol): + class OriginTypeProtocol(Protocol): + __parameters__: Tuple[Any] + + __origin__: OriginTypeProtocol + __args__: Tuple[Any] diff --git a/fastapi/dependencies/stubs.py b/fastapi/dependencies/stubs.py deleted file mode 100644 index b9ab48a69..000000000 --- a/fastapi/dependencies/stubs.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Any, Tuple - - -class GenericTypeStub: - class OriginTypeStub: - __parameters__: Tuple[Any] - - __origin__: OriginTypeStub - __args__: Tuple[Any] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ade00816f..a321a33fc 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -26,7 +26,7 @@ from fastapi.concurrency import ( contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement -from fastapi.dependencies.stubs import GenericTypeStub +from fastapi.dependencies.protocols import GenericTypeProtocol from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -264,14 +264,12 @@ def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: typevars = None if is_generic_type(call): - generic: GenericTypeStub = cast(GenericTypeStub, call) + generic: GenericTypeProtocol = cast(GenericTypeProtocol, call) + origin: Any = generic.__origin__ typevars = { typevar.__name__: value - for typevar, value in zip( - generic.__origin__.__parameters__, generic.__args__ - ) + for typevar, value in zip(origin.__parameters__, generic.__args__) } - origin: Any = generic.__origin__ call = origin.__init__ signature = inspect.signature(call)