Browse Source

🔨 refactored jsonable_encoder

pull/13982/head
Ben Brady 4 weeks ago
parent
commit
39c51a38ec
  1. 215
      fastapi/encoders.py

215
fastapi/encoders.py

@ -75,11 +75,13 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
IPv6Network: str,
NameEmail: str,
Path: str,
PurePath: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
UndefinedType: lambda _: None,
Url: str,
AnyUrl: str,
}
@ -98,6 +100,10 @@ def generate_encoders_by_class_tuples(
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
NoneType = type(None)
primitive_types = (str, int, float, NoneType)
iterable_types = (list, set, frozenset, GeneratorType, tuple, deque)
def jsonable_encoder(
obj: Annotated[
@ -201,18 +207,74 @@ def jsonable_encoder(
Read more about it in the
[FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/).
"""
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if include is not None and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
return encode_value(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_value(
obj: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
if custom_encoder:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if isinstance(obj, primitive_types):
return obj
if isinstance(obj, iterable_types):
encoded_list = []
for item in obj:
value = encode_value(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_list.append(value)
return encoded_list
if isinstance(obj, dict):
return encode_dict(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, BaseModel):
# TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {}
@ -220,6 +282,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
obj,
mode="json",
@ -230,9 +293,11 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
return encode_value(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
@ -240,104 +305,96 @@ def jsonable_encoder(
custom_encoder=encoders,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict = dataclasses.asdict(obj) # type: ignore
return encode_dict(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, UndefinedType):
return None
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and key in allowed_keys
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
try:
data = dict(obj)
obj_dict = dict(obj)
except Exception as e:
errors: List[Exception] = []
errors.append(e)
errors = [e]
try:
data = vars(obj)
obj_dict = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(
data,
return encode_dict(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
def encode_dict(
obj: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if sqlalchemy_safe and isinstance(key, str) and key.startswith("_sa"):
continue
if value is None and exclude_none:
continue
if key not in allowed_keys:
continue
if isinstance(key, primitive_types):
encoded_key = key
else:
encoded_key = encode_value(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = encode_value(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict

Loading…
Cancel
Save