diff --git a/fastapi/utils.py b/fastapi/utils.py index f24f28073..154dd9aa1 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -131,17 +131,26 @@ def create_response_field( ) -def create_cloned_field(field: ModelField) -> ModelField: +def create_cloned_field( + field: ModelField, *, cloned_types: Dict[Type[BaseModel], Type[BaseModel]] = None, +) -> ModelField: + # _cloned_types has already cloned types, to support recursive models + if cloned_types is None: + cloned_types = dict() original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): original_type = original_type.__pydantic_model__ # type: ignore use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) - use_type = create_model(original_type.__name__, __base__=original_type) - for f in original_type.__fields__.values(): - use_type.__fields__[f.name] = create_cloned_field(f) - + use_type = cloned_types.get(original_type) + if use_type is None: + use_type = create_model(original_type.__name__, __base__=original_type) + cloned_types[original_type] = use_type + for f in original_type.__fields__.values(): + use_type.__fields__[f.name] = create_cloned_field( + f, cloned_types=cloned_types + ) new_field = create_response_field(name=field.name, type_=use_type) new_field.has_alias = field.has_alias new_field.alias = field.alias @@ -157,10 +166,13 @@ def create_cloned_field(field: ModelField) -> ModelField: new_field.validate_always = field.validate_always if field.sub_fields: new_field.sub_fields = [ - create_cloned_field(sub_field) for sub_field in field.sub_fields + create_cloned_field(sub_field, cloned_types=cloned_types) + for sub_field in field.sub_fields ] if field.key_field: - new_field.key_field = create_cloned_field(field.key_field) + new_field.key_field = create_cloned_field( + field.key_field, cloned_types=cloned_types + ) new_field.validators = field.validators if PYDANTIC_1: new_field.pre_validators = field.pre_validators diff --git a/tests/test_validate_response_recursive.py b/tests/test_validate_response_recursive.py new file mode 100644 index 000000000..8b77ed14a --- /dev/null +++ b/tests/test_validate_response_recursive.py @@ -0,0 +1,80 @@ +from typing import List + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel + +app = FastAPI() + + +class RecursiveItem(BaseModel): + sub_items: List["RecursiveItem"] = [] + name: str + + +RecursiveItem.update_forward_refs() + + +class RecursiveSubitemInSubmodel(BaseModel): + sub_items2: List["RecursiveItemViaSubmodel"] = [] + name: str + + +class RecursiveItemViaSubmodel(BaseModel): + sub_items1: List[RecursiveSubitemInSubmodel] = [] + name: str + + +RecursiveSubitemInSubmodel.update_forward_refs() + + +@app.get("/items/recursive", response_model=RecursiveItem) +def get_recursive(): + return {"name": "item", "sub_items": [{"name": "subitem", "sub_items": []}]} + + +@app.get("/items/recursive-submodel", response_model=RecursiveItemViaSubmodel) +def get_recursive_submodel(): + return { + "name": "item", + "sub_items1": [ + { + "name": "subitem", + "sub_items2": [ + { + "name": "subsubitem", + "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}], + } + ], + } + ], + } + + +client = TestClient(app) + + +def test_recursive(): + response = client.get("/items/recursive") + assert response.status_code == 200 + assert response.json() == { + "sub_items": [{"name": "subitem", "sub_items": []}], + "name": "item", + } + + response = client.get("/items/recursive-submodel") + assert response.status_code == 200 + assert response.json() == { + "name": "item", + "sub_items1": [ + { + "name": "subitem", + "sub_items2": [ + { + "name": "subsubitem", + "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}], + } + ], + } + ], + }