diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 51cab419d..3f599c9fa 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -1,3 +1,4 @@ +import dataclasses from collections import defaultdict from enum import Enum from pathlib import PurePath @@ -61,6 +62,8 @@ def jsonable_encoder( custom_encoder=encoder, sqlalchemy_safe=sqlalchemy_safe, ) + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) if isinstance(obj, Enum): return obj.value if isinstance(obj, PurePath): diff --git a/fastapi/routing.py b/fastapi/routing.py index a7c62af3c..f5fe0c0bd 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import email.message import enum import inspect @@ -90,6 +91,8 @@ def _prepare_response_content( ) for k, v in res.items() } + elif dataclasses.is_dataclass(res): + return dataclasses.asdict(res) return res diff --git a/tests/test_serialize_response_dataclass.py b/tests/test_serialize_response_dataclass.py index d1b64c4e8..e520338ec 100644 --- a/tests/test_serialize_response_dataclass.py +++ b/tests/test_serialize_response_dataclass.py @@ -19,6 +19,11 @@ def get_valid(): return {"name": "valid", "price": 1.0} +@app.get("/items/object", response_model=Item) +def get_object(): + return Item(name="object", price=1.0, owner_ids=[1, 2, 3]) + + @app.get("/items/coerce", response_model=Item) def get_coerce(): return {"name": "coerce", "price": "1.0"} @@ -33,6 +38,29 @@ def get_validlist(): ] +@app.get("/items/objectlist", response_model=List[Item]) +def get_objectlist(): + return [ + Item(name="foo"), + Item(name="bar", price=1.0), + Item(name="baz", price=2.0, owner_ids=[1, 2, 3]), + ] + + +@app.get("/items/no-response-model/object") +def get_no_response_model_object(): + return Item(name="object", price=1.0, owner_ids=[1, 2, 3]) + + +@app.get("/items/no-response-model/objectlist") +def get_no_response_model_objectlist(): + return [ + Item(name="foo"), + Item(name="bar", price=1.0), + Item(name="baz", price=2.0, owner_ids=[1, 2, 3]), + ] + + client = TestClient(app) @@ -42,6 +70,12 @@ def test_valid(): assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None} +def test_object(): + response = client.get("/items/object") + response.raise_for_status() + assert response.json() == {"name": "object", "price": 1.0, "owner_ids": [1, 2, 3]} + + def test_coerce(): response = client.get("/items/coerce") response.raise_for_status() @@ -56,3 +90,29 @@ def test_validlist(): {"name": "bar", "price": 1.0, "owner_ids": None}, {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, ] + + +def test_objectlist(): + response = client.get("/items/objectlist") + response.raise_for_status() + assert response.json() == [ + {"name": "foo", "price": None, "owner_ids": None}, + {"name": "bar", "price": 1.0, "owner_ids": None}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] + + +def test_no_response_model_object(): + response = client.get("/items/no-response-model/object") + response.raise_for_status() + assert response.json() == {"name": "object", "price": 1.0, "owner_ids": [1, 2, 3]} + + +def test_no_response_model_objectlist(): + response = client.get("/items/no-response-model/objectlist") + response.raise_for_status() + assert response.json() == [ + {"name": "foo", "price": None, "owner_ids": None}, + {"name": "bar", "price": 1.0, "owner_ids": None}, + {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ]