Browse Source

🐛 Fix handling arbitrary types when using `arbitrary_types_allowed=True` (#14482)

pull/14485/head
Sebastián Ramírez 7 months ago
committed by GitHub
parent
commit
42b250d14d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 32
      fastapi/_compat/v2.py
  2. 141
      tests/test_arbitrary_types.py

32
fastapi/_compat/v2.py

@ -1,7 +1,7 @@
import re import re
import warnings import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass, is_dataclass
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
@ -18,7 +18,7 @@ from typing import (
from fastapi._compat import may_v1, shared from fastapi._compat import may_v1, shared
from fastapi.openapi.constants import REF_TEMPLATE from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap, UnionType from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, TypeAdapter, create_model from pydantic import BaseModel, ConfigDict, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError from pydantic import ValidationError as ValidationError
@ -64,6 +64,7 @@ class ModelField:
field_info: FieldInfo field_info: FieldInfo
name: str name: str
mode: Literal["validation", "serialization"] = "validation" mode: Literal["validation", "serialization"] = "validation"
config: Union[ConfigDict, None] = None
@property @property
def alias(self) -> str: def alias(self) -> str:
@ -94,8 +95,14 @@ class ModelField:
warnings.simplefilter( warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning "ignore", category=UnsupportedFieldAttributeWarning
) )
annotated_args = (
self.field_info.annotation,
*self.field_info.metadata,
self.field_info,
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter( self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info] Annotated[annotated_args],
config=self.config,
) )
def get_default(self) -> Any: def get_default(self) -> Any:
@ -412,10 +419,21 @@ def create_body_model(
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [ model_fields: List[ModelField] = []
ModelField(field_info=field_info, name=name) for name, field_info in model.model_fields.items():
for name, field_info in model.model_fields.items() type_ = field_info.annotation
] if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
model_config = None
else:
model_config = model.model_config
model_fields.append(
ModelField(
field_info=field_info,
name=name,
config=model_config,
)
)
return model_fields
# Duplicate of several schema functions from Pydantic v1 to make them compatible with # Duplicate of several schema functions from Pydantic v1 to make them compatible with

141
tests/test_arbitrary_types.py

@ -0,0 +1,141 @@
from typing import List
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
from typing_extensions import Annotated
from .utils import needs_pydanticv2
@pytest.fixture(name="client")
def get_client():
from pydantic import (
BaseModel,
ConfigDict,
PlainSerializer,
TypeAdapter,
WithJsonSchema,
)
class FakeNumpyArray:
def __init__(self):
self.data = [1.0, 2.0, 3.0]
FakeNumpyArrayPydantic = Annotated[
FakeNumpyArray,
WithJsonSchema(TypeAdapter(List[float]).json_schema()),
PlainSerializer(lambda v: v.data),
]
class MyModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
custom_field: FakeNumpyArrayPydantic
app = FastAPI()
@app.get("/")
def test() -> MyModel:
return MyModel(custom_field=FakeNumpyArray())
client = TestClient(app)
return client
@needs_pydanticv2
def test_get(client: TestClient):
response = client.get("/")
assert response.json() == {"custom_field": [1.0, 2.0, 3.0]}
@needs_pydanticv2
def test_typeadapter():
# This test is only to confirm that Pydantic alone is working as expected
from pydantic import (
BaseModel,
ConfigDict,
PlainSerializer,
TypeAdapter,
WithJsonSchema,
)
class FakeNumpyArray:
def __init__(self):
self.data = [1.0, 2.0, 3.0]
FakeNumpyArrayPydantic = Annotated[
FakeNumpyArray,
WithJsonSchema(TypeAdapter(List[float]).json_schema()),
PlainSerializer(lambda v: v.data),
]
class MyModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
custom_field: FakeNumpyArrayPydantic
ta = TypeAdapter(MyModel)
assert ta.dump_python(MyModel(custom_field=FakeNumpyArray())) == {
"custom_field": [1.0, 2.0, 3.0]
}
assert ta.json_schema() == snapshot(
{
"properties": {
"custom_field": {
"items": {"type": "number"},
"title": "Custom Field",
"type": "array",
}
},
"required": ["custom_field"],
"title": "MyModel",
"type": "object",
}
)
@needs_pydanticv2
def test_openapi_schema(client: TestClient):
response = client.get("openapi.json")
assert response.json() == snapshot(
{
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/": {
"get": {
"summary": "Test",
"operationId": "test__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MyModel"
}
}
},
}
},
}
}
},
"components": {
"schemas": {
"MyModel": {
"properties": {
"custom_field": {
"items": {"type": "number"},
"type": "array",
"title": "Custom Field",
}
},
"type": "object",
"required": ["custom_field"],
"title": "MyModel",
}
}
},
}
)
Loading…
Cancel
Save