Johannes Rueschel 1 day ago
committed by GitHub
parent
commit
0a6fbe3900
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 70
      fastapi/_compat.py
  2. 99
      tests/test_pydantic_v1_models.py

70
fastapi/_compat.py

@ -28,7 +28,6 @@ from typing_extensions import Annotated, Literal, get_args, get_origin
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
Sequence: list,
List: list,
@ -108,9 +107,14 @@ if PYDANTIC_V2:
return self.field_info.annotation
def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
from pydantic import PydanticDeprecatedSince20
try:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
except PydanticDeprecatedSince20:
pass
def get_default(self) -> Any:
if self.field_info.is_required():
@ -133,6 +137,16 @@ if PYDANTIC_V2:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
except AttributeError:
# pydantic v1
from pydantic import v1
try:
return v1.parse_obj_as(self.type_, value), None
except v1.ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=loc
)
def serialize(
self,
@ -146,18 +160,42 @@ if PYDANTIC_V2:
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
try:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
except AttributeError:
# pydantic v1
try:
return value.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
except AttributeError:
return [
item.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for item in value
]
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from

99
tests/test_pydantic_v1_models.py

@ -0,0 +1,99 @@
from typing import List, Optional
import pytest
from fastapi import Body, FastAPI
from fastapi._compat import PYDANTIC_V2
from fastapi.exceptions import ResponseValidationError
from fastapi.testclient import TestClient
from typing_extensions import Annotated
from tests.utils import needs_pydanticv2
if PYDANTIC_V2:
from pydantic import v1
class Item(v1.BaseModel):
name: str
description: Optional[str] = None
price: float
tax: Optional[float] = None
tags: list = []
class Model(v1.BaseModel):
name: str
else:
from pydantic import BaseModel
class Item(BaseModel):
name: str
description: Optional[str] = None
price: float
tax: Optional[float] = None
tags: list = []
class Model(BaseModel):
name: str
app = FastAPI()
@app.post("/request_body")
async def request_body(body: Annotated[Item, Body()]):
return body
@app.get("/response_model", response_model=Model)
async def response_model():
return Model(name="valid_model")
@app.get("/response_model__invalid", response_model=Model)
async def response_model__invalid():
return 1
@app.get("/response_model_list", response_model=List[Model])
async def response_model_list():
return [Model(name="valid_model")]
@app.get("/response_model_list__invalid", response_model=List[Model])
async def response_model_list__invalid():
return [1]
client = TestClient(app)
@needs_pydanticv2
class TestResponseModel:
def test_simple__valid(self):
response = client.get("/response_model")
assert response.status_code == 200
assert response.json() == {"name": "valid_model"}
def test_simple__invalid(self):
with pytest.raises(ResponseValidationError):
client.get("/response_model__invalid")
def test_list__valid(self):
response = client.get("/response_model_list")
assert response.status_code == 200
assert response.json() == [{"name": "valid_model"}]
def test_list__invalid(self):
with pytest.raises(ResponseValidationError):
client.get("/response_model_list__invalid")
@needs_pydanticv2
class TestRequestBody:
def test_model__valid(self):
response = client.post("/request_body", json={"name": "myname", "price": 1.0})
assert response.status_code == 200, response.text
def test_model__invalid(self):
response = client.post("/request_body", json={"name": "myname"})
assert response.status_code == 422, response.text
Loading…
Cancel
Save