Browse Source

🐛 Check already cloned fields in create_cloned_field to support recursive models (#1164)

* FIX: #894
Include recursion check for create_cloned_field.
Added test for recursive model.

* ♻️ Refactor and format create_cloned_field()

Co-authored-by: Lukas Voegtle <[email protected]>
Co-authored-by: Sebastián Ramírez <[email protected]>
pull/1182/head
voegtlel 5 years ago
committed by GitHub
parent
commit
0f152b4e97
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 26
      fastapi/utils.py
  2. 80
      tests/test_validate_response_recursive.py

26
fastapi/utils.py

@ -131,17 +131,26 @@ def create_response_field(
)
def create_cloned_field(field: ModelField) -> ModelField:
def create_cloned_field(
field: ModelField, *, cloned_types: Dict[Type[BaseModel], Type[BaseModel]] = None,
) -> ModelField:
# _cloned_types has already cloned types, to support recursive models
if cloned_types is None:
cloned_types = dict()
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__ # type: ignore
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = create_model(original_type.__name__, __base__=original_type)
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
@ -157,10 +166,13 @@ def create_cloned_field(field: ModelField) -> ModelField:
new_field.validate_always = field.validate_always
if field.sub_fields:
new_field.sub_fields = [
create_cloned_field(sub_field) for sub_field in field.sub_fields
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields
]
if field.key_field:
new_field.key_field = create_cloned_field(field.key_field)
new_field.key_field = create_cloned_field(
field.key_field, cloned_types=cloned_types
)
new_field.validators = field.validators
if PYDANTIC_1:
new_field.pre_validators = field.pre_validators

80
tests/test_validate_response_recursive.py

@ -0,0 +1,80 @@
from typing import List
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
app = FastAPI()
class RecursiveItem(BaseModel):
sub_items: List["RecursiveItem"] = []
name: str
RecursiveItem.update_forward_refs()
class RecursiveSubitemInSubmodel(BaseModel):
sub_items2: List["RecursiveItemViaSubmodel"] = []
name: str
class RecursiveItemViaSubmodel(BaseModel):
sub_items1: List[RecursiveSubitemInSubmodel] = []
name: str
RecursiveSubitemInSubmodel.update_forward_refs()
@app.get("/items/recursive", response_model=RecursiveItem)
def get_recursive():
return {"name": "item", "sub_items": [{"name": "subitem", "sub_items": []}]}
@app.get("/items/recursive-submodel", response_model=RecursiveItemViaSubmodel)
def get_recursive_submodel():
return {
"name": "item",
"sub_items1": [
{
"name": "subitem",
"sub_items2": [
{
"name": "subsubitem",
"sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
}
],
}
],
}
client = TestClient(app)
def test_recursive():
response = client.get("/items/recursive")
assert response.status_code == 200
assert response.json() == {
"sub_items": [{"name": "subitem", "sub_items": []}],
"name": "item",
}
response = client.get("/items/recursive-submodel")
assert response.status_code == 200
assert response.json() == {
"name": "item",
"sub_items1": [
{
"name": "subitem",
"sub_items2": [
{
"name": "subsubitem",
"sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
}
],
}
],
}
Loading…
Cancel
Save