diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..836285c4c 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -572,6 +572,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + if origin is Annotated: + return field_annotation_is_complex(get_args(annotation)[0]) + return ( _annotation_is_complex(annotation) or _annotation_is_complex(origin) diff --git a/tests/test_union_body_discriminator.py b/tests/test_union_body_discriminator.py new file mode 100644 index 000000000..14bf43ff3 --- /dev/null +++ b/tests/test_union_body_discriminator.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Union + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +from .utils import needs_pydanticv2 + + +@pytest.fixture(name="client") +def get_client() -> TestClient: + from pydantic import BaseModel, Discriminator, Tag + + app = FastAPI() + + class FirstItem(BaseModel): + value: str + price: int + + class OtherItem(BaseModel): + value: str + price: float + + def get_discriminator_value(v: Any) -> str: + if isinstance(v, dict): + return v.get("value") + return v.value + + Item = Annotated[ + Union[ + Annotated[FirstItem, Tag("first")], + Annotated[OtherItem, Tag("other")], + ], + Discriminator(get_discriminator_value), + ] + + @app.post("/items/") + def save_union_body_discriminator(item: Item) -> Dict[str, Any]: + return {"item": item} + + client = TestClient(app) + return client + + +@needs_pydanticv2 +def test_post_item(client: TestClient) -> None: + response = client.post("/items/", json={"value": "first", "price": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"item": {"value": "first", "price": 100}} + + +@needs_pydanticv2 +def test_post_other_item(client: TestClient) -> None: + response = client.post("/items/", json={"value": "other", "price": 100.5}) + assert response.status_code == 200, response.text + assert response.json() == {"item": {"value": "other", "price": 100.5}}