Browse Source

Add support and tests for Pydantic dataclasses in response_model (#454)

pull/490/head
dconathan 6 years ago
committed by Sebastián Ramírez
parent
commit
3025a368c6
  1. 3
      fastapi/utils.py
  2. 48
      tests/test_serialize_response.py
  3. 58
      tests/test_serialize_response_dataclass.py
  4. 51
      tests/test_validate_response.py
  5. 53
      tests/test_validate_response_dataclass.py

3
fastapi/utils.py

@ -1,4 +1,5 @@
import re
from dataclasses import is_dataclass
from typing import Any, Dict, List, Sequence, Set, Type, cast
from fastapi import routing
@ -52,6 +53,8 @@ def get_path_param_names(path: str) -> Set[str]:
def create_cloned_field(field: Field) -> Field:
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__ # type: ignore
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)

48
tests/test_serialize_response.py

@ -1,8 +1,7 @@
from typing import List
import pytest
from fastapi import FastAPI
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from starlette.testclient import TestClient
app = FastAPI()
@ -14,38 +13,45 @@ class Item(BaseModel):
owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item)
def get_invalid():
return {"name": "invalid", "price": "foo"}
@app.get("/items/valid", response_model=Item)
def get_valid():
return {"name": "valid", "price": 1.0}
@app.get("/items/innerinvalid", response_model=Item)
def get_innerinvalid():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
@app.get("/items/coerce", response_model=Item)
def get_coerce():
return {"name": "coerce", "price": "1.0"}
@app.get("/items/invalidlist", response_model=List[Item])
def get_invalidlist():
@app.get("/items/validlist", response_model=List[Item])
def get_validlist():
return [
{"name": "foo"},
{"name": "bar", "price": "bar"},
{"name": "baz", "price": "baz"},
{"name": "bar", "price": 1.0},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]
client = TestClient(app)
def test_invalid():
with pytest.raises(ValidationError):
client.get("/items/invalid")
def test_valid():
response = client.get("/items/valid")
response.raise_for_status()
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
def test_double_invalid():
with pytest.raises(ValidationError):
client.get("/items/innerinvalid")
def test_coerce():
response = client.get("/items/coerce")
response.raise_for_status()
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
def test_invalid_list():
with pytest.raises(ValidationError):
client.get("/items/invalidlist")
def test_validlist():
response = client.get("/items/validlist")
response.raise_for_status()
assert response.json() == [
{"name": "foo", "price": None, "owner_ids": None},
{"name": "bar", "price": 1.0, "owner_ids": None},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]

58
tests/test_serialize_response_dataclass.py

@ -0,0 +1,58 @@
from typing import List
from fastapi import FastAPI
from pydantic.dataclasses import dataclass
from starlette.testclient import TestClient
app = FastAPI()
@dataclass
class Item:
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/valid", response_model=Item)
def get_valid():
return {"name": "valid", "price": 1.0}
@app.get("/items/coerce", response_model=Item)
def get_coerce():
return {"name": "coerce", "price": "1.0"}
@app.get("/items/validlist", response_model=List[Item])
def get_validlist():
return [
{"name": "foo"},
{"name": "bar", "price": 1.0},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]
client = TestClient(app)
def test_valid():
response = client.get("/items/valid")
response.raise_for_status()
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
def test_coerce():
response = client.get("/items/coerce")
response.raise_for_status()
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
def test_validlist():
response = client.get("/items/validlist")
response.raise_for_status()
assert response.json() == [
{"name": "foo", "price": None, "owner_ids": None},
{"name": "bar", "price": 1.0, "owner_ids": None},
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
]

51
tests/test_validate_response.py

@ -0,0 +1,51 @@
from typing import List
import pytest
from fastapi import FastAPI
from pydantic import BaseModel, ValidationError
from starlette.testclient import TestClient
app = FastAPI()
class Item(BaseModel):
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item)
def get_invalid():
return {"name": "invalid", "price": "foo"}
@app.get("/items/innerinvalid", response_model=Item)
def get_innerinvalid():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
@app.get("/items/invalidlist", response_model=List[Item])
def get_invalidlist():
return [
{"name": "foo"},
{"name": "bar", "price": "bar"},
{"name": "baz", "price": "baz"},
]
client = TestClient(app)
def test_invalid():
with pytest.raises(ValidationError):
client.get("/items/invalid")
def test_double_invalid():
with pytest.raises(ValidationError):
client.get("/items/innerinvalid")
def test_invalid_list():
with pytest.raises(ValidationError):
client.get("/items/invalidlist")

53
tests/test_validate_response_dataclass.py

@ -0,0 +1,53 @@
from typing import List
import pytest
from fastapi import FastAPI
from pydantic import ValidationError
from pydantic.dataclasses import dataclass
from starlette.testclient import TestClient
app = FastAPI()
@dataclass
class Item:
name: str
price: float = None
owner_ids: List[int] = None
@app.get("/items/invalid", response_model=Item)
def get_invalid():
return {"name": "invalid", "price": "foo"}
@app.get("/items/innerinvalid", response_model=Item)
def get_innerinvalid():
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
@app.get("/items/invalidlist", response_model=List[Item])
def get_invalidlist():
return [
{"name": "foo"},
{"name": "bar", "price": "bar"},
{"name": "baz", "price": "baz"},
]
client = TestClient(app)
def test_invalid():
with pytest.raises(ValidationError):
client.get("/items/invalid")
def test_double_invalid():
with pytest.raises(ValidationError):
client.get("/items/innerinvalid")
def test_invalid_list():
with pytest.raises(ValidationError):
client.get("/items/invalidlist")
Loading…
Cancel
Save