Browse Source

♻️ Refactor internals for test coverage and performance (#9691)

* ♻️ Tweak import of Annotated from typing_extensions, they are installed anyway

* ♻️ Refactor _compat to define functions for Pydantic v1 or v2 once instead of checking inside

*  Add test for UploadFile for Pydantic v2

* ♻️ Refactor types and remove logic for impossible cases

*  Add missing tests from test refactor for path params

*  Add tests for new decimal encoder

* 💡 Add TODO comment for decimals in encoders

* 🔥 Remove unneeded dummy function

* 🔥 Remove section of code in field_annotation_is_scalar covered by sub-call to field_annotation_is_complex

* ♻️ Refactor and tweak variables and types in _compat

*  Add tests for corner cases and compat with Pydantic v1 and v2

* ♻️ Refactor type annotations
pull/9707/head
Sebastián Ramírez 2 years ago
committed by GitHub
parent
commit
cfb00b2119
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 337
      fastapi/_compat.py
  2. 13
      fastapi/dependencies/utils.py
  3. 1
      fastapi/encoders.py
  4. 9
      fastapi/security/oauth2.py
  5. 93
      tests/test_compat.py
  6. 6
      tests/test_datastructures.py
  7. 13
      tests/test_jsonable_encoder.py
  8. 67
      tests/test_path.py

337
fastapi/_compat.py

@ -113,22 +113,16 @@ if PYDANTIC_V2:
value: Any, value: Any,
values: Dict[str, Any] = {}, # noqa: B006 values: Dict[str, Any] = {}, # noqa: B006
*, *,
loc: Union[Tuple[Union[int, str], ...], str] = "", loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[ValidationError], None]]: ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try: try:
return ( return (
self._type_adapter.validate_python(value, from_attributes=True), self._type_adapter.validate_python(value, from_attributes=True),
None, None,
) )
except ValidationError as exc: except ValidationError as exc:
if isinstance(loc, tuple):
use_loc = loc
elif loc == "":
use_loc = ()
else:
use_loc = (loc,)
return None, _regenerate_error_with_loc( return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=use_loc errors=exc.errors(), loc_prefix=loc
) )
def serialize( def serialize(
@ -161,13 +155,6 @@ if PYDANTIC_V2:
# ModelField to its JSON Schema. # ModelField to its JSON Schema.
return id(self) return id(self)
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
return {}
def get_annotation_from_field_info( def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str annotation: Any, field_info: FieldInfo, field_name: str
) -> Any: ) -> Any:
@ -176,6 +163,91 @@ if PYDANTIC_V2:
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
return errors # type: ignore[return-value] return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema[
"title"
] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
return {}
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields
]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
else: else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401 from pydantic import AnyUrl as Url # noqa: F401
@ -333,87 +405,86 @@ else:
use_errors.append(error) use_errors.append(error)
return use_errors return use_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[ValidationError]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
if PYDANTIC_V2:
model.model_rebuild()
else:
model.update_forward_refs() model.update_forward_refs()
def _model_dump(
def _model_dump( model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any ) -> Any:
) -> Any:
if PYDANTIC_V2:
return model.model_dump(mode=mode, **kwargs)
else:
return model.dict(**kwargs) return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
def _get_model_config(model: BaseModel) -> Any:
if PYDANTIC_V2:
return model.model_config
else:
return model.__config__ # type: ignore[attr-defined] return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
def get_schema_from_model_field( *,
*, field: ModelField,
field: ModelField, schema_generator: GenerateJsonSchema,
schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap,
model_name_map: ModelNameMap, ) -> Dict[str, Any]:
) -> Dict[str, Any]: # This expects that GenerateJsonSchema was already used to generate the definitions
# This expects that GenerateJsonSchema was already used to generate the definitions
if PYDANTIC_V2:
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema[
"title"
] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
else:
return field_schema( # type: ignore[no-any-return] return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0] )[0]
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
if PYDANTIC_V2:
return {}
else:
models = get_flat_models_from_fields(fields, known_models=set()) models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(models) # type: ignore[no-any-return] return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
def get_definitions( *,
*, fields: List[ModelField],
fields: List[ModelField], schema_generator: GenerateJsonSchema,
schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap,
model_name_map: ModelNameMap, ) -> Dict[str, Dict[str, Any]]:
) -> Dict[str, Dict[str, Any]]:
if PYDANTIC_V2:
inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields
]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value]
else:
models = get_flat_models_from_fields(fields, known_models=set()) models = get_flat_models_from_fields(fields, known_models=set())
return get_model_definitions(flat_models=models, model_name_map=model_name_map) return get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)): if lenient_issubclass(annotation, (str, bytes)):
@ -453,10 +524,6 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar(annotation: Any) -> bool: def field_annotation_is_scalar(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return all(field_annotation_is_scalar(arg) for arg in get_args(annotation))
# handle Ellipsis here to make tuple[int, ...] work nicely # handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation) return annotation is Ellipsis or not field_annotation_is_complex(annotation)
@ -478,31 +545,6 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
) )
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
if PYDANTIC_V2:
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
else:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return field_annotation_is_sequence(field.field_info.annotation)
else:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
else:
return is_pv1_scalar_sequence_field(field)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes): if lenient_issubclass(annotation, bytes):
return True return True
@ -525,90 +567,31 @@ def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
return False return False
def is_bytes_sequence_annotation(annotation: Union[Type[Any], None]) -> bool: def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
at_least_one_bytes_sequence = False at_least_one = False
for arg in get_args(annotation): for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg): if is_bytes_sequence_annotation(arg):
at_least_one_bytes_sequence = True at_least_one = True
continue continue
return at_least_one_bytes_sequence return at_least_one
return field_annotation_is_sequence(annotation) and all( return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation) is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation) for sub_annotation in get_args(annotation)
) )
def is_uploadfile_sequence_annotation(annotation: Union[Type[Any], None]) -> bool: def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
at_least_one_bytes_sequence = False at_least_one = False
for arg in get_args(annotation): for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg): if is_uploadfile_sequence_annotation(arg):
at_least_one_bytes_sequence = True at_least_one = True
continue continue
return at_least_one_bytes_sequence return at_least_one
return field_annotation_is_sequence(annotation) and all( return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation) is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation) for sub_annotation in get_args(annotation)
) )
def is_bytes_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return is_bytes_or_nonable_bytes_annotation(field.type_)
else:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return is_bytes_sequence_annotation(field.type_)
else:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
if PYDANTIC_V2:
return type(field_info).from_annotation(annotation)
else:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if PYDANTIC_V2:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
else:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> ValidationError:
if PYDANTIC_V2:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
error["input"] = None
return error # type: ignore[return-value]
else:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
if PYDANTIC_V2:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
else:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel

13
fastapi/dependencies/utils.py

@ -25,7 +25,6 @@ from fastapi._compat import (
ModelField, ModelField,
Required, Required,
Undefined, Undefined,
ValidationError,
_regenerate_error_with_loc, _regenerate_error_with_loc,
copy_field_info, copy_field_info,
create_body_model, create_body_model,
@ -659,13 +658,7 @@ def request_params_to_args(
values[field.name] = deepcopy(field.default) values[field.name] = deepcopy(field.default)
continue continue
v_, errors_ = field.validate(value, values, loc=loc) v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ValidationError): if isinstance(errors_, ErrorWrapper):
new_errors = _regenerate_error_with_loc(
errors=errors_.errors(), loc_prefix=loc
)
new_error = ValidationError(title=errors_.title, errors=new_errors)
errors.append(new_error)
elif isinstance(errors_, ErrorWrapper):
errors.append(errors_) errors.append(errors_)
elif isinstance(errors_, list): elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
@ -678,9 +671,9 @@ def request_params_to_args(
async def request_body_to_args( async def request_body_to_args(
required_params: List[ModelField], required_params: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]], received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ValidationError]]: ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values = {} values = {}
errors: List[ValidationError] = [] errors: List[Dict[str, Any]] = []
if required_params: if required_params:
field = required_params[0] field = required_params[0]
field_info = field.field_info field_info = field.field_info

1
fastapi/encoders.py

@ -32,6 +32,7 @@ def isoformat(o: Union[datetime.date, datetime.time]) -> str:
# Taken from Pydantic v1 as is # Taken from Pydantic v1 as is
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]: def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
""" """
Encodes a Decimal as int of there's no exponent, otherwise float Encodes a Decimal as int of there's no exponent, otherwise float

9
fastapi/security/oauth2.py

@ -1,11 +1,5 @@
import sys
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union, cast
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
@ -15,6 +9,9 @@ from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated
class OAuth2PasswordRequestForm: class OAuth2PasswordRequestForm:
""" """

93
tests/test_compat.py

@ -0,0 +1,93 @@
from typing import List, Union
from fastapi import FastAPI, UploadFile
from fastapi._compat import (
ModelField,
Undefined,
_get_model_config,
is_bytes_sequence_annotation,
is_uploadfile_sequence_annotation,
)
from fastapi.testclient import TestClient
from pydantic import BaseConfig, BaseModel, ConfigDict
from pydantic.fields import FieldInfo
from .utils import needs_pydanticv1, needs_pydanticv2
@needs_pydanticv2
def test_model_field_default_required():
# For coverage
field_info = FieldInfo(annotation=str)
field = ModelField(name="foo", field_info=field_info)
assert field.default is Undefined
@needs_pydanticv1
def test_upload_file_dummy_general_plain_validator_function():
# For coverage
assert UploadFile.__get_pydantic_core_schema__(str, lambda x: None) == {}
@needs_pydanticv1
def test_union_scalar_list():
# For coverage
# TODO: there might not be a current valid code path that uses this, it would
# potentially enable query parameters defined as both a scalar and a list
# but that would require more refactors, also not sure it's really useful
from fastapi._compat import is_pv1_scalar_field
field_info = FieldInfo()
field = ModelField(
name="foo",
field_info=field_info,
type_=Union[str, List[int]],
class_validators={},
model_config=BaseConfig,
)
assert not is_pv1_scalar_field(field)
@needs_pydanticv2
def test_get_model_config():
# For coverage in Pydantic v2
class Foo(BaseModel):
model_config = ConfigDict(from_attributes=True)
foo = Foo()
config = _get_model_config(foo)
assert config == {"from_attributes": True}
def test_complex():
app = FastAPI()
@app.post("/")
def foo(foo: Union[str, List[int]]):
return foo
client = TestClient(app)
response = client.post("/", json="bar")
assert response.status_code == 200, response.text
assert response.json() == "bar"
response2 = client.post("/", json=[1, 2])
assert response2.status_code == 200, response2.text
assert response2.json() == [1, 2]
def test_is_bytes_sequence_annotation_union():
# For coverage
# TODO: in theory this would allow declaring types that could be lists of bytes
# to be read from files and other types, but I'm not even sure it's a good idea
# to support it as a first class "feature"
assert is_bytes_sequence_annotation(Union[List[str], List[bytes]])
def test_is_uploadfile_sequence_annotation():
# For coverage
# TODO: in theory this would allow declaring types that could be lists of UploadFile
# and other types, but I'm not even sure it's a good idea to support it as a first
# class "feature"
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])

6
tests/test_datastructures.py

@ -7,11 +7,17 @@ from fastapi.datastructures import Default
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
# TODO: remove when deprecating Pydantic v1
def test_upload_file_invalid(): def test_upload_file_invalid():
with pytest.raises(ValueError): with pytest.raises(ValueError):
UploadFile.validate("not a Starlette UploadFile") UploadFile.validate("not a Starlette UploadFile")
def test_upload_file_invalid_pydantic_v2():
with pytest.raises(ValueError):
UploadFile._validate("not a Starlette UploadFile", {})
def test_default_placeholder_equals(): def test_default_placeholder_equals():
placeholder_1 = Default("a") placeholder_1 = Default("a")
placeholder_2 = Default("a") placeholder_2 = Default("a")

13
tests/test_jsonable_encoder.py

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum from enum import Enum
from pathlib import PurePath, PurePosixPath, PureWindowsPath from pathlib import PurePath, PurePosixPath, PureWindowsPath
from typing import Optional from typing import Optional
@ -286,3 +287,15 @@ def test_encode_root():
model = ModelWithRoot(__root__="Foo") model = ModelWithRoot(__root__="Foo")
assert jsonable_encoder(model) == "Foo" assert jsonable_encoder(model) == "Foo"
@needs_pydanticv2
def test_decimal_encoder_float():
data = {"value": Decimal(1.23)}
assert jsonable_encoder(data) == {"value": 1.23}
@needs_pydanticv2
def test_decimal_encoder_int():
data = {"value": Decimal(2)}
assert jsonable_encoder(data) == {"value": 2}

67
tests/test_path.py

@ -409,6 +409,73 @@ def test_path_param_maxlength_foobar():
) )
def test_path_param_min_maxlength_foo():
response = client.get("/path/param-min_maxlength/foo")
assert response.status_code == 200
assert response.json() == "foo"
def test_path_param_min_maxlength_foobar():
response = client.get("/path/param-min_maxlength/foobar")
assert response.status_code == 422
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_too_long",
"loc": ["path", "item_id"],
"msg": "String should have at most 3 characters",
"input": "foobar",
"ctx": {"max_length": 3},
"url": match_pydantic_error_url("string_too_long"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["path", "item_id"],
"msg": "ensure this value has at most 3 characters",
"type": "value_error.any_str.max_length",
"ctx": {"limit_value": 3},
}
]
}
)
def test_path_param_min_maxlength_f():
response = client.get("/path/param-min_maxlength/f")
assert response.status_code == 422
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_too_short",
"loc": ["path", "item_id"],
"msg": "String should have at least 2 characters",
"input": "f",
"ctx": {"min_length": 2},
"url": match_pydantic_error_url("string_too_short"),
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["path", "item_id"],
"msg": "ensure this value has at least 2 characters",
"type": "value_error.any_str.min_length",
"ctx": {"limit_value": 2},
}
]
}
)
def test_path_param_gt_42(): def test_path_param_gt_42():
response = client.get("/path/param-gt/42") response = client.get("/path/param-gt/42")
assert response.status_code == 200 assert response.status_code == 200

Loading…
Cancel
Save