Browse Source

♻️ Refactor internal `check_file_field()`, rename to `ensure_multipart_is_installed()` to clarify its purpose (#12106)

pull/12112/head
Sebastián Ramírez 7 months ago
committed by GitHub
parent
commit
23bda0ffeb
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 47
      fastapi/dependencies/utils.py

47
fastapi/dependencies/utils.py

@ -80,25 +80,23 @@ multipart_incorrect_install_error = (
) )
def check_file_field(field: ModelField) -> None: def ensure_multipart_is_installed() -> None:
field_info = field.field_info try:
if isinstance(field_info, params.Form): # __version__ is available in both multiparts, and can be mocked
from multipart import __version__ # type: ignore
assert __version__
try: try:
# __version__ is available in both multiparts, and can be mocked # parse_options_header is only available in the right multipart
from multipart import __version__ # type: ignore from multipart.multipart import parse_options_header # type: ignore
assert __version__ assert parse_options_header
try:
# parse_options_header is only available in the right multipart
from multipart.multipart import parse_options_header # type: ignore
assert parse_options_header
except ImportError:
logger.error(multipart_incorrect_install_error)
raise RuntimeError(multipart_incorrect_install_error) from None
except ImportError: except ImportError:
logger.error(multipart_not_installed_error) logger.error(multipart_incorrect_install_error)
raise RuntimeError(multipart_not_installed_error) from None raise RuntimeError(multipart_incorrect_install_error) from None
except ImportError:
logger.error(multipart_not_installed_error)
raise RuntimeError(multipart_not_installed_error) from None
def get_param_sub_dependant( def get_param_sub_dependant(
@ -336,6 +334,7 @@ def analyze_param(
if annotation is not inspect.Signature.empty: if annotation is not inspect.Signature.empty:
use_annotation = annotation use_annotation = annotation
type_annotation = annotation type_annotation = annotation
# Extract Annotated info
if get_origin(use_annotation) is Annotated: if get_origin(use_annotation) is Annotated:
annotated_args = get_args(annotation) annotated_args = get_args(annotation)
type_annotation = annotated_args[0] type_annotation = annotated_args[0]
@ -355,6 +354,7 @@ def analyze_param(
) )
else: else:
fastapi_annotation = None fastapi_annotation = None
# Set default for Annotated FieldInfo
if isinstance(fastapi_annotation, FieldInfo): if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below. # Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info( field_info = copy_field_info(
@ -369,9 +369,10 @@ def analyze_param(
field_info.default = value field_info.default = value
else: else:
field_info.default = Required field_info.default = Required
# Get Annotated Depends
elif isinstance(fastapi_annotation, params.Depends): elif isinstance(fastapi_annotation, params.Depends):
depends = fastapi_annotation depends = fastapi_annotation
# Get Depends from default value
if isinstance(value, params.Depends): if isinstance(value, params.Depends):
assert depends is None, ( assert depends is None, (
"Cannot specify `Depends` in `Annotated` and default value" "Cannot specify `Depends` in `Annotated` and default value"
@ -382,6 +383,7 @@ def analyze_param(
f" default value together for {param_name!r}" f" default value together for {param_name!r}"
) )
depends = value depends = value
# Get FieldInfo from default value
elif isinstance(value, FieldInfo): elif isinstance(value, FieldInfo):
assert field_info is None, ( assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value" "Cannot specify FastAPI annotations in `Annotated` and default value"
@ -391,11 +393,13 @@ def analyze_param(
if PYDANTIC_V2: if PYDANTIC_V2:
field_info.annotation = type_annotation field_info.annotation = type_annotation
# Get Depends from type annotation
if depends is not None and depends.dependency is None: if depends is not None and depends.dependency is None:
# Copy `depends` before mutating it # Copy `depends` before mutating it
depends = copy(depends) depends = copy(depends)
depends.dependency = type_annotation depends.dependency = type_annotation
# Handle non-param type annotations like Request
if lenient_issubclass( if lenient_issubclass(
type_annotation, type_annotation,
( (
@ -411,6 +415,7 @@ def analyze_param(
assert ( assert (
field_info is None field_info is None
), f"Cannot specify FastAPI annotation for type {type_annotation!r}" ), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
elif field_info is None and depends is None: elif field_info is None and depends is None:
default_value = value if value is not inspect.Signature.empty else Required default_value = value if value is not inspect.Signature.empty else Required
if is_path_param: if is_path_param:
@ -428,7 +433,9 @@ def analyze_param(
field_info = params.Query(annotation=use_annotation, default=default_value) field_info = params.Query(annotation=use_annotation, default=default_value)
field = None field = None
# It's a field_info, not a dependency
if field_info is not None: if field_info is not None:
# Handle field_info.in_
if is_path_param: if is_path_param:
assert isinstance(field_info, params.Path), ( assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param" f"Cannot use `{field_info.__class__.__name__}` for path param"
@ -444,6 +451,8 @@ def analyze_param(
field_info, field_info,
param_name, param_name,
) )
if isinstance(field_info, params.Form):
ensure_multipart_is_installed()
if not field_info.alias and getattr(field_info, "convert_underscores", None): if not field_info.alias and getattr(field_info, "convert_underscores", None):
alias = param_name.replace("_", "-") alias = param_name.replace("_", "-")
else: else:
@ -786,7 +795,6 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
embed = getattr(field_info, "embed", None) embed = getattr(field_info, "embed", None)
body_param_names_set = {param.name for param in flat_dependant.body_params} body_param_names_set = {param.name for param in flat_dependant.body_params}
if len(body_param_names_set) == 1 and not embed: if len(body_param_names_set) == 1 and not embed:
check_file_field(first_param)
return first_param return first_param
# If one field requires to embed, all have to be embedded # If one field requires to embed, all have to be embedded
# in case a sub-dependency is evaluated with a single unique body field # in case a sub-dependency is evaluated with a single unique body field
@ -825,5 +833,4 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
alias="body", alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
) )
check_file_field(final_field)
return final_field return final_field

Loading…
Cancel
Save