Browse Source

fixed bug in file structure

pull/6038/head
yyklimenko 2 years ago
parent
commit
40965c9825
  1. 0
      fastapi/dependencies/stubs.py
  2. 9
      fastapi/dependencies/utils.py
  3. 953
      fastapi/utils.py

0
fastapi/stubs.py → fastapi/dependencies/stubs.py

9
fastapi/dependencies/utils.py

@ -26,6 +26,7 @@ from fastapi.concurrency import (
contextmanager_in_threadpool, contextmanager_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.stubs import GenericTypeStub
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import OAuth2, SecurityScopes
@ -263,16 +264,18 @@ def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None typevars = None
if is_generic_type(call): if is_generic_type(call):
origin = call.__origin__ generic: GenericTypeStub = cast(GenericTypeStub, call)
typevars = { typevars = {
typevar.__name__: value typevar.__name__: value
for typevar, value in zip(origin.__parameters__, call.__args__) for typevar, value in zip(
generic.__origin__.__parameters__, generic.__args__
)
} }
origin: Any = generic.__origin__
call = origin.__init__ call = origin.__init__
signature = inspect.signature(call) signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {}) globalns = getattr(call, "__globals__", {})
typed_params = [ typed_params = [
inspect.Parameter( inspect.Parameter(
name=param.name, name=param.name,

953
fastapi/utils.py

@ -1,800 +1,207 @@
import dataclasses import functools
import inspect import re
from contextlib import contextmanager import warnings
from copy import deepcopy from dataclasses import is_dataclass
from typing import ( from enum import Enum
Any, from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast
Callable,
Coroutine, import fastapi
Dict, from fastapi.datastructures import DefaultPlaceholder, DefaultType
ForwardRef, from fastapi.openapi.constants import REF_PREFIX
List, from pydantic import BaseConfig, BaseModel, create_model
Mapping, from pydantic.class_validators import Validator
Optional, from pydantic.fields import FieldInfo, ModelField, UndefinedType
Sequence, from pydantic.schema import model_process_schema
Tuple,
Type,
Union,
cast,
)
import anyio
from fastapi import params
from fastapi.concurrency import (
AsyncExitStack,
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.stubs import GenericTypeStub
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import create_response_field, get_path_param_names
from pydantic import BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
FieldInfo,
ModelField,
Required,
Undefined,
)
from pydantic.schema import get_annotation_from_field_info
from pydantic.typing import evaluate_forwardref
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_types = (list, set, tuple)
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
multipart_incorrect_install_error = (
'Form data requires "python-multipart" to be installed. '
'It seems you installed "multipart" instead. \n'
'You can remove "multipart" with: \n\n'
"pip uninstall multipart\n\n"
'And then install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
def check_file_field(field: ModelField) -> None:
field_info = field.field_info
if isinstance(field_info, params.Form):
try:
# __version__ is available in both multiparts, and can be mocked
from multipart import __version__ # type: ignore
assert __version__
try:
# parse_options_header is only available in the right multipart
from multipart.multipart import parse_options_header # type: ignore
assert parse_options_header
except ImportError:
logger.error(multipart_incorrect_install_error)
raise RuntimeError(multipart_incorrect_install_error) from None
except ImportError:
logger.error(multipart_not_installed_error)
raise RuntimeError(multipart_not_installed_error) from None
def get_param_sub_dependant(
*, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
) -> Dependant:
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(
depends=depends,
dependency=dependency,
path=path,
name=param.name,
security_scopes=security_scopes,
)
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
def get_sub_dependant(
*,
depends: params.Depends,
dependency: Callable[..., Any],
path: str,
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
if isinstance(dependency, SecurityBase):
use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant(
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
return flat_dependant
def get_flat_params(dependant: Dependant) -> List[ModelField]: if TYPE_CHECKING: # pragma: nocover
flat_dependant = get_flat_dependant(dependant, skip_repeats=True) from .routing import APIRoute
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def is_scalar_field(field: ModelField) -> bool:
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, sequence_types + (dict,))
and not dataclasses.is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_scalar_field(f) for f in field.sub_fields):
return False
return True
def is_scalar_sequence_field(field: ModelField) -> bool: def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass( if status_code is None:
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_scalar_field(sub_field):
return False
return True return True
if lenient_issubclass(field.type_, sequence_types): # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
if status_code in {
"default",
"1XX",
"2XX",
"3XX",
"4XX",
"5XX",
}:
return True return True
return False current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 304})
def is_generic_type(obj: Any) -> bool:
return hasattr(obj, "__origin__") and hasattr(obj, "__args__")
def substitute_generic_type(annotation: Any, typevars: Dict[str, Any]) -> Any:
collection_shells = {list: List, set: List, dict: Dict, tuple: Tuple}
if is_generic_type(annotation):
args = tuple(
substitute_generic_type(arg, typevars) for arg in annotation.__args__
)
annotation = collection_shells.get(annotation.__origin__, annotation)
return annotation[args]
return typevars.get(annotation.__name__, annotation)
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
typevars = None
if is_generic_type(call):
stub: GenericTypeStub = cast(GenericTypeStub, call)
typevars = {
typevar.__name__: value
for typevar, value in zip(stub.__origin__.__parameters__, stub.__args__)
}
call = stub.__origin__.__init__ # type: ignore
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns, typevars),
)
for param in signature.parameters.values()
if param.name != "self"
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation( def get_model_definitions(
annotation: Any,
globalns: Dict[str, Any],
typevars: Optional[Dict[str, type]] = None,
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
if typevars:
annotation = substitute_generic_type(annotation, typevars)
return annotation
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_dependant(
*, *,
path: str, flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
call: Callable[..., Any], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
name: Optional[str] = None, ) -> Dict[str, Any]:
security_scopes: Optional[List[str]] = None, definitions: Dict[str, Dict[str, Any]] = {}
use_cache: bool = True, for model in flat_models:
) -> Dependant: m_schema, m_definitions, m_nested_models = model_process_schema(
path_param_names = get_path_param_names(path) model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
endpoint_signature = get_typed_signature(call) )
signature_params = endpoint_signature.parameters definitions.update(m_definitions)
dependant = Dependant( model_name = model_name_map[model]
call=call, if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
definitions[model_name] = m_schema
return definitions
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = True,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
field_info = field_info or FieldInfo()
response_field = functools.partial(
ModelField,
name=name, name=name,
path=path, type_=type_,
security_scopes=security_scopes, class_validators=class_validators,
use_cache=use_cache, default=default,
required=required,
model_config=model_config,
alias=alias,
) )
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
sub_dependant = get_param_sub_dependant(
param=param, path=path, security_scopes=security_scopes
)
dependant.dependencies.append(sub_dependant)
continue
if add_non_field_param_to_dependency(param=param, dependant=dependant):
continue
param_field = get_param_field(
param=param, default_field_info=params.Query, param_name=param_name
)
if param_name in path_param_names:
assert is_scalar_field(
field=param_field
), "Path params must be of one of the supported types"
ignore_default = not isinstance(param.default, params.Path)
param_field = get_param_field(
param=param,
param_name=param_name,
default_field_info=params.Path,
force_type=params.ParamTypes.path,
ignore_default=ignore_default,
)
add_param_to_fields(field=param_field, dependant=dependant)
elif is_scalar_field(field=param_field):
add_param_to_fields(field=param_field, dependant=dependant)
elif isinstance(
param.default, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
add_param_to_fields(field=param_field, dependant=dependant)
else:
field_info = param_field.field_info
assert isinstance(
field_info, params.Body
), f"Param: {param_field.name} can only be a request body, using Body()"
dependant.body_params.append(param_field)
return dependant
def add_non_field_param_to_dependency(
*, param: inspect.Parameter, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param.name
return True
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name
return True
elif lenient_issubclass(param.annotation, HTTPConnection):
dependant.http_connection_param_name = param.name
return True
elif lenient_issubclass(param.annotation, Response):
dependant.response_param_name = param.name
return True
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param.name
return True
elif lenient_issubclass(param.annotation, SecurityScopes):
dependant.security_scopes_param_name = param.name
return True
return None
def get_param_field( try:
return response_field(field_info=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
"Invalid args for response field! Hint: "
f"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
) from None
def create_cloned_field(
field: ModelField,
*, *,
param: inspect.Parameter, cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None,
param_name: str,
default_field_info: Type[params.Param] = params.Param,
force_type: Optional[params.ParamTypes] = None,
ignore_default: bool = False,
) -> ModelField: ) -> ModelField:
default_value: Any = Undefined # _cloned_types has already cloned types, to support recursive models
had_schema = False if cloned_types is None:
if not param.default == param.empty and ignore_default is False: cloned_types = {}
default_value = param.default original_type = field.type_
if isinstance(default_value, FieldInfo): if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
had_schema = True original_type = original_type.__pydantic_model__
field_info = default_value use_type = original_type
default_value = field_info.default if lenient_issubclass(original_type, BaseModel):
if ( original_type = cast(Type[BaseModel], original_type)
isinstance(field_info, params.Param) use_type = cloned_types.get(original_type)
and getattr(field_info, "in_", None) is None if use_type is None:
): use_type = create_model(original_type.__name__, __base__=original_type)
field_info.in_ = default_field_info.in_ cloned_types[original_type] = use_type
if force_type: for f in original_type.__fields__.values():
field_info.in_ = force_type # type: ignore use_type.__fields__[f.name] = create_cloned_field(
else: f, cloned_types=cloned_types
field_info = default_field_info(default=default_value) )
required = True new_field = create_response_field(name=field.name, type_=use_type)
if default_value is Required or ignore_default: new_field.has_alias = field.has_alias
required = True new_field.alias = field.alias
default_value = None new_field.class_validators = field.class_validators
elif default_value is not Undefined: new_field.default = field.default
required = False new_field.required = field.required
annotation: Any = Any new_field.model_config = field.model_config
if not param.annotation == param.empty: new_field.field_info = field.field_info
annotation = param.annotation new_field.allow_none = field.allow_none
annotation = get_annotation_from_field_info(annotation, field_info, param_name) new_field.validate_always = field.validate_always
if not field_info.alias and getattr(field_info, "convert_underscores", None): if field.sub_fields:
alias = param.name.replace("_", "-") new_field.sub_fields = [
else: create_cloned_field(sub_field, cloned_types=cloned_types)
alias = field_info.alias or param.name for sub_field in field.sub_fields
field = create_response_field( ]
name=param.name, if field.key_field:
type_=annotation, new_field.key_field = create_cloned_field(
default=default_value, field.key_field, cloned_types=cloned_types
alias=alias, )
required=required, new_field.validators = field.validators
field_info=field_info, new_field.pre_validators = field.pre_validators
new_field.post_validators = field.post_validators
new_field.parse_json = field.parse_json
new_field.shape = field.shape
new_field.populate_validators()
return new_field
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
stacklevel=2,
) )
if not had_schema and not is_scalar_field(field=field): operation_id = name + path
field.field_info = params.Body(field_info.default) operation_id = re.sub(r"\W", "_", operation_id)
if not had_schema and lenient_issubclass(field.type_, UploadFile): operation_id = operation_id + "_" + method.lower()
field.field_info = params.File(field_info.default) return operation_id
return field
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = cast(params.Param, field.field_info)
if field_info.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif field_info.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif field_info.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
field_info.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
dependant.cookie_params.append(field)
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
def generate_unique_id(route: "APIRoute") -> str:
operation_id = route.name + route.path_format
operation_id = re.sub(r"\W", "_", operation_id)
assert route.methods
operation_id = operation_id + "_" + list(route.methods)[0].lower()
return operation_id
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
async def solve_dependencies( for key, value in update_dict.items():
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Response,
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if ( if (
dependency_overrides_provider key in main_dict
and dependency_overrides_provider.dependency_overrides and isinstance(main_dict[key], dict)
and isinstance(value, dict)
): ):
original_call = sub_dependant.call deep_dict_update(main_dict[key], value)
call = getattr( elif (
dependency_overrides_provider, "dependency_overrides", {} key in main_dict
).get(original_call, original_call) and isinstance(main_dict[key], list)
use_path: str = sub_dependant.path # type: ignore and isinstance(update_dict[key], list)
use_sub_dependant = get_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
)
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
(
sub_values,
sub_errors,
background_tasks,
_, # the subdependency returns the same response we have
sub_dependency_cache,
) = solved_result
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack)
solved = await solve_generator(
call=call, stack=stack, sub_values=sub_values
)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
else:
solved = await run_in_threadpool(call, **sub_values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
required_params=dependant.body_params, received_body=body
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache
def request_params_to_args(
required_params: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
for field in required_params:
if is_scalar_sequence_field(field) and isinstance(
received_params, (QueryParams, Headers)
): ):
value = received_params.getlist(field.alias) or field.default main_dict[key] = main_dict[key] + update_dict[key]
else:
value = received_params.get(field.alias)
field_info = field.field_info
assert isinstance(
field_info, params.Param
), "Params must be subclasses of Param"
if value is None:
if field.required:
errors.append(
ErrorWrapper(
MissingError(), loc=(field_info.in_.value, field.alias)
)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field_info.in_.value, field.alias)
)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else: else:
values[field.name] = v_ main_dict[key] = value
return values, errors
async def request_body_to_args( def get_value_or_default(
required_params: List[ModelField], first_item: Union[DefaultPlaceholder, DefaultType],
received_body: Optional[Union[Dict[str, Any], FormData]], *extra_items: Union[DefaultPlaceholder, DefaultType],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: ) -> Union[DefaultPlaceholder, DefaultType]:
values = {} """
errors = [] Pass items or `DefaultPlaceholder`s by descending priority.
if required_params:
field = required_params[0]
field_info = field.field_info
embed = getattr(field_info, "embed", None)
field_alias_omitted = len(required_params) == 1 and not embed
if field_alias_omitted:
received_body = {field.alias: received_body}
for field in required_params: The first one to _not_ be a `DefaultPlaceholder` will be returned.
loc: Tuple[str, ...]
if field_alias_omitted:
loc = ("body",)
else:
loc = ("body", field.alias)
value: Optional[Any] = None Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
if received_body is not None: """
if ( items = (first_item,) + extra_items
field.shape in sequence_shapes or field.type_ in sequence_types for item in items:
) and isinstance(received_body, FormData): if not isinstance(item, DefaultPlaceholder):
value = received_body.getlist(field.alias) return item
else: return first_item
try:
value = received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
continue
if (
value is None
or (isinstance(field_info, params.Form) and value == "")
or (
isinstance(field_info, params.Form)
and field.shape in sequence_shapes
and len(value) == 0
)
):
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
continue
if (
isinstance(field_info, params.File)
and lenient_issubclass(field.type_, bytes)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
field.shape in sequence_shapes
and isinstance(field_info, params.File)
and lenient_issubclass(field.type_, bytes)
and isinstance(value, sequence_types)
):
results: List[Union[bytes, str]] = []
async def process_fn(
fn: Callable[[], Coroutine[Any, Any, Any]]
) -> None:
result = await fn()
results.append(result)
async with anyio.create_task_group() as tg:
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = sequence_shape_to_type[field.shape](results)
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
return missing_field_error
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
field_info = first_param.field_info
embed = getattr(field_info, "embed", None)
body_param_names_set = {param.name for param in flat_dependant.body_params}
if len(body_param_names_set) == 1 and not embed:
check_file_field(first_param)
return first_param
# If one field requires to embed, all have to be embedded
# in case a sub-dependency is evaluated with a single unique body field
# That is combined (embedded) with other body fields
for param in flat_dependant.body_params:
setattr(param.field_info, "embed", True) # noqa: B010
model_name = "Body_" + name
BodyModel: Type[BaseModel] = create_model(model_name)
for f in flat_dependant.body_params:
BodyModel.__fields__[f.name] = f
required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = {"default": None}
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
BodyFieldInfo: Type[params.Body] = params.File
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
BodyFieldInfo = params.Form
else:
BodyFieldInfo = params.Body
body_param_media_types = [
f.field_info.media_type
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
final_field = create_response_field(
name="body",
type_=BodyModel,
required=required,
alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
check_file_field(final_field)
return final_field

Loading…
Cancel
Save