Browse Source

Add support and tests for Pydantic dataclasses in response_model (#454)

pull/490/head
dconathan 6 years ago
committed by Sebastián Ramírez
parent
commit
3025a368c6
  1. 3
      fastapi/utils.py
  2. 48
      tests/test_serialize_response.py
  3. 58
      tests/test_serialize_response_dataclass.py
  4. 51
      tests/test_validate_response.py
  5. 53
      tests/test_validate_response_dataclass.py

3
fastapi/utils.py

@ -1,4 +1,5 @@
import re import re
from dataclasses import is_dataclass
from typing import Any, Dict, List, Sequence, Set, Type, cast from typing import Any, Dict, List, Sequence, Set, Type, cast
from fastapi import routing from fastapi import routing
@ -52,6 +53,8 @@ def get_path_param_names(path: str) -> Set[str]:
def create_cloned_field(field: Field) -> Field: def create_cloned_field(field: Field) -> Field:
original_type = field.type_ original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__ # type: ignore
use_type = original_type use_type = original_type
if lenient_issubclass(original_type, BaseModel): if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type) original_type = cast(Type[BaseModel], original_type)

48
tests/test_serialize_response.py

@ -1,8 +1,7 @@
from typing import List from typing import List
import pytest
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel, ValidationError from pydantic import BaseModel
from starlette.testclient import TestClient from starlette.testclient import TestClient
app = FastAPI() app = FastAPI()
@ -14,38 +13,45 @@ class Item(BaseModel):
owner_ids: List[int] = None owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item) @app.get("/items/valid", response_model=Item)
def get_invalid(): def get_valid():
return {"name": "invalid", "price": "foo"} return {"name": "valid", "price": 1.0}
@app.get("/items/innerinvalid", response_model=Item) @app.get("/items/coerce", response_model=Item)
def get_innerinvalid(): def get_coerce():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]} return {"name": "coerce", "price": "1.0"}
@app.get("/items/invalidlist", response_model=List[Item]) @app.get("/items/validlist", response_model=List[Item])
def get_invalidlist(): def get_validlist():
return [ return [
{"name": "foo"}, {"name": "foo"},
{"name": "bar", "price": "bar"}, {"name": "bar", "price": 1.0},
{"name": "baz", "price": "baz"}, {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
] ]
client = TestClient(app) client = TestClient(app)
def test_invalid(): def test_valid():
with pytest.raises(ValidationError): response = client.get("/items/valid")
client.get("/items/invalid") response.raise_for_status()
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
def test_double_invalid(): def test_coerce():
with pytest.raises(ValidationError): response = client.get("/items/coerce")
client.get("/items/innerinvalid") response.raise_for_status()
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
def test_invalid_list(): def test_validlist():
with pytest.raises(ValidationError): response = client.get("/items/validlist")
client.get("/items/invalidlist") response.raise_for_status()
assert response.json() == [
{"name": "foo", "price": None, "owner_ids": None},
{"name": "bar", "price": 1.0, "owner_ids": None},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]

58
tests/test_serialize_response_dataclass.py

@ -0,0 +1,58 @@
from typing import List
from fastapi import FastAPI
from pydantic.dataclasses import dataclass
from starlette.testclient import TestClient
app = FastAPI()
@dataclass
class Item:
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/valid", response_model=Item)
def get_valid():
return {"name": "valid", "price": 1.0}
@app.get("/items/coerce", response_model=Item)
def get_coerce():
return {"name": "coerce", "price": "1.0"}
@app.get("/items/validlist", response_model=List[Item])
def get_validlist():
return [
{"name": "foo"},
{"name": "bar", "price": 1.0},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]
client = TestClient(app)
def test_valid():
response = client.get("/items/valid")
response.raise_for_status()
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
def test_coerce():
response = client.get("/items/coerce")
response.raise_for_status()
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
def test_validlist():
response = client.get("/items/validlist")
response.raise_for_status()
assert response.json() == [
{"name": "foo", "price": None, "owner_ids": None},
{"name": "bar", "price": 1.0, "owner_ids": None},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]

51
tests/test_validate_response.py

@ -0,0 +1,51 @@
from typing import List
import pytest
from fastapi import FastAPI
from pydantic import BaseModel, ValidationError
from starlette.testclient import TestClient
app = FastAPI()
class Item(BaseModel):
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item)
def get_invalid():
return {"name": "invalid", "price": "foo"}
@app.get("/items/innerinvalid", response_model=Item)
def get_innerinvalid():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
@app.get("/items/invalidlist", response_model=List[Item])
def get_invalidlist():
return [
{"name": "foo"},
{"name": "bar", "price": "bar"},
{"name": "baz", "price": "baz"},
]
client = TestClient(app)
def test_invalid():
with pytest.raises(ValidationError):
client.get("/items/invalid")
def test_double_invalid():
with pytest.raises(ValidationError):
client.get("/items/innerinvalid")
def test_invalid_list():
with pytest.raises(ValidationError):
client.get("/items/invalidlist")

53
tests/test_validate_response_dataclass.py

@ -0,0 +1,53 @@
from typing import List
import pytest
from fastapi import FastAPI
from pydantic import ValidationError
from pydantic.dataclasses import dataclass
from starlette.testclient import TestClient
app = FastAPI()
@dataclass
class Item:
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item)
def get_invalid():
return {"name": "invalid", "price": "foo"}
@app.get("/items/innerinvalid", response_model=Item)
def get_innerinvalid():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
@app.get("/items/invalidlist", response_model=List[Item])
def get_invalidlist():
return [
{"name": "foo"},
{"name": "bar", "price": "bar"},
{"name": "baz", "price": "baz"},
]
client = TestClient(app)
def test_invalid():
with pytest.raises(ValidationError):
client.get("/items/invalid")
def test_double_invalid():
with pytest.raises(ValidationError):
client.get("/items/innerinvalid")
def test_invalid_list():
with pytest.raises(ValidationError):
client.get("/items/invalidlist")
Loading…
Cancel
Save