From ec08e235eaac9d40f5686c4eaa3e863f7f84c7ae Mon Sep 17 00:00:00 2001 From: kaiix Date: Fri, 28 Feb 2025 23:23:22 +0800 Subject: [PATCH] Don't revalidate the response content if it is of the same type as response_model --- fastapi/routing.py | 31 ++++++++++++++++---------- tests/test_serialize_response_model.py | 21 +++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 8ea4bb219..93f1dc333 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -143,6 +143,7 @@ def _merge_lifespan_context( async def serialize_response( *, field: Optional[ModelField] = None, + response_model: Any = Default(None), response_content: Any, include: Optional[IncEx] = None, exclude: Optional[IncEx] = None, @@ -152,7 +153,10 @@ async def serialize_response( exclude_none: bool = False, is_coroutine: bool = True, ) -> Any: - if field: + if not field: + return jsonable_encoder(response_content) + + if type(response_content) is not response_model: errors = [] if not hasattr(field, "serialize"): # pydantic v1 @@ -187,18 +191,18 @@ async def serialize_response( exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - - return jsonable_encoder( - value, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) else: - return jsonable_encoder(response_content) + value = response_content + + return jsonable_encoder( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) async def run_endpoint_function( @@ -220,6 +224,7 @@ def get_request_handler( status_code: Optional[int] = None, response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse), response_field: Optional[ModelField] = None, + response_model: Any = Default(None), response_model_include: Optional[IncEx] = None, response_model_exclude: Optional[IncEx] = None, response_model_by_alias: bool = True, @@ -327,6 +332,7 @@ def get_request_handler( content = await serialize_response( field=response_field, response_content=raw_response, + response_model=response_model, include=response_model_include, exclude=response_model_exclude, by_alias=response_model_by_alias, @@ -575,6 +581,7 @@ class APIRoute(routing.Route): status_code=self.status_code, response_class=self.response_class, response_field=self.secure_cloned_response_field, + response_model=self.response_model, response_model_include=self.response_model_include, response_model_exclude=self.response_model_exclude, response_model_by_alias=self.response_model_by_alias, diff --git a/tests/test_serialize_response_model.py b/tests/test_serialize_response_model.py index 3bb46b2e9..209efb3c4 100644 --- a/tests/test_serialize_response_model.py +++ b/tests/test_serialize_response_model.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from fastapi import FastAPI +from fastapi._compat import PYDANTIC_V2 from pydantic import BaseModel, Field from starlette.testclient import TestClient @@ -152,3 +153,23 @@ def test_validdict_exclude_unset(): "k2": {"aliased_name": "bar", "price": 1.0}, "k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]}, } + + +if not PYDANTIC_V2: + from pydantic import validator + + class AutoIncrement(BaseModel): + count: int + + @validator("count") + def auto_increment(cls, count: int): + return count + 1 + + @app.post("/increment", response_model=AutoIncrement) + async def increment(): + return AutoIncrement(count=0) + + def test_response_model_should_not_revalidate_response_content_if_they_had_same_type(): + response = client.post("/increment") + response.raise_for_status() + assert response.json() == {"count": 1}