Browse Source

Apply custom encoders to model field values

pull/15654/head
jiyujie2006 5 days ago
parent
commit
dc27bbc317
  1. 52
      fastapi/encoders.py
  2. 24
      tests/test_jsonable_encoder.py

52
fastapi/encoders.py

@ -242,19 +242,23 @@ def jsonable_encoder(
exclude = set(exclude) # type: ignore[assignment] # ty: ignore[invalid-assignment]
if isinstance(obj, BaseModel):
if custom_encoder:
encoded_values: dict[int, Any] = {}
no_encoder = object()
def custom_encoder_fallback(value: Any) -> Any:
def custom_encode(value: Any) -> Any:
if type(value) in custom_encoder:
encoded_value = custom_encoder[type(value)](value)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(value, encoder_type):
encoded_value = encoder_instance(value)
break
else:
raise TypeError(
f"Object of type {type(value).__name__} is not JSON serializable"
)
return custom_encoder[type(value)](value)
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(value, encoder_type):
return encoder_instance(value)
return no_encoder
def jsonable_custom_encoded(value: Any) -> Any:
encoded_value = custom_encode(value)
if encoded_value is no_encoder:
raise TypeError(
f"Object of type {type(value).__name__} is not JSON serializable"
)
return jsonable_encoder(
encoded_value,
by_alias=by_alias,
@ -265,6 +269,11 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
def custom_encoder_fallback(value: Any) -> Any:
encoded_value = jsonable_custom_encoded(value)
encoded_values[id(value)] = encoded_value
return encoded_value
obj_dict = obj.__pydantic_serializer__.to_python(
obj,
mode="json",
@ -276,6 +285,27 @@ def jsonable_encoder(
exclude_defaults=exclude_defaults,
fallback=custom_encoder_fallback,
)
for field_name, field in type(obj).model_fields.items():
value = getattr(obj, field_name)
encoded_value = encoded_values.get(id(value), no_encoder)
if encoded_value is no_encoder:
encoded_value = custom_encode(value)
if encoded_value is no_encoder:
continue
encoded_value = jsonable_encoder(
encoded_value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
field_key = field_name
if by_alias:
field_key = field.serialization_alias or field.alias or field_name
if field_key in obj_dict:
obj_dict[field_key] = encoded_value
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,

24
tests/test_jsonable_encoder.py

@ -285,18 +285,36 @@ def test_custom_encoder_model_field_uses_caller_options():
def test_custom_encoder_model_field_does_not_encode_field_names():
class ModelWithString(BaseModel):
value: str
assert jsonable_encoder(
ModelWithString(value="encoded"),
custom_encoder={str: str.upper},
) == {"value": "ENCODED"}
def test_custom_encoder_model_field_applies_to_known_field_types():
class CustomValue:
pass
class NestedModel(BaseModel):
value: int
class ModelWithCustomValue(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
value: CustomValue
nested: NestedModel
assert jsonable_encoder(
ModelWithCustomValue(value=CustomValue()),
custom_encoder={CustomValue: lambda _: "encoded", str: str.upper},
) == {"value": "ENCODED"}
ModelWithCustomValue(value=CustomValue(), nested=NestedModel(value=1)),
custom_encoder={
CustomValue: lambda _: "encoded",
NestedModel: lambda _: "nested",
str: str.upper,
},
) == {"value": "ENCODED", "nested": "NESTED"}
def test_custom_enum_encoders():

Loading…
Cancel
Save