From 96789cf28aca9c6d53f2b6f4d88cb65a0cef4fa4 Mon Sep 17 00:00:00 2001 From: Maxime Goyette Date: Mon, 10 Mar 2025 18:29:35 -0400 Subject: [PATCH] Feature: jsonable_encoder should pass context to pydantic models --- fastapi/_compat.py | 2 ++ fastapi/encoders.py | 15 +++++++++++++++ tests/test_compat.py | 11 +++++++++++ tests/test_jsonable_encoder.py | 28 +++++++++++++++++++++++++++- 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..663ff526e 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -447,6 +447,8 @@ else: def _model_dump( model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any ) -> Any: + if not PYDANTIC_V2: + kwargs.pop("context", None) return model.dict(**kwargs) def _get_model_config(model: BaseModel) -> Any: diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 451ea0760..c9ec03a0e 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -126,6 +126,14 @@ def jsonable_encoder( """ ), ] = None, + context: Annotated[ + Optional[Any], + Doc( + """ + Pydantic's `context` parameter, passed to Pydantic serializers. + """ + ), + ] = None, by_alias: Annotated[ bool, Doc( @@ -225,6 +233,7 @@ def jsonable_encoder( mode="json", include=include, exclude=exclude, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_none=exclude_none, @@ -234,6 +243,7 @@ def jsonable_encoder( obj_dict = obj_dict["__root__"] return jsonable_encoder( obj_dict, + context=context, exclude_none=exclude_none, exclude_defaults=exclude_defaults, # TODO: remove when deprecating Pydantic v1 @@ -246,6 +256,7 @@ def jsonable_encoder( obj_dict, include=include, exclude=exclude, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, @@ -280,6 +291,7 @@ def jsonable_encoder( ): encoded_key = jsonable_encoder( key, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_none=exclude_none, @@ -288,6 +300,7 @@ def jsonable_encoder( ) encoded_value = jsonable_encoder( value, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_none=exclude_none, @@ -304,6 +317,7 @@ def jsonable_encoder( item, include=include, exclude=exclude, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, @@ -334,6 +348,7 @@ def jsonable_encoder( data, include=include, exclude=exclude, + context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, diff --git a/tests/test_compat.py b/tests/test_compat.py index f4a3093c5..65ecb7ce3 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -5,6 +5,7 @@ from fastapi._compat import ( ModelField, Undefined, _get_model_config, + _model_dump, get_cached_model_fields, get_model_fields, is_bytes_sequence_annotation, @@ -118,3 +119,13 @@ def test_get_model_fields_cached(): assert non_cached_fields is not non_cached_fields2 assert cached_fields is cached_fields2 + + +@needs_pydanticv1 +def test_model_dump_remove_context_kwarg_in_pydanticv1() -> None: + class Model(BaseModel): + foo: str + + # The following instruction would throw an error + # in pv1 if the context kwarg was not removed. + _model_dump(Model(foo="bar"), context=1) diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py index 1906d6bf1..094ac8d13 100644 --- a/tests/test_jsonable_encoder.py +++ b/tests/test_jsonable_encoder.py @@ -9,7 +9,11 @@ from typing import Optional import pytest from fastapi._compat import PYDANTIC_V2, Undefined from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, Field, ValidationError +from pydantic import ( + BaseModel, + Field, + ValidationError, +) from .utils import needs_pydanticv1, needs_pydanticv2 @@ -316,3 +320,25 @@ def test_encode_deque_encodes_child_models(): def test_encode_pydantic_undefined(): data = {"value": Undefined} assert jsonable_encoder(data) == {"value": None} + + +@needs_pydanticv2 +def test_encode_with_context() -> None: + from pydantic import SerializationInfo, field_serializer + + class ModelWithContextualSerializer(BaseModel): + value: int + + @field_serializer("value") + def serialize_value(value: int, info: SerializationInfo) -> int: + if info.context is not None and isinstance( + value_from_context := info.context.get("value"), int + ): + return value_from_context + + return value + + model = ModelWithContextualSerializer(value=1) + + assert jsonable_encoder(model) == {"value": 1} + assert jsonable_encoder(model, context={"value": 2}) == {"value": 2}