diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..fa1c5604b 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -28,7 +28,6 @@ from typing_extensions import Annotated, Literal, get_args, get_origin PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 - sequence_annotation_to_type = { Sequence: list, List: list, @@ -108,9 +107,14 @@ if PYDANTIC_V2: return self.field_info.annotation def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info] - ) + from pydantic import PydanticDeprecatedSince20 + + try: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info] + ) + except PydanticDeprecatedSince20: + pass def get_default(self) -> Any: if self.field_info.is_required(): @@ -133,6 +137,16 @@ if PYDANTIC_V2: return None, _regenerate_error_with_loc( errors=exc.errors(include_url=False), loc_prefix=loc ) + except AttributeError: + # pydantic v1 + from pydantic import v1 + + try: + return v1.parse_obj_as(self.type_, value), None + except v1.ValidationError as exc: + return None, _regenerate_error_with_loc( + errors=exc.errors(), loc_prefix=loc + ) def serialize( self, @@ -146,18 +160,42 @@ if PYDANTIC_V2: exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: - # What calls this code passes a value that already called - # self._type_adapter.validate_python(value) - return self._type_adapter.dump_python( - value, - mode=mode, - include=include, - exclude=exclude, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) + try: + # What calls this code passes a value that already called + # self._type_adapter.validate_python(value) + return self._type_adapter.dump_python( + value, + mode=mode, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + except AttributeError: + # pydantic v1 + try: + return value.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + except AttributeError: + return [ + item.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + for item in value + ] def __hash__(self) -> int: # Each ModelField is unique for our purposes, to allow making a dict from diff --git a/tests/test_pydantic_v1_models.py b/tests/test_pydantic_v1_models.py new file mode 100644 index 000000000..2728d9f15 --- /dev/null +++ b/tests/test_pydantic_v1_models.py @@ -0,0 +1,99 @@ +from typing import List, Optional + +import pytest +from fastapi import Body, FastAPI +from fastapi._compat import PYDANTIC_V2 +from fastapi.exceptions import ResponseValidationError +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +from tests.utils import needs_pydanticv2 + +if PYDANTIC_V2: + from pydantic import v1 + + class Item(v1.BaseModel): + name: str + description: Optional[str] = None + price: float + tax: Optional[float] = None + tags: list = [] + + class Model(v1.BaseModel): + name: str + +else: + from pydantic import BaseModel + + class Item(BaseModel): + name: str + description: Optional[str] = None + price: float + tax: Optional[float] = None + tags: list = [] + + class Model(BaseModel): + name: str + + +app = FastAPI() + + +@app.post("/request_body") +async def request_body(body: Annotated[Item, Body()]): + return body + + +@app.get("/response_model", response_model=Model) +async def response_model(): + return Model(name="valid_model") + + +@app.get("/response_model__invalid", response_model=Model) +async def response_model__invalid(): + return 1 + + +@app.get("/response_model_list", response_model=List[Model]) +async def response_model_list(): + return [Model(name="valid_model")] + + +@app.get("/response_model_list__invalid", response_model=List[Model]) +async def response_model_list__invalid(): + return [1] + + +client = TestClient(app) + + +@needs_pydanticv2 +class TestResponseModel: + def test_simple__valid(self): + response = client.get("/response_model") + assert response.status_code == 200 + assert response.json() == {"name": "valid_model"} + + def test_simple__invalid(self): + with pytest.raises(ResponseValidationError): + client.get("/response_model__invalid") + + def test_list__valid(self): + response = client.get("/response_model_list") + assert response.status_code == 200 + assert response.json() == [{"name": "valid_model"}] + + def test_list__invalid(self): + with pytest.raises(ResponseValidationError): + client.get("/response_model_list__invalid") + + +@needs_pydanticv2 +class TestRequestBody: + def test_model__valid(self): + response = client.post("/request_body", json={"name": "myname", "price": 1.0}) + assert response.status_code == 200, response.text + + def test_model__invalid(self): + response = client.post("/request_body", json={"name": "myname"}) + assert response.status_code == 422, response.text