From 3025a368c6ea7cc9260fe81c20b8af30e9d8aab5 Mon Sep 17 00:00:00 2001 From: dconathan Date: Fri, 30 Aug 2019 18:12:15 -0500 Subject: [PATCH] :sparkles: Add support and tests for Pydantic dataclasses in response_model (#454) --- fastapi/utils.py | 3 ++ tests/test_serialize_response.py | 48 ++++++++++-------- tests/test_serialize_response_dataclass.py | 58 ++++++++++++++++++++++ tests/test_validate_response.py | 51 +++++++++++++++++++ tests/test_validate_response_dataclass.py | 53 ++++++++++++++++++++ 5 files changed, 192 insertions(+), 21 deletions(-) create mode 100644 tests/test_serialize_response_dataclass.py create mode 100644 tests/test_validate_response.py create mode 100644 tests/test_validate_response_dataclass.py diff --git a/fastapi/utils.py b/fastapi/utils.py index de0260615..17a16b522 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -1,4 +1,5 @@ import re +from dataclasses import is_dataclass from typing import Any, Dict, List, Sequence, Set, Type, cast from fastapi import routing @@ -52,6 +53,8 @@ def get_path_param_names(path: str) -> Set[str]: def create_cloned_field(field: Field) -> Field: original_type = field.type_ + if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): + original_type = original_type.__pydantic_model__ # type: ignore use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) diff --git a/tests/test_serialize_response.py b/tests/test_serialize_response.py index c0382b899..5fff871f0 100644 --- a/tests/test_serialize_response.py +++ b/tests/test_serialize_response.py @@ -1,8 +1,7 @@ from typing import List -import pytest from fastapi import FastAPI -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from starlette.testclient import TestClient app = FastAPI() @@ -14,38 +13,45 @@ class Item(BaseModel): owner_ids: List[int] = None -@app.get("/items/invalid", response_model=Item) -def get_invalid(): - return {"name": "invalid", "price": "foo"} +@app.get("/items/valid", response_model=Item) +def get_valid(): + return {"name": "valid", "price": 1.0} -@app.get("/items/innerinvalid", response_model=Item) -def get_innerinvalid(): - return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]} +@app.get("/items/coerce", response_model=Item) +def get_coerce(): + return {"name": "coerce", "price": "1.0"} -@app.get("/items/invalidlist", response_model=List[Item]) -def get_invalidlist(): +@app.get("/items/validlist", response_model=List[Item]) +def get_validlist(): return [ {"name": "foo"}, - {"name": "bar", "price": "bar"}, - {"name": "baz", "price": "baz"}, + {"name": "bar", "price": 1.0}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, ] client = TestClient(app) -def test_invalid(): - with pytest.raises(ValidationError): - client.get("/items/invalid") +def test_valid(): + response = client.get("/items/valid") + response.raise_for_status() + assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None} -def test_double_invalid(): - with pytest.raises(ValidationError): - client.get("/items/innerinvalid") +def test_coerce(): + response = client.get("/items/coerce") + response.raise_for_status() + assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None} -def test_invalid_list(): - with pytest.raises(ValidationError): - client.get("/items/invalidlist") +def test_validlist(): + response = client.get("/items/validlist") + response.raise_for_status() + assert response.json() == [ + {"name": "foo", "price": None, "owner_ids": None}, + {"name": "bar", "price": 1.0, "owner_ids": None}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] diff --git a/tests/test_serialize_response_dataclass.py b/tests/test_serialize_response_dataclass.py new file mode 100644 index 000000000..ee701f969 --- /dev/null +++ b/tests/test_serialize_response_dataclass.py @@ -0,0 +1,58 @@ +from typing import List + +from fastapi import FastAPI +from pydantic.dataclasses import dataclass +from starlette.testclient import TestClient + +app = FastAPI() + + +@dataclass +class Item: + name: str + price: float = None + owner_ids: List[int] = None + + +@app.get("/items/valid", response_model=Item) +def get_valid(): + return {"name": "valid", "price": 1.0} + + +@app.get("/items/coerce", response_model=Item) +def get_coerce(): + return {"name": "coerce", "price": "1.0"} + + +@app.get("/items/validlist", response_model=List[Item]) +def get_validlist(): + return [ + {"name": "foo"}, + {"name": "bar", "price": 1.0}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] + + +client = TestClient(app) + + +def test_valid(): + response = client.get("/items/valid") + response.raise_for_status() + assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None} + + +def test_coerce(): + response = client.get("/items/coerce") + response.raise_for_status() + assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None} + + +def test_validlist(): + response = client.get("/items/validlist") + response.raise_for_status() + assert response.json() == [ + {"name": "foo", "price": None, "owner_ids": None}, + {"name": "bar", "price": 1.0, "owner_ids": None}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] diff --git a/tests/test_validate_response.py b/tests/test_validate_response.py new file mode 100644 index 000000000..c0382b899 --- /dev/null +++ b/tests/test_validate_response.py @@ -0,0 +1,51 @@ +from typing import List + +import pytest +from fastapi import FastAPI +from pydantic import BaseModel, ValidationError +from starlette.testclient import TestClient + +app = FastAPI() + + +class Item(BaseModel): + name: str + price: float = None + owner_ids: List[int] = None + + +@app.get("/items/invalid", response_model=Item) +def get_invalid(): + return {"name": "invalid", "price": "foo"} + + +@app.get("/items/innerinvalid", response_model=Item) +def get_innerinvalid(): + return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]} + + +@app.get("/items/invalidlist", response_model=List[Item]) +def get_invalidlist(): + return [ + {"name": "foo"}, + {"name": "bar", "price": "bar"}, + {"name": "baz", "price": "baz"}, + ] + + +client = TestClient(app) + + +def test_invalid(): + with pytest.raises(ValidationError): + client.get("/items/invalid") + + +def test_double_invalid(): + with pytest.raises(ValidationError): + client.get("/items/innerinvalid") + + +def test_invalid_list(): + with pytest.raises(ValidationError): + client.get("/items/invalidlist") diff --git a/tests/test_validate_response_dataclass.py b/tests/test_validate_response_dataclass.py new file mode 100644 index 000000000..4a066416a --- /dev/null +++ b/tests/test_validate_response_dataclass.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest +from fastapi import FastAPI +from pydantic import ValidationError +from pydantic.dataclasses import dataclass +from starlette.testclient import TestClient + +app = FastAPI() + + +@dataclass +class Item: + name: str + price: float = None + owner_ids: List[int] = None + + +@app.get("/items/invalid", response_model=Item) +def get_invalid(): + return {"name": "invalid", "price": "foo"} + + +@app.get("/items/innerinvalid", response_model=Item) +def get_innerinvalid(): + return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]} + + +@app.get("/items/invalidlist", response_model=List[Item]) +def get_invalidlist(): + return [ + {"name": "foo"}, + {"name": "bar", "price": "bar"}, + {"name": "baz", "price": "baz"}, + ] + + +client = TestClient(app) + + +def test_invalid(): + with pytest.raises(ValidationError): + client.get("/items/invalid") + + +def test_double_invalid(): + with pytest.raises(ValidationError): + client.get("/items/innerinvalid") + + +def test_invalid_list(): + with pytest.raises(ValidationError): + client.get("/items/invalidlist")