diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 451ea0760..88f287084 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -14,7 +14,7 @@ from ipaddress import ( from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from uuid import UUID from fastapi.types import IncEx @@ -64,9 +64,6 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { datetime.timedelta: lambda td: td.total_seconds(), Decimal: decimal_encoder, Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, IPv4Address: str, IPv4Interface: str, IPv4Network: str, @@ -75,11 +72,12 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { IPv6Network: str, NameEmail: str, Path: str, + PurePath: str, Pattern: lambda o: o.pattern, SecretBytes: str, SecretStr: str, - set: list, UUID: str, + UndefinedType: lambda _: None, Url: str, AnyUrl: str, } @@ -98,6 +96,8 @@ def generate_encoders_by_class_tuples( encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) +NoneType = type(None) + def jsonable_encoder( obj: Annotated[ @@ -201,18 +201,74 @@ def jsonable_encoder( Read more about it in the [FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/). """ - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) if include is not None and not isinstance(include, (set, dict)): include = set(include) + if exclude is not None and not isinstance(exclude, (set, dict)): exclude = set(exclude) + + return encode_value( + obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + + +def encode_value( + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, + sqlalchemy_safe: bool = True, +) -> Any: + if custom_encoder: + encoder = find_encoder(obj, custom_encoder) + if encoder: + return encoder(obj) + + if isinstance(obj, (str, int, float, NoneType)): # type: ignore[arg-type, misc] + return obj + + if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + encoded_list = [] + for item in obj: + value = encode_value( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_list.append(value) + + return encoded_list + + if isinstance(obj, dict): + return encode_dict( + obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + if isinstance(obj, BaseModel): # TODO: remove when deprecating Pydantic v1 encoders: Dict[Any, Any] = {} @@ -220,6 +276,7 @@ def jsonable_encoder( encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: encoders.update(custom_encoder) + obj_dict = _model_dump( obj, mode="json", @@ -230,9 +287,11 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) + if "__root__" in obj_dict: obj_dict = obj_dict["__root__"] - return jsonable_encoder( + + return encode_value( obj_dict, exclude_none=exclude_none, exclude_defaults=exclude_defaults, @@ -240,104 +299,106 @@ def jsonable_encoder( custom_encoder=encoders, sqlalchemy_safe=sqlalchemy_safe, ) + if dataclasses.is_dataclass(obj): obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( + return encode_dict( obj_dict, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, sqlalchemy_safe=sqlalchemy_safe, ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, (str, int, float, type(None))): - return obj - if isinstance(obj, UndefinedType): - return None - if isinstance(obj, dict): - encoded_dict = {} - allowed_keys = set(obj.keys()) - if include is not None: - allowed_keys &= set(include) - if exclude is not None: - allowed_keys -= set(exclude) - for key, value in obj.items(): - if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) - and (value is not None or not exclude_none) - and key in allowed_keys - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - ) - return encoded_list - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) + encoder = find_encoder(obj, ENCODERS_BY_TYPE) + if encoder: + return encoder(obj) try: - data = dict(obj) + obj_dict = dict(obj) except Exception as e: - errors: List[Exception] = [] - errors.append(e) + errors = [e] try: - data = vars(obj) + obj_dict = vars(obj) except Exception as e: errors.append(e) raise ValueError(errors) from e - return jsonable_encoder( - data, + + return encode_dict( + obj_dict, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, sqlalchemy_safe=sqlalchemy_safe, ) + + +def encode_dict( + obj: Dict[Any, Any], + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, + sqlalchemy_safe: bool = True, +) -> Any: + encoded_dict = {} + allowed_keys = set(obj.keys()) + + if include is not None: + allowed_keys &= set(include) + + if exclude is not None: + allowed_keys -= set(exclude) + + for key, value in obj.items(): + if key not in allowed_keys: + continue + if value is None and exclude_none: + continue + if sqlalchemy_safe and isinstance(key, str) and key.startswith("_sa"): + continue + + encoded_key = encode_value( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + + encoded_value = encode_value( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_dict[encoded_key] = encoded_value + + return encoded_dict + + +def find_encoder( + value: Any, encoders: Dict[Any, Callable[[Any], Any]] +) -> Optional[Callable[[Any], Any]]: + # fastpath for exact class match + encoder = encoders.get(type(value)) + if encoder: + return encoder + + # fallback to isinstance which uses MRO + for encoder_type, encoder in encoders.items(): + if isinstance(value, encoder_type): + return encoder + + return None diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py index 1906d6bf1..b1cb56c65 100644 --- a/tests/test_jsonable_encoder.py +++ b/tests/test_jsonable_encoder.py @@ -88,6 +88,38 @@ def test_encode_dict(): } +def test_encode_dict_with_nonprimative_keys(): + class CustomString: + value: str + + def __init__(self, value: str) -> None: + self.value = value + + def __hash__(self): + return hash(self.value) + + assert jsonable_encoder( + {CustomString("foo"): "bar"}, custom_encoder={CustomString: lambda v: v.value} + ) == {"foo": "bar"} + + +def test_encode_dict_with_custom_encoder_keys(): + assert jsonable_encoder( + {"foo": "bar"}, custom_encoder={str: lambda v: "_" + v} + ) == {"_foo": "_bar"} + + +def test_encode_dict_with_sqlalchemy_safe(): + obj = {"_sa_foo": "foo", "bar": "bar"} + assert jsonable_encoder(obj, sqlalchemy_safe=True) == {"bar": "bar"} + assert jsonable_encoder(obj, sqlalchemy_safe=False) == obj + + +def test_encode_dict_with_exclude_none(): + assert jsonable_encoder({"foo": None}, exclude_none=True) == {} + assert jsonable_encoder({"foo": None}, exclude_none=False) == {"foo": None} + + def test_encode_class(): person = Person(name="Foo") pet = Pet(owner=person, name="Firulais")