Browse Source

Allow v1.BaseModels to be used when pydantic v2 is installed

pull/10223/head
chbndrhnns 2 years ago
parent
commit
b9d2e9b6a1
  1. 68
      fastapi/_compat.py
  2. 80
      tests/test_pydantic_v1_models.py
  3. 25
      tests/test_response_model_v1.py

68
fastapi/_compat.py

@ -19,14 +19,13 @@ from typing import (
from fastapi.exceptions import RequestErrorModel
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic import BaseModel, PydanticDeprecatedSince20, create_model, v1
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
sequence_annotation_to_type = {
Sequence: list,
List: list,
@ -98,9 +97,12 @@ if PYDANTIC_V2:
return self.field_info.annotation
def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
try:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
except PydanticDeprecatedSince20:
pass
def get_default(self) -> Any:
if self.field_info.is_required():
@ -123,6 +125,14 @@ if PYDANTIC_V2:
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=loc
)
except AttributeError:
# pydantic v1
try:
return v1.parse_obj_as(self.type_, value), None
except v1.ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=loc
)
def serialize(
self,
@ -136,18 +146,42 @@ if PYDANTIC_V2:
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
try:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
except AttributeError:
# pydantic v1
try:
return value.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
except AttributeError:
return [
item.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for item in value
]
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from

80
tests/test_pydantic_v1_models.py

@ -0,0 +1,80 @@
from typing import Annotated
import pytest
from fastapi import Body, FastAPI
from fastapi.exceptions import ResponseValidationError
from fastapi.testclient import TestClient
from pydantic import v1
class Item(v1.BaseModel):
name: str
description: str | None = None
price: float
tax: float | None = None
tags: list = []
class Model(v1.BaseModel):
name: str
app = FastAPI()
@app.post("/request_body")
async def request_body(body: Annotated[Item, Body()]):
return body
@app.get("/response_model", response_model=Model)
async def response_model():
return Model(name="valid_model")
@app.get("/response_model__invalid", response_model=Model)
async def response_model__invalid():
return 1
@app.get("/response_model_list", response_model=list[Model])
async def response_model_list():
return [Model(name="valid_model")]
@app.get("/response_model_list__invalid", response_model=list[Model])
async def response_model_list__invalid():
return [1]
client = TestClient(app)
class TestResponseModel:
def test_simple__valid(self):
response = client.get("/response_model")
assert response.status_code == 200
assert response.json() == {"name": "valid_model"}
def test_simple__invalid(self):
with pytest.raises(ResponseValidationError):
client.get("/response_model__invalid")
def test_list__valid(self):
response = client.get("/response_model_list")
assert response.status_code == 200
assert response.json() == [{"name": "valid_model"}]
def test_list__invalid(self):
with pytest.raises(ResponseValidationError):
client.get("/response_model_list__invalid")
class TestRequestBody:
def test_model__valid(self):
response = client.post("/request_body", json={"name": "myname", "price": 1.0})
assert response.status_code == 200, response.text
def test_model__invalid(self):
response = client.post("/request_body", json={"name": "myname"})
assert response.status_code == 422, response.text

25
tests/test_response_model_v1.py

@ -1,25 +0,0 @@
from typing import List
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel, v1
class Model(v1.BaseModel):
name: str
app = FastAPI()
@app.get("/valid", response_model=Model)
def valid1():
pass
client = TestClient(app)
def test_path_operations():
response = client.get("/valid")
assert response.status_code == 200, response.text
Loading…
Cancel
Save