import inspect from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass from typing import ( Any, Callable, Coroutine, Dict, ForwardRef, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast, ) import anyio from fastapi import params from fastapi._compat import ( PYDANTIC_V2, ErrorWrapper, ModelField, Required, Undefined, _regenerate_error_with_loc, copy_field_info, create_body_model, evaluate_forwardref, field_annotation_is_scalar, get_annotation_from_field_info, get_missing_field_error, is_bytes_field, is_bytes_sequence_field, is_scalar_field, is_scalar_sequence_field, is_sequence_field, is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_sequence_annotation, lenient_issubclass, sequence_types, serialize_sequence_value, value_is_sequence, ) from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, run_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.utils import create_model_field, get_path_param_names from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket from typing_extensions import Annotated, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' 'You can install "python-multipart" with: \n\n' "pip install python-multipart\n" ) multipart_incorrect_install_error = ( 'Form data requires "python-multipart" to be installed. ' 'It seems you installed "multipart" instead. \n' 'You can remove "multipart" with: \n\n' "pip uninstall multipart\n\n" 'And then install "python-multipart" with: \n\n' "pip install python-multipart\n" ) def ensure_multipart_is_installed() -> None: try: # __version__ is available in both multiparts, and can be mocked from multipart import __version__ # type: ignore assert __version__ try: # parse_options_header is only available in the right multipart from multipart.multipart import parse_options_header # type: ignore assert parse_options_header except ImportError: logger.error(multipart_incorrect_install_error) raise RuntimeError(multipart_incorrect_install_error) from None except ImportError: logger.error(multipart_not_installed_error) raise RuntimeError(multipart_not_installed_error) from None def get_param_sub_dependant( *, param_name: str, depends: params.Depends, path: str, security_scopes: Optional[List[str]] = None, ) -> Dependant: assert depends.dependency return get_sub_dependant( depends=depends, dependency=depends.dependency, path=path, name=param_name, security_scopes=security_scopes, ) def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) def get_sub_dependant( *, depends: params.Depends, dependency: Callable[..., Any], path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, ) -> Dependant: security_requirement = None security_scopes = security_scopes or [] if isinstance(depends, params.Security): dependency_scopes = depends.scopes security_scopes.extend(dependency_scopes) if isinstance(dependency, SecurityBase): use_scopes: List[str] = [] if isinstance(dependency, (OAuth2, OpenIdConnect)): use_scopes = security_scopes security_requirement = SecurityRequirement( security_scheme=dependency, scopes=use_scopes ) sub_dependant = get_dependant( path=path, call=dependency, name=name, security_scopes=security_scopes, use_cache=depends.use_cache, limiter=depends.limiter, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] def get_flat_dependant( dependant: Dependant, *, skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, ) -> Dependant: if visited is None: visited = [] visited.append(dependant.cache_key) flat_dependant = Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, limiter=dependant.limiter, path=dependant.path, ) for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( sub_dependant, skip_repeats=skip_repeats, visited=visited ) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.security_requirements.extend(flat_sub.security_requirements) return flat_dependant def get_flat_params(dependant: Dependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) return ( flat_dependant.path_params + flat_dependant.query_params + flat_dependant.header_params + flat_dependant.cookie_params ) def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation(param.annotation, globalns), ) for param in signature.parameters.values() ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) return annotation def get_typed_return_annotation(call: Callable[..., Any]) -> Any: signature = inspect.signature(call) annotation = signature.return_annotation if annotation is inspect.Signature.empty: return None globalns = getattr(call, "__globals__", {}) return get_typed_annotation(annotation, globalns) def get_dependant( *, path: str, call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, limiter: Optional[anyio.CapacityLimiter] = None, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters dependant = Dependant( call=call, name=name, path=path, security_scopes=security_scopes, use_cache=use_cache, limiter=limiter, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names param_details = analyze_param( param_name=param_name, annotation=param.annotation, value=param.default, is_path_param=is_path_param, ) if param_details.depends is not None: sub_dependant = get_param_sub_dependant( param_name=param_name, depends=param_details.depends, path=path, security_scopes=security_scopes, ) dependant.dependencies.append(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, type_annotation=param_details.type_annotation, dependant=dependant, ): assert ( param_details.field is None ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_details.field is not None if is_body_param(param_field=param_details.field, is_path_param=is_path_param): dependant.body_params.append(param_details.field) else: add_param_to_fields(field=param_details.field, dependant=dependant) return dependant def add_non_field_param_to_dependency( *, param_name: str, type_annotation: Any, dependant: Dependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True elif lenient_issubclass(type_annotation, WebSocket): dependant.websocket_param_name = param_name return True elif lenient_issubclass(type_annotation, HTTPConnection): dependant.http_connection_param_name = param_name return True elif lenient_issubclass(type_annotation, Response): dependant.response_param_name = param_name return True elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): dependant.background_tasks_param_name = param_name return True elif lenient_issubclass(type_annotation, SecurityScopes): dependant.security_scopes_param_name = param_name return True return None @dataclass class ParamDetails: type_annotation: Any depends: Optional[params.Depends] field: Optional[ModelField] def analyze_param( *, param_name: str, annotation: Any, value: Any, is_path_param: bool, ) -> ParamDetails: field_info = None depends = None type_annotation: Any = Any use_annotation: Any = Any if annotation is not inspect.Signature.empty: use_annotation = annotation type_annotation = annotation # Extract Annotated info if get_origin(use_annotation) is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] fastapi_annotations = [ arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends)) ] fastapi_specific_annotations = [ arg for arg in fastapi_annotations if isinstance(arg, (params.Param, params.Body, params.Depends)) ] if fastapi_specific_annotations: fastapi_annotation: Union[FieldInfo, params.Depends, None] = ( fastapi_specific_annotations[-1] ) else: fastapi_annotation = None # Set default for Annotated FieldInfo if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. field_info = copy_field_info( field_info=fastapi_annotation, annotation=use_annotation ) assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." ) if value is not inspect.Signature.empty: assert not is_path_param, "Path parameters cannot have default values" field_info.default = value else: field_info.default = Required # Get Annotated Depends elif isinstance(fastapi_annotation, params.Depends): depends = fastapi_annotation # Get Depends from default value if isinstance(value, params.Depends): assert depends is None, ( "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" ) assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" f" default value together for {param_name!r}" ) depends = value # Get FieldInfo from default value elif isinstance(value, FieldInfo): assert field_info is None, ( "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" ) field_info = value if PYDANTIC_V2: field_info.annotation = type_annotation # Get Depends from type annotation if depends is not None and depends.dependency is None: # Copy `depends` before mutating it depends = copy(depends) depends.dependency = type_annotation # Handle non-param type annotations like Request if lenient_issubclass( type_annotation, ( Request, WebSocket, HTTPConnection, Response, StarletteBackgroundTasks, SecurityScopes, ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" assert ( field_info is None ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else Required if is_path_param: # We might check here that `default_value is Required`, but the fact is that the same # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=use_annotation) elif is_uploadfile_or_nonable_uploadfile_annotation( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=use_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=use_annotation, default=default_value) else: field_info = params.Query(annotation=use_annotation, default=default_value) field = None # It's a field_info, not a dependency if field_info is not None: # Handle field_info.in_ if is_path_param: assert isinstance(field_info, params.Path), ( f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" ) elif ( isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None ): field_info.in_ = params.ParamTypes.query use_annotation_from_field_info = get_annotation_from_field_info( use_annotation, field_info, param_name, ) if isinstance(field_info, params.Form): ensure_multipart_is_installed() if not field_info.alias and getattr(field_info, "convert_underscores", None): alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name field_info.alias = alias field = create_model_field( name=param_name, type_=use_annotation_from_field_info, default=field_info.default, alias=alias, required=field_info.default in (Required, Undefined), field_info=field_info, ) return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: if is_path_param: assert is_scalar_field( field=param_field ), "Path params must be of one of the supported types" return False elif is_scalar_field(field=param_field): return False elif isinstance( param_field.field_info, (params.Query, params.Header) ) and is_scalar_sequence_field(param_field): return False else: assert isinstance( param_field.field_info, params.Body ), f"Param: {param_field.name} can only be a request body, using Body()" return True def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) if field_info_in == params.ParamTypes.path: dependant.path_params.append(field) elif field_info_in == params.ParamTypes.query: dependant.query_params.append(field) elif field_info_in == params.ParamTypes.header: dependant.header_params.append(field) else: assert ( field_info_in == params.ParamTypes.cookie ), f"non-body parameters must be in path, query, header or cookie: {field.name}" dependant.cookie_params.append(field) def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.iscoroutinefunction(dunder_call) def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) async def solve_generator( *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), limiter=limiter) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @dataclass class SolvedDependency: values: Dict[str, Any] errors: List[Any] background_tasks: Optional[StarletteBackgroundTasks] response: Response dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, async_exit_stack: AsyncExitStack, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore dependency_cache = dependency_cache or {} sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key ) call = sub_dependant.call use_sub_dependant = sub_dependant if ( dependency_overrides_provider and dependency_overrides_provider.dependency_overrides ): original_call = sub_dependant.call call = getattr( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, limiter=sub_dependant.limiter, ) solved_result = await solve_dependencies( request=request, dependant=use_sub_dependant, body=body, background_tasks=background_tasks, response=response, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, ) background_tasks = solved_result.background_tasks dependency_cache.update(solved_result.dependency_cache) if solved_result.errors: errors.extend(solved_result.errors) continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( call=call, stack=async_exit_stack, sub_values=solved_result.values, limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, **solved_result.values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) query_values, query_errors = request_params_to_args( dependant.query_params, request.query_params ) header_values, header_errors = request_params_to_args( dependant.header_params, request.headers ) cookie_values, cookie_errors = request_params_to_args( dependant.cookie_params, request.cookies ) values.update(path_values) values.update(query_values) values.update(header_values) values.update(cookie_values) errors += path_errors + query_errors + header_errors + cookie_errors if dependant.body_params: ( body_values, body_errors, ) = await request_body_to_args( # body_params checked above required_params=dependant.body_params, received_body=body ) values.update(body_values) errors.extend(body_errors) if dependant.http_connection_param_name: values[dependant.http_connection_param_name] = request if dependant.request_param_name and isinstance(request, Request): values[dependant.request_param_name] = request elif dependant.websocket_param_name and isinstance(request, WebSocket): values[dependant.websocket_param_name] = request if dependant.background_tasks_param_name: if background_tasks is None: background_tasks = BackgroundTasks() values[dependant.background_tasks_param_name] = background_tasks if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) return SolvedDependency( values=values, errors=errors, background_tasks=background_tasks, response=response, dependency_cache=dependency_cache, ) def request_params_to_args( required_params: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], ) -> Tuple[Dict[str, Any], List[Any]]: values = {} errors = [] for field in required_params: if is_scalar_sequence_field(field) and isinstance( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = field.field_info assert isinstance( field_info, params.Param ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) if value is None: if field.required: errors.append(get_missing_field_error(loc=loc)) else: values[field.name] = deepcopy(field.default) continue v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) errors.extend(new_errors) else: values[field.name] = v_ return values, errors async def request_body_to_args( required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: values = {} errors: List[Dict[str, Any]] = [] if required_params: field = required_params[0] field_info = field.field_info embed = getattr(field_info, "embed", None) field_alias_omitted = len(required_params) == 1 and not embed if field_alias_omitted: received_body = {field.alias: received_body} for field in required_params: loc: Tuple[str, ...] if field_alias_omitted: loc = ("body",) else: loc = ("body", field.alias) value: Optional[Any] = None if received_body is not None: if (is_sequence_field(field)) and isinstance(received_body, FormData): value = received_body.getlist(field.alias) else: try: value = received_body.get(field.alias) except AttributeError: errors.append(get_missing_field_error(loc)) continue if ( value is None or (isinstance(field_info, params.Form) and value == "") or ( isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0 ) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue if ( isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile) ): value = await value.read() elif ( is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value) ): # For types assert isinstance(value, sequence_types) # type: ignore[arg-type] results: List[Union[bytes, str]] = [] async def process_fn( fn: Callable[[], Coroutine[Any, Any, Any]], ) -> None: result = await fn() results.append(result) # noqa: B023 async with anyio.create_task_group() as tg: for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, list): errors.extend(errors_) elif errors_: errors.append(errors_) else: values[field.name] = v_ return values, errors def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] field_info = first_param.field_info embed = getattr(field_info, "embed", None) body_param_names_set = {param.name for param in flat_dependant.body_params} if len(body_param_names_set) == 1 and not embed: return first_param # If one field requires to embed, all have to be embedded # in case a sub-dependency is evaluated with a single unique body field # That is combined (embedded) with other body fields for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name BodyModel = create_body_model( fields=flat_dependant.body_params, model_name=model_name ) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = { "annotation": BodyModel, "alias": "body", } if not required: BodyFieldInfo_kwargs["default"] = None if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): BodyFieldInfo: Type[params.Body] = params.File elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): BodyFieldInfo = params.Form else: BodyFieldInfo = params.Body body_param_media_types = [ f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] final_field = create_model_field( name="body", type_=BodyModel, required=required, alias="body", field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), ) return final_field