From d4608a00cf4855021dfb1a780556e24dedc94b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 23 Jan 2022 17:32:04 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Prefer=20custom=20encoder=20over?= =?UTF-8?q?=20defaults=20if=20specified=20in=20`jsonable=5Fencoder`=20(#44?= =?UTF-8?q?67)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Vivek Sunder --- fastapi/encoders.py | 18 +++++++++--------- tests/test_jsonable_encoder.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 3f599c9fa..4b7ffe313 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -34,9 +34,17 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, + custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, ) -> Any: + custom_encoder = custom_encoder or {} + if custom_encoder: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + else: + for encoder_type, encoder_instance in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder_instance(obj) if include is not None and not isinstance(include, (set, dict)): include = set(include) if exclude is not None and not isinstance(exclude, (set, dict)): @@ -118,14 +126,6 @@ def jsonable_encoder( ) return encoded_list - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder(obj) - if type(obj) in ENCODERS_BY_TYPE: return ENCODERS_BY_TYPE[type(obj)](obj) for encoder, classes_tuple in encoders_by_class_tuples.items(): diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py index e2aa8adf8..fa82b5ea8 100644 --- a/tests/test_jsonable_encoder.py +++ b/tests/test_jsonable_encoder.py @@ -161,6 +161,21 @@ def test_custom_encoders(): assert encoded_instance["dt_field"] == instance.dt_field.isoformat() +def test_custom_enum_encoders(): + def custom_enum_encoder(v: Enum): + return v.value.lower() + + class MyEnum(Enum): + ENUM_VAL_1 = "ENUM_VAL_1" + + instance = MyEnum.ENUM_VAL_1 + + encoded_instance = jsonable_encoder( + instance, custom_encoder={MyEnum: custom_enum_encoder} + ) + assert encoded_instance == custom_enum_encoder(instance) + + def test_encode_model_with_path(model_with_path): if isinstance(model_with_path.path, PureWindowsPath): expected = "\\foo\\bar"