Ben Brady 3 days ago
committed by GitHub
parent
commit
0647595beb
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 237
      fastapi/encoders.py
  2. 32
      tests/test_jsonable_encoder.py

237
fastapi/encoders.py

@ -14,7 +14,7 @@ from ipaddress import (
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from uuid import UUID
from fastapi.types import IncEx
@ -64,9 +64,6 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
@ -75,11 +72,12 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
IPv6Network: str,
NameEmail: str,
Path: str,
PurePath: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
UndefinedType: lambda _: None,
Url: str,
AnyUrl: str,
}
@ -98,6 +96,8 @@ def generate_encoders_by_class_tuples(
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
NoneType = type(None)
def jsonable_encoder(
obj: Annotated[
@ -201,18 +201,74 @@ def jsonable_encoder(
Read more about it in the
[FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/).
"""
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if include is not None and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
return encode_value(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_value(
obj: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
if custom_encoder:
encoder = find_encoder(obj, custom_encoder)
if encoder:
return encoder(obj)
if isinstance(obj, (str, int, float, NoneType)): # type: ignore[arg-type, misc]
return obj
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
value = encode_value(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_list.append(value)
return encoded_list
if isinstance(obj, dict):
return encode_dict(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, BaseModel):
# TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {}
@ -220,6 +276,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
obj,
mode="json",
@ -230,9 +287,11 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
return encode_value(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
@ -240,104 +299,106 @@ def jsonable_encoder(
custom_encoder=encoders,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
return encode_dict(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, UndefinedType):
return None
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and key in allowed_keys
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
encoder = find_encoder(obj, ENCODERS_BY_TYPE)
if encoder:
return encoder(obj)
try:
data = dict(obj)
obj_dict = dict(obj)
except Exception as e:
errors: List[Exception] = []
errors.append(e)
errors = [e]
try:
data = vars(obj)
obj_dict = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(
data,
return encode_dict(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_dict(
obj: Dict[Any, Any],
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if key not in allowed_keys:
continue
if value is None and exclude_none:
continue
if sqlalchemy_safe and isinstance(key, str) and key.startswith("_sa"):
continue
encoded_key = encode_value(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = encode_value(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
def find_encoder(
value: Any, encoders: Dict[Any, Callable[[Any], Any]]
) -> Optional[Callable[[Any], Any]]:
# fastpath for exact class match
encoder = encoders.get(type(value))
if encoder:
return encoder
# fallback to isinstance which uses MRO
for encoder_type, encoder in encoders.items():
if isinstance(value, encoder_type):
return encoder
return None

32
tests/test_jsonable_encoder.py

@ -88,6 +88,38 @@ def test_encode_dict():
}
def test_encode_dict_with_nonprimative_keys():
class CustomString:
value: str
def __init__(self, value: str) -> None:
self.value = value
def __hash__(self):
return hash(self.value)
assert jsonable_encoder(
{CustomString("foo"): "bar"}, custom_encoder={CustomString: lambda v: v.value}
) == {"foo": "bar"}
def test_encode_dict_with_custom_encoder_keys():
assert jsonable_encoder(
{"foo": "bar"}, custom_encoder={str: lambda v: "_" + v}
) == {"_foo": "_bar"}
def test_encode_dict_with_sqlalchemy_safe():
obj = {"_sa_foo": "foo", "bar": "bar"}
assert jsonable_encoder(obj, sqlalchemy_safe=True) == {"bar": "bar"}
assert jsonable_encoder(obj, sqlalchemy_safe=False) == obj
def test_encode_dict_with_exclude_none():
assert jsonable_encoder({"foo": None}, exclude_none=True) == {}
assert jsonable_encoder({"foo": None}, exclude_none=False) == {"foo": None}
def test_encode_class():
person = Person(name="Foo")
pet = Pet(owner=person, name="Firulais")

Loading…
Cancel
Save