diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 6b6e5d0f7..82e3ffa06 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -12,12 +12,20 @@ def jsonable_encoder( exclude: Set[str] = set(), by_alias: bool = False, include_none: bool = True, + custom_encoder: dict = {}, ) -> Any: if isinstance(obj, BaseModel): - return jsonable_encoder( - obj.dict(include=include, exclude=exclude, by_alias=by_alias), - include_none=include_none, - ) + if not obj.Config.json_encoders: + return jsonable_encoder( + obj.dict(include=include, exclude=exclude, by_alias=by_alias), + include_none=include_none, + ) + else: + return jsonable_encoder( + obj.dict(include=include, exclude=exclude, by_alias=by_alias), + include_none=include_none, + custom_encoder=obj.Config.json_encoders, + ) if isinstance(obj, Enum): return obj.value if isinstance(obj, (str, int, float, type(None))): @@ -25,8 +33,16 @@ def jsonable_encoder( if isinstance(obj, dict): return { jsonable_encoder( - key, by_alias=by_alias, include_none=include_none - ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none) + key, + by_alias=by_alias, + include_none=include_none, + custom_encoder=custom_encoder, + ): jsonable_encoder( + value, + by_alias=by_alias, + include_none=include_none, + custom_encoder=custom_encoder, + ) for key, value in obj.items() if value is not None or include_none } @@ -38,12 +54,16 @@ def jsonable_encoder( exclude=exclude, by_alias=by_alias, include_none=include_none, + custom_encoder=custom_encoder, ) for item in obj ] errors = [] try: - encoder = ENCODERS_BY_TYPE[type(obj)] + if custom_encoder and type(obj) in custom_encoder: + encoder = custom_encoder[type(obj)] + else: + encoder = ENCODERS_BY_TYPE[type(obj)] return encoder(obj) except KeyError as e: errors.append(e) diff --git a/scripts/test.sh b/scripts/test.sh old mode 100644 new mode 100755 diff --git a/tests/test_datetime.py b/tests/test_datetime.py new file mode 100644 index 000000000..c16166ca4 --- /dev/null +++ b/tests/test_datetime.py @@ -0,0 +1,35 @@ +import json +from datetime import datetime, timezone + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.testclient import TestClient + + +class ModelWithDatetimeField(BaseModel): + dt_field: datetime + + class Config: + json_encoders = { + datetime: lambda dt: dt.replace( + microsecond=0, tzinfo=timezone.utc + ).isoformat() + } + + +app = FastAPI() +model = ModelWithDatetimeField(dt_field=datetime.utcnow()) + + +@app.get("/model", response_model=ModelWithDatetimeField) +def get_model(): + return model + + +client = TestClient(app) + + +def test_dt(): + with client: + response = client.get("/model") + assert json.loads(model.json()) == response.json()