Browse Source

Merge ee3c9eac6f into 6e69d62bfe

pull/11542/merge
Tamir Duberstein 2 days ago
committed by GitHub
parent
commit
3a2898140a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 111
      fastapi/_compat.py
  2. 30
      fastapi/dependencies/utils.py
  3. 12
      fastapi/exceptions.py
  4. 18
      fastapi/routing.py
  5. 8
      tests/test_filter_pydantic_sub_model/app_pv1.py
  6. 16
      tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py
  7. 67
      tests/test_filter_pydantic_sub_model_pv2.py

111
fastapi/_compat.py

@ -4,6 +4,7 @@ from dataclasses import dataclass, is_dataclass
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Deque, Deque,
@ -24,7 +25,14 @@ from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin from typing_extensions import (
Annotated,
Literal,
TypeAlias,
assert_never,
get_args,
get_origin,
)
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
@ -61,6 +69,7 @@ if PYDANTIC_V2:
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import ErrorDetails as ErrorDetails
from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url from pydantic_core import Url as Url
@ -70,7 +79,7 @@ if PYDANTIC_V2:
) )
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from pydantic_core.core_schema import ( from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 general_plain_validator_function as with_info_plain_validator_function,
) )
RequiredParam = PydanticUndefined RequiredParam = PydanticUndefined
@ -85,6 +94,9 @@ if PYDANTIC_V2:
class ErrorWrapper(Exception): class ErrorWrapper(Exception):
pass pass
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/error_wrappers.py#L45-L47.
ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper]
@dataclass @dataclass
class ModelField: class ModelField:
field_info: FieldInfo field_info: FieldInfo
@ -118,22 +130,25 @@ if PYDANTIC_V2:
return Undefined return Undefined
return self.field_info.get_default(call_default_factory=True) return self.field_info.get_default(call_default_factory=True)
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/fields.py#L850-L852 for the signature.
def validate( def validate(
self, self,
value: Any, value: Any,
values: Dict[str, Any] = {}, # noqa: B006 values: Dict[str, Any] = {}, # noqa: B006
*, *,
loc: Tuple[Union[int, str], ...] = (), loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: ) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]:
try: try:
return ( return (
self._type_adapter.validate_python(value, from_attributes=True), self._type_adapter.validate_python(value, from_attributes=True),
None, None,
) )
except ValidationError as exc: except ValidationError as exc:
return None, _regenerate_error_with_loc( errors: List[ErrorDetails] = [
errors=exc.errors(include_url=False), loc_prefix=loc {**err, "loc": loc + err["loc"]}
) for err in exc.errors(include_url=False)
]
return None, errors
def serialize( def serialize(
self, self,
@ -170,7 +185,13 @@ if PYDANTIC_V2:
) -> Any: ) -> Any:
return annotation return annotation
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: def _normalize_errors(
errors: Union[ErrorList, Sequence[ErrorDetails]],
) -> List[ErrorDetails]:
assert isinstance(errors, Sequence), type(errors)
for error in errors:
assert not isinstance(error, ErrorWrapper)
assert not isinstance(error, Sequence)
return errors # type: ignore[return-value] return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None: def _model_rebuild(model: Type[BaseModel]) -> None:
@ -272,12 +293,12 @@ if PYDANTIC_V2:
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
error = ValidationError.from_exception_data( [error] = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}] "Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0] ).errors(include_url=False, include_input=False)
error["input"] = None error["input"] = None
return error # type: ignore[return-value] return error
def create_body_model( def create_body_model(
*, fields: Sequence[ModelField], model_name: str *, fields: Sequence[ModelField], model_name: str
@ -296,14 +317,22 @@ else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401 from pydantic import AnyUrl as Url # noqa: F401
from pydantic import ( # type: ignore[assignment] from pydantic import ( # type: ignore[assignment]
BaseConfig as BaseConfig, # noqa: F401 BaseConfig as BaseConfig,
) )
from pydantic import ValidationError as ValidationError # noqa: F401 from pydantic import ValidationError as ValidationError
from pydantic.class_validators import ( # type: ignore[no-redef] from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator, # noqa: F401 Validator as Validator,
)
if TYPE_CHECKING: # pragma: nocover
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorDict as ErrorDetails,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorList as ErrorList,
) )
from pydantic.error_wrappers import ( # type: ignore[no-redef] from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper, # noqa: F401 ErrorWrapper as ErrorWrapper,
) )
from pydantic.errors import MissingError from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined] from pydantic.fields import ( # type: ignore[attr-defined]
@ -317,7 +346,7 @@ else:
) )
from pydantic.fields import FieldInfo as FieldInfo from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ( # type: ignore[no-redef,attr-defined] from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField, # noqa: F401 ModelField as ModelField,
) )
# Keeping old "Required" functionality from Pydantic V1, without # Keeping old "Required" functionality from Pydantic V1, without
@ -327,7 +356,7 @@ else:
Undefined as Undefined, Undefined as Undefined,
) )
from pydantic.fields import ( # type: ignore[no-redef, attr-defined] from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType, # noqa: F401 UndefinedType as UndefinedType,
) )
from pydantic.schema import ( from pydantic.schema import (
field_schema, field_schema,
@ -335,14 +364,14 @@ else:
get_model_name_map, get_model_name_map,
model_process_schema, model_process_schema,
) )
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401 from pydantic.schema import ( # type: ignore[no-redef]
get_annotation_from_field_info as get_annotation_from_field_info, get_annotation_from_field_info as get_annotation_from_field_info,
) )
from pydantic.typing import ( # type: ignore[no-redef] from pydantic.typing import ( # type: ignore[no-redef]
evaluate_forwardref as evaluate_forwardref, # noqa: F401 evaluate_forwardref as evaluate_forwardref,
) )
from pydantic.utils import ( # type: ignore[no-redef] from pydantic.utils import ( # type: ignore[no-redef]
lenient_issubclass as lenient_issubclass, # noqa: F401 lenient_issubclass as lenient_issubclass,
) )
GetJsonSchemaHandler = Any # type: ignore[assignment,misc] GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
@ -432,18 +461,23 @@ else:
return True return True
return False return False
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: def _normalize_errors(
use_errors: List[Any] = [] errors: Union[ErrorList, Sequence["ErrorDetails"]],
for error in errors: ) -> List["ErrorDetails"]:
if isinstance(error, ErrorWrapper): use_errors: List[ErrorDetails] = []
new_errors = ValidationError( # type: ignore[call-arg] if isinstance(errors, ErrorWrapper):
errors=[error], model=RequestErrorModel use_errors.extend(
ValidationError( # type: ignore[call-arg]
errors=[errors], model=RequestErrorModel
).errors() ).errors()
use_errors.extend(new_errors) )
elif isinstance(error, list): elif isinstance(errors, Sequence):
for error in errors:
assert not isinstance(error, dict)
use_errors.extend(_normalize_errors(error)) use_errors.extend(_normalize_errors(error))
else: return use_errors
use_errors.append(error) else:
assert_never(errors) # pragma: no cover
return use_errors return use_errors
def _model_rebuild(model: Type[BaseModel]) -> None: def _model_rebuild(model: Type[BaseModel]) -> None:
@ -514,10 +548,10 @@ else:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined] return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: def get_missing_field_error(loc: Tuple[str, ...]) -> "ErrorDetails":
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg] missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel) [new_error] = ValidationError([missing_field_error], RequestErrorModel).errors()
return new_error.errors()[0] # type: ignore[return-value] return new_error
def create_body_model( def create_body_model(
*, fields: Sequence[ModelField], model_name: str *, fields: Sequence[ModelField], model_name: str
@ -531,17 +565,6 @@ else:
return list(model.__fields__.values()) # type: ignore[attr-defined] return list(model.__fields__.values()) # type: ignore[attr-defined]
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)): if lenient_issubclass(annotation, (str, bytes)):
return False return False

30
fastapi/dependencies/utils.py

@ -3,6 +3,7 @@ from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Coroutine, Coroutine,
@ -22,11 +23,10 @@ import anyio
from fastapi import params from fastapi import params
from fastapi._compat import ( from fastapi._compat import (
PYDANTIC_V2, PYDANTIC_V2,
ErrorWrapper,
ModelField, ModelField,
RequiredParam, RequiredParam,
Undefined, Undefined,
_regenerate_error_with_loc, _normalize_errors,
copy_field_info, copy_field_info,
create_body_model, create_body_model,
evaluate_forwardref, evaluate_forwardref,
@ -46,6 +46,9 @@ from fastapi._compat import (
serialize_sequence_value, serialize_sequence_value,
value_is_sequence, value_is_sequence,
) )
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from fastapi.background import BackgroundTasks from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import (
asynccontextmanager, asynccontextmanager,
@ -563,7 +566,7 @@ async def solve_generator(
@dataclass @dataclass
class SolvedDependency: class SolvedDependency:
values: Dict[str, Any] values: Dict[str, Any]
errors: List[Any] errors: List["ErrorDetails"]
background_tasks: Optional[StarletteBackgroundTasks] background_tasks: Optional[StarletteBackgroundTasks]
response: Response response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
@ -582,7 +585,7 @@ async def solve_dependencies(
embed_body_fields: bool, embed_body_fields: bool,
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[ErrorDetails] = []
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
@ -658,7 +661,8 @@ async def solve_dependencies(
values.update(query_values) values.update(query_values)
values.update(header_values) values.update(header_values)
values.update(cookie_values) values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors for errors_ in (path_errors, query_errors, header_errors, cookie_errors):
errors.extend(errors_)
if dependant.body_params: if dependant.body_params:
( (
body_values, body_values,
@ -697,17 +701,15 @@ async def solve_dependencies(
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List["ErrorDetails"]]:
if value is None: if value is None:
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
else: else:
return deepcopy(field.default), [] return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc) v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper): if errors_ is not None:
return None, [errors_] new_errors = _normalize_errors(errors_)
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
return None, new_errors return None, new_errors
else: else:
return v_, [] return v_, []
@ -740,9 +742,9 @@ def _get_multidict_value(
def request_params_to_args( def request_params_to_args(
fields: Sequence[ModelField], fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers], received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]: ) -> Tuple[Dict[str, Any], List["ErrorDetails"]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = [] errors: List[ErrorDetails] = []
if not fields: if not fields:
return values, errors return values, errors
@ -906,9 +908,9 @@ async def request_body_to_args(
body_fields: List[ModelField], body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]], received_body: Optional[Union[Dict[str, Any], FormData]],
embed_body_fields: bool, embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: ) -> Tuple[Dict[str, Any], List["ErrorDetails"]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = [] errors: List[ErrorDetails] = []
assert body_fields, "request_body_to_args() should be called with fields" assert body_fields, "request_body_to_args() should be called with fields"
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
first_field = body_fields[0] first_field = body_fields[0]

12
fastapi/exceptions.py

@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, Sequence, Type, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException from starlette.exceptions import WebSocketException as StarletteWebSocketException
@ -147,15 +149,15 @@ class FastAPIError(RuntimeError):
class ValidationException(Exception): class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None: def __init__(self, errors: Sequence["ErrorDetails"]) -> None:
self._errors = errors self._errors = errors
def errors(self) -> Sequence[Any]: def errors(self) -> Sequence["ErrorDetails"]:
return self._errors return self._errors
class RequestValidationError(ValidationException): class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None:
super().__init__(errors) super().__init__(errors)
self.body = body self.body = body
@ -165,7 +167,7 @@ class WebSocketRequestValidationError(ValidationException):
class ResponseValidationError(ValidationException): class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None:
super().__init__(errors) super().__init__(errors)
self.body = body self.body = body

18
fastapi/routing.py

@ -154,7 +154,6 @@ async def serialize_response(
is_coroutine: bool = True, is_coroutine: bool = True,
) -> Any: ) -> Any:
if field: if field:
errors = []
if not hasattr(field, "serialize"): if not hasattr(field, "serialize"):
# pydantic v1 # pydantic v1
response_content = _prepare_response_content( response_content = _prepare_response_content(
@ -164,15 +163,11 @@ async def serialize_response(
exclude_none=exclude_none, exclude_none=exclude_none,
) )
if is_coroutine: if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",)) value, errors = field.validate(response_content, {}, loc=("response",))
else: else:
value, errors_ = await run_in_threadpool( value, errors = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",) field.validate, response_content, {}, loc=("response",)
) )
if isinstance(errors_, list):
errors.extend(errors_)
elif errors_:
errors.append(errors_)
if errors: if errors:
raise ResponseValidationError( raise ResponseValidationError(
errors=_normalize_errors(errors), body=response_content errors=_normalize_errors(errors), body=response_content
@ -287,7 +282,6 @@ def get_request_handler(
status_code=400, detail="There was an error parsing the body" status_code=400, detail="There was an error parsing the body"
) )
raise http_error from e raise http_error from e
errors: List[Any] = []
async with AsyncExitStack() as async_exit_stack: async with AsyncExitStack() as async_exit_stack:
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
request=request, request=request,
@ -341,9 +335,7 @@ def get_request_handler(
response.body = b"" response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw) response.headers.raw.extend(solved_result.response.headers.raw)
if errors: if errors:
validation_error = RequestValidationError( validation_error = RequestValidationError(errors, body=body)
_normalize_errors(errors), body=body
)
raise validation_error raise validation_error
if response is None: if response is None:
raise FastAPIError( raise FastAPIError(
@ -377,9 +369,7 @@ def get_websocket_app(
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
if solved_result.errors: if solved_result.errors:
raise WebSocketRequestValidationError( raise WebSocketRequestValidationError(solved_result.errors)
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values) await dependant.call(**solved_result.values)

8
tests/test_filter_pydantic_sub_model/app_pv1.py

@ -1,4 +1,4 @@
from typing import Optional from typing import Any, Dict, Optional
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
@ -20,7 +20,7 @@ class ModelA(BaseModel):
model_b: ModelB model_b: ModelB
@validator("name") @validator("name")
def lower_username(cls, name: str, values): def lower_username(cls, name: str, values: Dict[str, Any]) -> str:
if not name.endswith("A"): if not name.endswith("A"):
raise ValueError("name must end in A") raise ValueError("name must end in A")
return name return name
@ -31,5 +31,7 @@ async def get_model_c() -> ModelC:
@app.get("/model/{name}", response_model=ModelA) @app.get("/model/{name}", response_model=ModelA)
async def get_model_a(name: str, model_c=Depends(get_model_c)): async def get_model_a(
name: str, model_c: ModelC = Depends(get_model_c)
) -> Dict[str, Any]:
return {"name": name, "description": "model-a-desc", "model_b": model_c} return {"name": name, "description": "model-a-desc", "model_b": model_c}

16
tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py

@ -1,4 +1,9 @@
from typing import TYPE_CHECKING, Sequence
import pytest import pytest
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from fastapi.exceptions import ResponseValidationError from fastapi.exceptions import ResponseValidationError
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -6,7 +11,7 @@ from ..utils import needs_pydanticv1
@pytest.fixture(name="client") @pytest.fixture(name="client")
def get_client(): def get_client() -> TestClient:
from .app_pv1 import app from .app_pv1 import app
client = TestClient(app) client = TestClient(app)
@ -14,7 +19,7 @@ def get_client():
@needs_pydanticv1 @needs_pydanticv1
def test_filter_sub_model(client: TestClient): def test_filter_sub_model(client: TestClient) -> None:
response = client.get("/model/modelA") response = client.get("/model/modelA")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
@ -25,10 +30,11 @@ def test_filter_sub_model(client: TestClient):
@needs_pydanticv1 @needs_pydanticv1
def test_validator_is_cloned(client: TestClient): def test_validator_is_cloned(client: TestClient) -> None:
with pytest.raises(ResponseValidationError) as err: with pytest.raises(ResponseValidationError) as err:
client.get("/model/modelX") client.get("/model/modelX")
assert err.value.errors() == [ errors: Sequence[ErrorDetails] = err.value.errors()
assert errors == [
{ {
"loc": ("response", "name"), "loc": ("response", "name"),
"msg": "name must end in A", "msg": "name must end in A",
@ -38,7 +44,7 @@ def test_validator_is_cloned(client: TestClient):
@needs_pydanticv1 @needs_pydanticv1
def test_openapi_schema(client: TestClient): def test_openapi_schema(client: TestClient) -> None:
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {

67
tests/test_filter_pydantic_sub_model_pv2.py

@ -1,8 +1,11 @@
from typing import Optional from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
import pytest import pytest
from dirty_equals import HasRepr, IsDict, IsOneOf from dirty_equals import HasRepr
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from fastapi.exceptions import ResponseValidationError from fastapi.exceptions import ResponseValidationError
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -10,7 +13,7 @@ from .utils import needs_pydanticv2
@pytest.fixture(name="client") @pytest.fixture(name="client")
def get_client(): def get_client() -> TestClient:
from pydantic import BaseModel, ValidationInfo, field_validator from pydantic import BaseModel, ValidationInfo, field_validator
app = FastAPI() app = FastAPI()
@ -27,7 +30,7 @@ def get_client():
foo: ModelB foo: ModelB
@field_validator("name") @field_validator("name")
def lower_username(cls, name: str, info: ValidationInfo): def lower_username(cls, name: str, info: ValidationInfo) -> str:
if not name.endswith("A"): if not name.endswith("A"):
raise ValueError("name must end in A") raise ValueError("name must end in A")
return name return name
@ -36,7 +39,9 @@ def get_client():
return ModelC(username="test-user", password="test-password") return ModelC(username="test-user", password="test-password")
@app.get("/model/{name}", response_model=ModelA) @app.get("/model/{name}", response_model=ModelA)
async def get_model_a(name: str, model_c=Depends(get_model_c)): async def get_model_a(
name: str, model_c: ModelC = Depends(get_model_c)
) -> Dict[str, Any]:
return {"name": name, "description": "model-a-desc", "foo": model_c} return {"name": name, "description": "model-a-desc", "foo": model_c}
client = TestClient(app) client = TestClient(app)
@ -44,7 +49,7 @@ def get_client():
@needs_pydanticv2 @needs_pydanticv2
def test_filter_sub_model(client: TestClient): def test_filter_sub_model(client: TestClient) -> None:
response = client.get("/model/modelA") response = client.get("/model/modelA")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
@ -55,32 +60,23 @@ def test_filter_sub_model(client: TestClient):
@needs_pydanticv2 @needs_pydanticv2
def test_validator_is_cloned(client: TestClient): def test_validator_is_cloned(client: TestClient) -> None:
with pytest.raises(ResponseValidationError) as err: with pytest.raises(ResponseValidationError) as err:
client.get("/model/modelX") client.get("/model/modelX")
assert err.value.errors() == [ errors: Sequence[ErrorDetails] = err.value.errors()
IsDict( assert errors == [
{ {
"type": "value_error", "type": "value_error",
"loc": ("response", "name"), "loc": ("response", "name"),
"msg": "Value error, name must end in A", "msg": "Value error, name must end in A",
"input": "modelX", "input": "modelX",
"ctx": {"error": HasRepr("ValueError('name must end in A')")}, "ctx": {"error": HasRepr("ValueError('name must end in A')")},
} }
)
| IsDict(
# TODO remove when deprecating Pydantic v1
{
"loc": ("response", "name"),
"msg": "name must end in A",
"type": "value_error",
}
)
] ]
@needs_pydanticv2 @needs_pydanticv2
def test_openapi_schema(client: TestClient): def test_openapi_schema(client: TestClient) -> None:
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
@ -137,23 +133,14 @@ def test_openapi_schema(client: TestClient):
}, },
"ModelA": { "ModelA": {
"title": "ModelA", "title": "ModelA",
"required": IsOneOf( "required": ["name", "foo"],
["name", "description", "foo"],
# TODO remove when deprecating Pydantic v1
["name", "foo"],
),
"type": "object", "type": "object",
"properties": { "properties": {
"name": {"title": "Name", "type": "string"}, "name": {"title": "Name", "type": "string"},
"description": IsDict( "description": {
{ "title": "Description",
"title": "Description", "anyOf": [{"type": "string"}, {"type": "null"}],
"anyOf": [{"type": "string"}, {"type": "null"}], },
}
)
|
# TODO remove when deprecating Pydantic v1
IsDict({"title": "Description", "type": "string"}),
"foo": {"$ref": "#/components/schemas/ModelB"}, "foo": {"$ref": "#/components/schemas/ModelB"},
}, },
}, },

Loading…
Cancel
Save