diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..6e9a6f0ed 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -246,7 +246,21 @@ if PYDANTIC_V2: ) and not isinstance(field.field_info, params.Body) def is_sequence_field(field: ModelField) -> bool: - return field_annotation_is_sequence(field.field_info.annotation) + return field_annotation_is_sequence( + field.field_info.annotation + ) or field_annotation_is_optional_sequence(field.field_info.annotation) + + def field_annotation_is_optional_sequence( + annotation: Union[Type[Any], None], + ) -> bool: + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + for arg in args: + if hasattr(arg, "__origin__"): + if arg.__origin__ in sequence_types: + return True + return False def is_scalar_sequence_field(field: ModelField) -> bool: return field_annotation_is_scalar_sequence(field.field_info.annotation) diff --git a/tests/test_compat.py b/tests/test_compat.py index f4a3093c5..c05e5ddad 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, FrozenSet, List, Optional, Set, Union from fastapi import FastAPI, UploadFile from fastapi._compat import ( @@ -9,6 +9,7 @@ from fastapi._compat import ( get_model_fields, is_bytes_sequence_annotation, is_scalar_field, + is_sequence_field, is_uploadfile_sequence_annotation, ) from fastapi.testclient import TestClient @@ -96,6 +97,27 @@ def test_is_uploadfile_sequence_annotation(): assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]]) +@needs_pydanticv2 +def test_model_optional_union_v2(): + # For coverage + types = [ + Optional[List[str]], + Union[List[int], List[float]], + Optional[Set[int]], + Union[Set[int], Set[float]], + Optional[FrozenSet[int]], + Union[List[int], None], + ] + for annotation in types: + field_info = FieldInfo(annotation=annotation) + field = ModelField(name="foo", field_info=field_info) + assert is_sequence_field(field) is True + + field_info_str = FieldInfo(annotation=str) + field_str = ModelField(name="foo", field_info=field_info_str) + assert is_sequence_field(field_str) is False + + def test_is_pv1_scalar_field(): # For coverage class Model(BaseModel):