Vadim 5 days ago
committed by GitHub
parent
commit
3addbf0a66
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 16
      fastapi/_compat.py
  2. 24
      tests/test_compat.py

16
fastapi/_compat.py

@ -246,7 +246,21 @@ if PYDANTIC_V2:
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
return field_annotation_is_sequence(
field.field_info.annotation
) or field_annotation_is_optional_sequence(field.field_info.annotation)
def field_annotation_is_optional_sequence(
annotation: Union[Type[Any], None],
) -> bool:
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
for arg in args:
if hasattr(arg, "__origin__"):
if arg.__origin__ in sequence_types:
return True
return False
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)

24
tests/test_compat.py

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, FrozenSet, List, Optional, Set, Union
from fastapi import FastAPI, UploadFile
from fastapi._compat import (
@ -9,6 +9,7 @@ from fastapi._compat import (
get_model_fields,
is_bytes_sequence_annotation,
is_scalar_field,
is_sequence_field,
is_uploadfile_sequence_annotation,
)
from fastapi.testclient import TestClient
@ -96,6 +97,27 @@ def test_is_uploadfile_sequence_annotation():
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
@needs_pydanticv2
def test_model_optional_union_v2():
# For coverage
types = [
Optional[List[str]],
Union[List[int], List[float]],
Optional[Set[int]],
Union[Set[int], Set[float]],
Optional[FrozenSet[int]],
Union[List[int], None],
]
for annotation in types:
field_info = FieldInfo(annotation=annotation)
field = ModelField(name="foo", field_info=field_info)
assert is_sequence_field(field) is True
field_info_str = FieldInfo(annotation=str)
field_str = ModelField(name="foo", field_info=field_info_str)
assert is_sequence_field(field_str) is False
def test_is_pv1_scalar_field():
# For coverage
class Model(BaseModel):

Loading…
Cancel
Save