Ben Brady 1 week ago
committed by GitHub
parent
commit
0647595beb
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 219
      fastapi/encoders.py
  2. 32
      tests/test_jsonable_encoder.py

219
fastapi/encoders.py

@ -14,7 +14,7 @@ from ipaddress import (
from pathlib import Path, PurePath from pathlib import Path, PurePath
from re import Pattern from re import Pattern
from types import GeneratorType from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from uuid import UUID from uuid import UUID
from fastapi.types import IncEx from fastapi.types import IncEx
@ -64,9 +64,6 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
datetime.timedelta: lambda td: td.total_seconds(), datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder, Decimal: decimal_encoder,
Enum: lambda o: o.value, Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str, IPv4Address: str,
IPv4Interface: str, IPv4Interface: str,
IPv4Network: str, IPv4Network: str,
@ -75,11 +72,12 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
IPv6Network: str, IPv6Network: str,
NameEmail: str, NameEmail: str,
Path: str, Path: str,
PurePath: str,
Pattern: lambda o: o.pattern, Pattern: lambda o: o.pattern,
SecretBytes: str, SecretBytes: str,
SecretStr: str, SecretStr: str,
set: list,
UUID: str, UUID: str,
UndefinedType: lambda _: None,
Url: str, Url: str,
AnyUrl: str, AnyUrl: str,
} }
@ -98,6 +96,8 @@ def generate_encoders_by_class_tuples(
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
NoneType = type(None)
def jsonable_encoder( def jsonable_encoder(
obj: Annotated[ obj: Annotated[
@ -201,18 +201,74 @@ def jsonable_encoder(
Read more about it in the Read more about it in the
[FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/). [FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/).
""" """
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if include is not None and not isinstance(include, (set, dict)): if include is not None and not isinstance(include, (set, dict)):
include = set(include) include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)): if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude) exclude = set(exclude)
return encode_value(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_value(
obj: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
if custom_encoder:
encoder = find_encoder(obj, custom_encoder)
if encoder:
return encoder(obj)
if isinstance(obj, (str, int, float, NoneType)): # type: ignore[arg-type, misc]
return obj
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
value = encode_value(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_list.append(value)
return encoded_list
if isinstance(obj, dict):
return encode_dict(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, BaseModel): if isinstance(obj, BaseModel):
# TODO: remove when deprecating Pydantic v1 # TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {} encoders: Dict[Any, Any] = {}
@ -220,6 +276,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder: if custom_encoder:
encoders.update(custom_encoder) encoders.update(custom_encoder)
obj_dict = _model_dump( obj_dict = _model_dump(
obj, obj,
mode="json", mode="json",
@ -230,9 +287,11 @@ def jsonable_encoder(
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
) )
if "__root__" in obj_dict: if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"] obj_dict = obj_dict["__root__"]
return jsonable_encoder(
return encode_value(
obj_dict, obj_dict,
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
@ -240,45 +299,74 @@ def jsonable_encoder(
custom_encoder=encoders, custom_encoder=encoders,
sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
if dataclasses.is_dataclass(obj): if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj) obj_dict = dataclasses.asdict(obj)
return jsonable_encoder( return encode_dict(
obj_dict, obj_dict,
include=include, include=include,
exclude=exclude, exclude=exclude,
by_alias=by_alias, by_alias=by_alias,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none, exclude_none=exclude_none,
custom_encoder=custom_encoder, custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
if isinstance(obj, Enum):
return obj.value encoder = find_encoder(obj, ENCODERS_BY_TYPE)
if isinstance(obj, PurePath): if encoder:
return str(obj) return encoder(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj try:
if isinstance(obj, UndefinedType): obj_dict = dict(obj)
return None except Exception as e:
if isinstance(obj, dict): errors = [e]
try:
obj_dict = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return encode_dict(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_dict(
obj: Dict[Any, Any],
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
encoded_dict = {} encoded_dict = {}
allowed_keys = set(obj.keys()) allowed_keys = set(obj.keys())
if include is not None: if include is not None:
allowed_keys &= set(include) allowed_keys &= set(include)
if exclude is not None: if exclude is not None:
allowed_keys -= set(exclude) allowed_keys -= set(exclude)
for key, value in obj.items(): for key, value in obj.items():
if ( if key not in allowed_keys:
( continue
not sqlalchemy_safe if value is None and exclude_none:
or (not isinstance(key, str)) continue
or (not key.startswith("_sa")) if sqlalchemy_safe and isinstance(key, str) and key.startswith("_sa"):
) continue
and (value is not None or not exclude_none)
and key in allowed_keys encoded_key = encode_value(
):
encoded_key = jsonable_encoder(
key, key,
by_alias=by_alias, by_alias=by_alias,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
@ -286,7 +374,8 @@ def jsonable_encoder(
custom_encoder=custom_encoder, custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
encoded_value = jsonable_encoder(
encoded_value = encode_value(
value, value,
by_alias=by_alias, by_alias=by_alias,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
@ -295,49 +384,21 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
encoded_dict[encoded_key] = encoded_value encoded_dict[encoded_key] = encoded_value
return encoded_dict return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
try: def find_encoder(
data = dict(obj) value: Any, encoders: Dict[Any, Callable[[Any], Any]]
except Exception as e: ) -> Optional[Callable[[Any], Any]]:
errors: List[Exception] = [] # fastpath for exact class match
errors.append(e) encoder = encoders.get(type(value))
try: if encoder:
data = vars(obj) return encoder
except Exception as e:
errors.append(e) # fallback to isinstance which uses MRO
raise ValueError(errors) from e for encoder_type, encoder in encoders.items():
return jsonable_encoder( if isinstance(value, encoder_type):
data, return encoder
include=include,
exclude=exclude, return None
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

32
tests/test_jsonable_encoder.py

@ -88,6 +88,38 @@ def test_encode_dict():
} }
def test_encode_dict_with_nonprimative_keys():
class CustomString:
value: str
def __init__(self, value: str) -> None:
self.value = value
def __hash__(self):
return hash(self.value)
assert jsonable_encoder(
{CustomString("foo"): "bar"}, custom_encoder={CustomString: lambda v: v.value}
) == {"foo": "bar"}
def test_encode_dict_with_custom_encoder_keys():
assert jsonable_encoder(
{"foo": "bar"}, custom_encoder={str: lambda v: "_" + v}
) == {"_foo": "_bar"}
def test_encode_dict_with_sqlalchemy_safe():
obj = {"_sa_foo": "foo", "bar": "bar"}
assert jsonable_encoder(obj, sqlalchemy_safe=True) == {"bar": "bar"}
assert jsonable_encoder(obj, sqlalchemy_safe=False) == obj
def test_encode_dict_with_exclude_none():
assert jsonable_encoder({"foo": None}, exclude_none=True) == {}
assert jsonable_encoder({"foo": None}, exclude_none=False) == {"foo": None}
def test_encode_class(): def test_encode_class():
person = Person(name="Foo") person = Person(name="Foo")
pet = Pet(owner=person, name="Firulais") pet = Pet(owner=person, name="Firulais")

Loading…
Cancel
Save