diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 25ef6dc12..6b6e5d0f7 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -3,7 +3,7 @@ from types import GeneratorType from typing import Any, Set from pydantic import BaseModel -from pydantic.json import pydantic_encoder +from pydantic.json import ENCODERS_BY_TYPE def jsonable_encoder( @@ -12,64 +12,11 @@ def jsonable_encoder( exclude: Set[str] = set(), by_alias: bool = False, include_none: bool = True, - root_encoder: bool = True, -) -> Any: - errors = [] - try: - return known_data_encoder( - obj, - include=include, - exclude=exclude, - by_alias=by_alias, - include_none=include_none, - ) - except Exception as e: - if not root_encoder: - raise e - errors.append(e) - try: - data = dict(obj) - return jsonable_encoder( - data, - include=include, - exclude=exclude, - by_alias=by_alias, - include_none=include_none, - root_encoder=False, - ) - except Exception as e: - if not root_encoder: - raise e - errors.append(e) - try: - data = vars(obj) - return jsonable_encoder( - data, - include=include, - exclude=exclude, - by_alias=by_alias, - include_none=include_none, - root_encoder=False, - ) - except Exception as e: - if not root_encoder: - raise e - errors.append(e) - raise ValueError(errors) - - -def known_data_encoder( - obj: Any, - include: Set[str] = None, - exclude: Set[str] = set(), - by_alias: bool = False, - include_none: bool = True, ) -> Any: if isinstance(obj, BaseModel): return jsonable_encoder( obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none, - root_encoder=False, ) if isinstance(obj, Enum): return obj.value @@ -78,10 +25,8 @@ def known_data_encoder( if isinstance(obj, dict): return { jsonable_encoder( - key, by_alias=by_alias, include_none=include_none, root_encoder=False - ): jsonable_encoder( - value, by_alias=by_alias, include_none=include_none, root_encoder=False - ) + key, by_alias=by_alias, include_none=include_none + ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none) for key, value in obj.items() if value is not None or include_none } @@ -93,8 +38,22 @@ def known_data_encoder( exclude=exclude, by_alias=by_alias, include_none=include_none, - root_encoder=False, ) for item in obj ] - return pydantic_encoder(obj) + errors = [] + try: + encoder = ENCODERS_BY_TYPE[type(obj)] + return encoder(obj) + except KeyError as e: + errors.append(e) + try: + data = dict(obj) + except Exception as e: + errors.append(e) + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) + return jsonable_encoder(data, by_alias=by_alias, include_none=include_none) diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py new file mode 100644 index 000000000..9108df9d8 --- /dev/null +++ b/tests/test_jsonable_encoder.py @@ -0,0 +1,50 @@ +import pytest +from fastapi.encoders import jsonable_encoder + + +class Person: + def __init__(self, name: str): + self.name = name + + +class Pet: + def __init__(self, owner: Person, name: str): + self.owner = owner + self.name = name + + +class DictablePerson(Person): + def __iter__(self): + return ((k, v) for k, v in self.__dict__.items()) + + +class DictablePet(Pet): + def __iter__(self): + return ((k, v) for k, v in self.__dict__.items()) + + +class Unserializable: + def __iter__(self): + raise NotImplementedError() + + @property + def __dict__(self): + raise NotImplementedError() + + +def test_encode_class(): + person = Person(name="Foo") + pet = Pet(owner=person, name="Firulais") + assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}} + + +def test_encode_dictable(): + person = DictablePerson(name="Foo") + pet = DictablePet(owner=person, name="Firulais") + assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}} + + +def test_encode_unsupported(): + unserializable = Unserializable() + with pytest.raises(ValueError): + jsonable_encoder(unserializable)