diff --git a/fastapi/encoders.py b/fastapi/encoders.py index d75664deb..cb94a91a7 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -1,6 +1,6 @@ from enum import Enum from types import GeneratorType -from typing import Any, Dict, List, Set, Union +from typing import Any, Callable, Dict, List, Set, Tuple, Union from fastapi.logger import logger from fastapi.utils import PYDANTIC_1 @@ -11,6 +11,21 @@ SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] +def generate_encoders_by_class_tuples( + type_encoder_map: Dict[Any, Callable] +) -> Dict[Callable, Tuple]: + encoders_by_classes: Dict[Callable, List] = {} + for type_, encoder in type_encoder_map.items(): + encoders_by_classes.setdefault(encoder, []).append(type_) + encoders_by_class_tuples: Dict[Callable, Tuple] = {} + for encoder, classes in encoders_by_classes.items(): + encoders_by_class_tuples[encoder] = tuple(classes) + return encoders_by_class_tuples + + +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) + + def jsonable_encoder( obj: Any, include: Union[SetIntStr, DictIntStrAny] = None, @@ -106,24 +121,31 @@ def jsonable_encoder( ) ) return encoded_list + + if custom_encoder: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + else: + for encoder_type, encoder in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder(obj) + + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + errors: List[Exception] = [] try: - if custom_encoder and type(obj) in custom_encoder: - encoder = custom_encoder[type(obj)] - else: - encoder = ENCODERS_BY_TYPE[type(obj)] - return encoder(obj) - except KeyError as e: + data = dict(obj) + except Exception as e: errors.append(e) try: - data = dict(obj) + data = vars(obj) except Exception as e: errors.append(e) - try: - data = vars(obj) - except Exception as e: - errors.append(e) - raise ValueError(errors) + raise ValueError(errors) return jsonable_encoder( data, by_alias=by_alias, diff --git a/tests/test_inherited_custom_class.py b/tests/test_inherited_custom_class.py new file mode 100644 index 000000000..a9f673898 --- /dev/null +++ b/tests/test_inherited_custom_class.py @@ -0,0 +1,73 @@ +import uuid + +import pytest +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + +app = FastAPI() + + +class MyUuid: + def __init__(self, uuid_string: str): + self.uuid = uuid_string + + def __str__(self): + return self.uuid + + @property + def __class__(self): + return uuid.UUID + + @property + def __dict__(self): + """Spoof a missing __dict__ by raising TypeError, this is how + asyncpg.pgroto.pgproto.UUID behaves""" + raise TypeError("vars() argument must have __dict__ attribute") + + +@app.get("/fast_uuid") +def return_fast_uuid(): + # I don't want to import asyncpg for this test so I made my own UUID + # Import asyncpg and uncomment the two lines below for the actual bug + + # from asyncpg.pgproto import pgproto + # asyncpg_uuid = pgproto.UUID("a10ff360-3b1e-4984-a26f-d3ab460bdb51") + + asyncpg_uuid = MyUuid("a10ff360-3b1e-4984-a26f-d3ab460bdb51") + assert isinstance(asyncpg_uuid, uuid.UUID) + assert type(asyncpg_uuid) != uuid.UUID + with pytest.raises(TypeError): + vars(asyncpg_uuid) + return {"fast_uuid": asyncpg_uuid} + + +class SomeCustomClass(BaseModel): + class Config: + arbitrary_types_allowed = True + json_encoders = {uuid.UUID: str} + + a_uuid: MyUuid + + +@app.get("/get_custom_class") +def return_some_user(): + # Test that the fix also works for custom pydantic classes + return SomeCustomClass(a_uuid=MyUuid("b8799909-f914-42de-91bc-95c819218d01")) + + +client = TestClient(app) + + +def test_dt(): + with client: + response_simple = client.get("/fast_uuid") + response_pydantic = client.get("/get_custom_class") + + assert response_simple.json() == { + "fast_uuid": "a10ff360-3b1e-4984-a26f-d3ab460bdb51" + } + + assert response_pydantic.json() == { + "a_uuid": "b8799909-f914-42de-91bc-95c819218d01" + }