From 0e0931d3082cd44f56f13617b1d76b5dab87eaff Mon Sep 17 00:00:00 2001 From: Rubikoid Date: Sun, 4 Jul 2021 21:53:40 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20include/exclude=20for=20di?= =?UTF-8?q?cts=20in=20`jsonable=5Fencoder`=20(#2016)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- fastapi/encoders.py | 4 +- tests/test_response_model_include_exclude.py | 174 +++++++++++++++++++ 2 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 tests/test_response_model_include_exclude.py diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 6a2a75dda..51cab419d 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -36,9 +36,9 @@ def jsonable_encoder( custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, sqlalchemy_safe: bool = True, ) -> Any: - if include is not None and not isinstance(include, set): + if include is not None and not isinstance(include, (set, dict)): include = set(include) - if exclude is not None and not isinstance(exclude, set): + if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) if isinstance(obj, BaseModel): encoder = getattr(obj.__config__, "json_encoders", {}) diff --git a/tests/test_response_model_include_exclude.py b/tests/test_response_model_include_exclude.py new file mode 100644 index 000000000..533f8105b --- /dev/null +++ b/tests/test_response_model_include_exclude.py @@ -0,0 +1,174 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel + + +class Test(BaseModel): + foo: str + bar: str + + +class Test2(BaseModel): + test: Test + baz: str + + +class Test3(BaseModel): + name: str + age: int + test2: Test2 + + +app = FastAPI() + + +@app.get( + "/simple_include", + response_model=Test2, + response_model_include={"baz": ..., "test": {"foo"}}, +) +def simple_include(): + return Test2( + test=Test(foo="simple_include test foo", bar="simple_include test bar"), + baz="simple_include test2 baz", + ) + + +@app.get( + "/simple_include_dict", + response_model=Test2, + response_model_include={"baz": ..., "test": {"foo"}}, +) +def simple_include_dict(): + return { + "test": { + "foo": "simple_include_dict test foo", + "bar": "simple_include_dict test bar", + }, + "baz": "simple_include_dict test2 baz", + } + + +@app.get( + "/simple_exclude", + response_model=Test2, + response_model_exclude={"test": {"bar"}}, +) +def simple_exclude(): + return Test2( + test=Test(foo="simple_exclude test foo", bar="simple_exclude test bar"), + baz="simple_exclude test2 baz", + ) + + +@app.get( + "/simple_exclude_dict", + response_model=Test2, + response_model_exclude={"test": {"bar"}}, +) +def simple_exclude_dict(): + return { + "test": { + "foo": "simple_exclude_dict test foo", + "bar": "simple_exclude_dict test bar", + }, + "baz": "simple_exclude_dict test2 baz", + } + + +@app.get( + "/mixed", + response_model=Test3, + response_model_include={"test2", "name"}, + response_model_exclude={"test2": {"baz"}}, +) +def mixed(): + return Test3( + name="mixed test3 name", + age=3, + test2=Test2( + test=Test(foo="mixed test foo", bar="mixed test bar"), baz="mixed test2 baz" + ), + ) + + +@app.get( + "/mixed_dict", + response_model=Test3, + response_model_include={"test2", "name"}, + response_model_exclude={"test2": {"baz"}}, +) +def mixed_dict(): + return { + "name": "mixed_dict test3 name", + "age": 3, + "test2": { + "test": {"foo": "mixed_dict test foo", "bar": "mixed_dict test bar"}, + "baz": "mixed_dict test2 baz", + }, + } + + +client = TestClient(app) + + +def test_nested_include_simple(): + response = client.get("/simple_include") + + assert response.status_code == 200, response.text + + assert response.json() == { + "baz": "simple_include test2 baz", + "test": {"foo": "simple_include test foo"}, + } + + +def test_nested_include_simple_dict(): + response = client.get("/simple_include_dict") + + assert response.status_code == 200, response.text + + assert response.json() == { + "baz": "simple_include_dict test2 baz", + "test": {"foo": "simple_include_dict test foo"}, + } + + +def test_nested_exclude_simple(): + response = client.get("/simple_exclude") + assert response.status_code == 200, response.text + assert response.json() == { + "baz": "simple_exclude test2 baz", + "test": {"foo": "simple_exclude test foo"}, + } + + +def test_nested_exclude_simple_dict(): + response = client.get("/simple_exclude_dict") + assert response.status_code == 200, response.text + assert response.json() == { + "baz": "simple_exclude_dict test2 baz", + "test": {"foo": "simple_exclude_dict test foo"}, + } + + +def test_nested_include_mixed(): + response = client.get("/mixed") + assert response.status_code == 200, response.text + assert response.json() == { + "name": "mixed test3 name", + "test2": { + "test": {"foo": "mixed test foo", "bar": "mixed test bar"}, + }, + } + + +def test_nested_include_mixed_dict(): + response = client.get("/mixed_dict") + assert response.status_code == 200, response.text + assert response.json() == { + "name": "mixed_dict test3 name", + "test2": { + "test": {"foo": "mixed_dict test foo", "bar": "mixed_dict test bar"}, + }, + }