diff --git a/fastapi/_compat.py b/fastapi/_compat.py index eb55b08f2..daa8abbfa 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -19,14 +19,13 @@ from typing import ( from fastapi.exceptions import RequestErrorModel from fastapi.types import IncEx, ModelNameMap, UnionType -from pydantic import BaseModel, create_model +from pydantic import BaseModel, PydanticDeprecatedSince20, create_model, v1 from pydantic.version import VERSION as PYDANTIC_VERSION from starlette.datastructures import UploadFile from typing_extensions import Annotated, Literal, get_args, get_origin PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - sequence_annotation_to_type = { Sequence: list, List: list, @@ -98,9 +97,12 @@ if PYDANTIC_V2: return self.field_info.annotation def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info] - ) + try: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info] + ) + except PydanticDeprecatedSince20: + pass def get_default(self) -> Any: if self.field_info.is_required(): @@ -123,6 +125,14 @@ if PYDANTIC_V2: return None, _regenerate_error_with_loc( errors=exc.errors(), loc_prefix=loc ) + except AttributeError: + # pydantic v1 + try: + return v1.parse_obj_as(self.type_, value), None + except v1.ValidationError as exc: + return None, _regenerate_error_with_loc( + errors=exc.errors(), loc_prefix=loc + ) def serialize( self, @@ -136,18 +146,42 @@ if PYDANTIC_V2: exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: - # What calls this code passes a value that already called - # self._type_adapter.validate_python(value) - return self._type_adapter.dump_python( - value, - mode=mode, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) + try: + # What calls this code passes a value that already called + # self._type_adapter.validate_python(value) + return self._type_adapter.dump_python( + value, + mode=mode, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + except AttributeError: + # pydantic v1 + try: + return value.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + except AttributeError: + return [ + item.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + for item in value + ] def __hash__(self) -> int: # Each ModelField is unique for our purposes, to allow making a dict from diff --git a/tests/test_pydantic_v1_models.py b/tests/test_pydantic_v1_models.py new file mode 100644 index 000000000..e3af655f4 --- /dev/null +++ b/tests/test_pydantic_v1_models.py @@ -0,0 +1,80 @@ +from typing import Annotated + +import pytest +from fastapi import Body, FastAPI +from fastapi.exceptions import ResponseValidationError +from fastapi.testclient import TestClient +from pydantic import v1 + + +class Item(v1.BaseModel): + name: str + description: str | None = None + price: float + tax: float | None = None + tags: list = [] + + +class Model(v1.BaseModel): + name: str + + +app = FastAPI() + + +@app.post("/request_body") +async def request_body(body: Annotated[Item, Body()]): + return body + + +@app.get("/response_model", response_model=Model) +async def response_model(): + return Model(name="valid_model") + + +@app.get("/response_model__invalid", response_model=Model) +async def response_model__invalid(): + return 1 + + +@app.get("/response_model_list", response_model=list[Model]) +async def response_model_list(): + return [Model(name="valid_model")] + + +@app.get("/response_model_list__invalid", response_model=list[Model]) +async def response_model_list__invalid(): + return [1] + + +client = TestClient(app) + + +class TestResponseModel: + def test_simple__valid(self): + response = client.get("/response_model") + assert response.status_code == 200 + assert response.json() == {"name": "valid_model"} + + def test_simple__invalid(self): + with pytest.raises(ResponseValidationError): + client.get("/response_model__invalid") + + def test_list__valid(self): + response = client.get("/response_model_list") + assert response.status_code == 200 + assert response.json() == [{"name": "valid_model"}] + + def test_list__invalid(self): + with pytest.raises(ResponseValidationError): + client.get("/response_model_list__invalid") + + +class TestRequestBody: + def test_model__valid(self): + response = client.post("/request_body", json={"name": "myname", "price": 1.0}) + assert response.status_code == 200, response.text + + def test_model__invalid(self): + response = client.post("/request_body", json={"name": "myname"}) + assert response.status_code == 422, response.text diff --git a/tests/test_response_model_v1.py b/tests/test_response_model_v1.py deleted file mode 100644 index d5c5ca195..000000000 --- a/tests/test_response_model_v1.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List - -from fastapi import FastAPI -from fastapi.testclient import TestClient -from pydantic import BaseModel, v1 - - -class Model(v1.BaseModel): - name: str - - -app = FastAPI() - - -@app.get("/valid", response_model=Model) -def valid1(): - pass - - -client = TestClient(app) - - -def test_path_operations(): - response = client.get("/valid") - assert response.status_code == 200, response.text