From aea04ee32ee1942e6e1a904527bb8da6ba76abd9 Mon Sep 17 00:00:00 2001 From: juhovh-aiven Date: Sat, 28 Mar 2020 02:19:17 +1100 Subject: [PATCH] :bug: Fix exclude_unset and aliases in response model validation (#1074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix exclude_unset and aliases in response model validation. * :sparkles: Use by_alias from param :shrug: Co-authored-by: Sebastián Ramírez --- fastapi/routing.py | 32 +++-- tests/test_serialize_response_model.py | 154 +++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 7 deletions(-) create mode 100644 tests/test_serialize_response_model.py diff --git a/fastapi/routing.py b/fastapi/routing.py index b36104869..b90935e15 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -48,6 +48,28 @@ except ImportError: # pragma: nocover from pydantic.fields import Field as ModelField # type: ignore +def _prepare_response_content( + res: Any, *, by_alias: bool = True, exclude_unset: bool +) -> Any: + if isinstance(res, BaseModel): + if PYDANTIC_1: + return res.dict(by_alias=by_alias, exclude_unset=exclude_unset) + else: + return res.dict( + by_alias=by_alias, skip_defaults=exclude_unset + ) # pragma: nocover + elif isinstance(res, list): + return [ + _prepare_response_content(item, exclude_unset=exclude_unset) for item in res + ] + elif isinstance(res, dict): + return { + k: _prepare_response_content(v, exclude_unset=exclude_unset) + for k, v in res.items() + } + return res + + async def serialize_response( *, field: ModelField = None, @@ -60,13 +82,9 @@ async def serialize_response( ) -> Any: if field: errors = [] - if exclude_unset and isinstance(response_content, BaseModel): - if PYDANTIC_1: - response_content = response_content.dict(exclude_unset=exclude_unset) - else: - response_content = response_content.dict( - skip_defaults=exclude_unset - ) # pragma: nocover + response_content = _prepare_response_content( + response_content, by_alias=by_alias, exclude_unset=exclude_unset + ) if is_coroutine: value, errors_ = field.validate(response_content, {}, loc=("response",)) else: diff --git a/tests/test_serialize_response_model.py b/tests/test_serialize_response_model.py new file mode 100644 index 000000000..adb7fda34 --- /dev/null +++ b/tests/test_serialize_response_model.py @@ -0,0 +1,154 @@ +from typing import Dict, List + +from fastapi import FastAPI +from pydantic import BaseModel, Field +from starlette.testclient import TestClient + +app = FastAPI() + + +class Item(BaseModel): + name: str = Field(..., alias="aliased_name") + price: float = None + owner_ids: List[int] = None + + +@app.get("/items/valid", response_model=Item) +def get_valid(): + return Item(aliased_name="valid", price=1.0) + + +@app.get("/items/coerce", response_model=Item) +def get_coerce(): + return Item(aliased_name="coerce", price="1.0") + + +@app.get("/items/validlist", response_model=List[Item]) +def get_validlist(): + return [ + Item(aliased_name="foo"), + Item(aliased_name="bar", price=1.0), + Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]), + ] + + +@app.get("/items/validdict", response_model=Dict[str, Item]) +def get_validdict(): + return { + "k1": Item(aliased_name="foo"), + "k2": Item(aliased_name="bar", price=1.0), + "k3": Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]), + } + + +@app.get( + "/items/valid-exclude-unset", response_model=Item, response_model_exclude_unset=True +) +def get_valid_exclude_unset(): + return Item(aliased_name="valid", price=1.0) + + +@app.get( + "/items/coerce-exclude-unset", + response_model=Item, + response_model_exclude_unset=True, +) +def get_coerce_exclude_unset(): + return Item(aliased_name="coerce", price="1.0") + + +@app.get( + "/items/validlist-exclude-unset", + response_model=List[Item], + response_model_exclude_unset=True, +) +def get_validlist_exclude_unset(): + return [ + Item(aliased_name="foo"), + Item(aliased_name="bar", price=1.0), + Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]), + ] + + +@app.get( + "/items/validdict-exclude-unset", + response_model=Dict[str, Item], + response_model_exclude_unset=True, +) +def get_validdict_exclude_unset(): + return { + "k1": Item(aliased_name="foo"), + "k2": Item(aliased_name="bar", price=1.0), + "k3": Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]), + } + + +client = TestClient(app) + + +def test_valid(): + response = client.get("/items/valid") + response.raise_for_status() + assert response.json() == {"aliased_name": "valid", "price": 1.0, "owner_ids": None} + + +def test_coerce(): + response = client.get("/items/coerce") + response.raise_for_status() + assert response.json() == { + "aliased_name": "coerce", + "price": 1.0, + "owner_ids": None, + } + + +def test_validlist(): + response = client.get("/items/validlist") + response.raise_for_status() + assert response.json() == [ + {"aliased_name": "foo", "price": None, "owner_ids": None}, + {"aliased_name": "bar", "price": 1.0, "owner_ids": None}, + {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] + + +def test_validdict(): + response = client.get("/items/validdict") + response.raise_for_status() + assert response.json() == { + "k1": {"aliased_name": "foo", "price": None, "owner_ids": None}, + "k2": {"aliased_name": "bar", "price": 1.0, "owner_ids": None}, + "k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + } + + +def test_valid_exclude_unset(): + response = client.get("/items/valid-exclude-unset") + response.raise_for_status() + assert response.json() == {"aliased_name": "valid", "price": 1.0} + + +def test_coerce_exclude_unset(): + response = client.get("/items/coerce-exclude-unset") + response.raise_for_status() + assert response.json() == {"aliased_name": "coerce", "price": 1.0} + + +def test_validlist_exclude_unset(): + response = client.get("/items/validlist-exclude-unset") + response.raise_for_status() + assert response.json() == [ + {"aliased_name": "foo"}, + {"aliased_name": "bar", "price": 1.0}, + {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + ] + + +def test_validdict_exclude_unset(): + response = client.get("/items/validdict-exclude-unset") + response.raise_for_status() + assert response.json() == { + "k1": {"aliased_name": "foo"}, + "k2": {"aliased_name": "bar", "price": 1.0}, + "k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, + }