Browse Source

♻️ Refactor internals for test coverage and performance (#9691)

* ♻️ Tweak import of Annotated from typing_extensions, they are installed anyway

* ♻️ Refactor _compat to define functions for Pydantic v1 or v2 once instead of checking inside

*  Add test for UploadFile for Pydantic v2

* ♻️ Refactor types and remove logic for impossible cases

*  Add missing tests from test refactor for path params

*  Add tests for new decimal encoder

* 💡 Add TODO comment for decimals in encoders

* 🔥 Remove unneeded dummy function

* 🔥 Remove section of code in field_annotation_is_scalar covered by sub-call to field_annotation_is_complex

* ♻️ Refactor and tweak variables and types in _compat

*  Add tests for corner cases and compat with Pydantic v1 and v2

* ♻️ Refactor type annotations
pull/9707/head
Sebastián Ramírez 2 years ago
committed by GitHub
parent
commit
cfb00b2119
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 337
      fastapi/_compat.py
  2. 13
      fastapi/dependencies/utils.py
  3. 1
      fastapi/encoders.py
  4. 9
      fastapi/security/oauth2.py
  5. 93
      tests/test_compat.py
  6. 6
      tests/test_datastructures.py
  7. 13
      tests/test_jsonable_encoder.py
  8. 67
      tests/test_path.py

337
fastapi/_compat.py

@ -113,22 +113,16 @@ if PYDANTIC_V2:
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Union[Tuple[Union[int, str], ...], str] = "",
) -> Tuple[Any, Union[List[ValidationError], None]]:
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
if isinstance(loc, tuple):
use_loc = loc
elif loc == "":
use_loc = ()
else:
use_loc = (loc,)
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=use_loc
errors=exc.errors(), loc_prefix=loc
)
def serialize(
@ -161,13 +155,6 @@ if PYDANTIC_V2:
# ModelField to its JSON Schema.
return id(self)
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
return {}
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
@ -176,6 +163,91 @@ if PYDANTIC_V2:
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema[
"title"
] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
return {}
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields
]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
@ -333,87 +405,86 @@ else:
use_errors.append(error)
return use_errors
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[ValidationError]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
if PYDANTIC_V2:
model.model_rebuild()
else:
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
if PYDANTIC_V2:
return model.model_dump(mode=mode, **kwargs)
else:
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
if PYDANTIC_V2:
return model.model_config
else:
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
if PYDANTIC_V2:
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema[
"title"
] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
else:
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
if PYDANTIC_V2:
return {}
else:
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
if PYDANTIC_V2:
inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields
]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value]
else:
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
@ -453,10 +524,6 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return all(field_annotation_is_scalar(arg) for arg in get_args(annotation))
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
@ -478,31 +545,6 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
)
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
if PYDANTIC_V2:
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
else:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return field_annotation_is_sequence(field.field_info.annotation)
else:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
else:
return is_pv1_scalar_sequence_field(field)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
@ -525,90 +567,31 @@ def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
return False
def is_bytes_sequence_annotation(annotation: Union[Type[Any], None]) -> bool:
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_bytes_sequence = False
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one_bytes_sequence = True
at_least_one = True
continue
return at_least_one_bytes_sequence
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Union[Type[Any], None]) -> bool:
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_bytes_sequence = False
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one_bytes_sequence = True
at_least_one = True
continue
return at_least_one_bytes_sequence
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return is_bytes_or_nonable_bytes_annotation(field.type_)
else:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
if PYDANTIC_V2:
return is_bytes_sequence_annotation(field.type_)
else:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
if PYDANTIC_V2:
return type(field_info).from_annotation(annotation)
else:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if PYDANTIC_V2:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
else:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> ValidationError:
if PYDANTIC_V2:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
error["input"] = None
return error # type: ignore[return-value]
else:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
if PYDANTIC_V2:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
else:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel

13
fastapi/dependencies/utils.py

@ -25,7 +25,6 @@ from fastapi._compat import (
ModelField,
Required,
Undefined,
ValidationError,
_regenerate_error_with_loc,
copy_field_info,
create_body_model,
@ -659,13 +658,7 @@ def request_params_to_args(
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ValidationError):
new_errors = _regenerate_error_with_loc(
errors=errors_.errors(), loc_prefix=loc
)
new_error = ValidationError(title=errors_.title, errors=new_errors)
errors.append(new_error)
elif isinstance(errors_, ErrorWrapper):
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
@ -678,9 +671,9 @@ def request_params_to_args(
async def request_body_to_args(
required_params: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ValidationError]]:
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values = {}
errors: List[ValidationError] = []
errors: List[Dict[str, Any]] = []
if required_params:
field = required_params[0]
field_info = field.field_info

1
fastapi/encoders.py

@ -32,6 +32,7 @@ def isoformat(o: Union[datetime.date, datetime.time]) -> str:
# Taken from Pydantic v1 as is
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float

9
fastapi/security/oauth2.py

@ -1,11 +1,5 @@
import sys
from typing import Any, Dict, List, Optional, Union, cast
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
@ -15,6 +9,9 @@ from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated
class OAuth2PasswordRequestForm:
"""

93
tests/test_compat.py

@ -0,0 +1,93 @@
from typing import List, Union
from fastapi import FastAPI, UploadFile
from fastapi._compat import (
ModelField,
Undefined,
_get_model_config,
is_bytes_sequence_annotation,
is_uploadfile_sequence_annotation,
)
from fastapi.testclient import TestClient
from pydantic import BaseConfig, BaseModel, ConfigDict
from pydantic.fields import FieldInfo
from .utils import needs_pydanticv1, needs_pydanticv2
@needs_pydanticv2
def test_model_field_default_required():
# For coverage
field_info = FieldInfo(annotation=str)
field = ModelField(name="foo", field_info=field_info)
assert field.default is Undefined
@needs_pydanticv1
def test_upload_file_dummy_general_plain_validator_function():
# For coverage
assert UploadFile.__get_pydantic_core_schema__(str, lambda x: None) == {}
@needs_pydanticv1
def test_union_scalar_list():
# For coverage
# TODO: there might not be a current valid code path that uses this, it would
# potentially enable query parameters defined as both a scalar and a list
# but that would require more refactors, also not sure it's really useful
from fastapi._compat import is_pv1_scalar_field
field_info = FieldInfo()
field = ModelField(
name="foo",
field_info=field_info,
type_=Union[str, List[int]],
class_validators={},
model_config=BaseConfig,
)
assert not is_pv1_scalar_field(field)
@needs_pydanticv2
def test_get_model_config():
# For coverage in Pydantic v2
class Foo(BaseModel):
model_config = ConfigDict(from_attributes=True)
foo = Foo()
config = _get_model_config(foo)
assert config == {"from_attributes": True}
def test_complex():
app = FastAPI()
@app.post("/")
def foo(foo: Union[str, List[int]]):
return foo
client = TestClient(app)
response = client.post("/", json="bar")
assert response.status_code == 200, response.text
assert response.json() == "bar"
response2 = client.post("/", json=[1, 2])
assert response2.status_code == 200, response2.text
assert response2.json() == [1, 2]
def test_is_bytes_sequence_annotation_union():
# For coverage
# TODO: in theory this would allow declaring types that could be lists of bytes
# to be read from files and other types, but I'm not even sure it's a good idea
# to support it as a first class "feature"
assert is_bytes_sequence_annotation(Union[List[str], List[bytes]])
def test_is_uploadfile_sequence_annotation():
# For coverage
# TODO: in theory this would allow declaring types that could be lists of UploadFile
# and other types, but I'm not even sure it's a good idea to support it as a first
# class "feature"
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])

6
tests/test_datastructures.py

@ -7,11 +7,17 @@ from fastapi.datastructures import Default
from fastapi.testclient import TestClient
# TODO: remove when deprecating Pydantic v1
def test_upload_file_invalid():
with pytest.raises(ValueError):
UploadFile.validate("not a Starlette UploadFile")
def test_upload_file_invalid_pydantic_v2():
with pytest.raises(ValueError):
UploadFile._validate("not a Starlette UploadFile", {})
def test_default_placeholder_equals():
placeholder_1 = Default("a")
placeholder_2 = Default("a")

13
tests/test_jsonable_encoder.py

@ -1,5 +1,6 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from pathlib import PurePath, PurePosixPath, PureWindowsPath
from typing import Optional
@ -286,3 +287,15 @@ def test_encode_root():
model = ModelWithRoot(__root__="Foo")
assert jsonable_encoder(model) == "Foo"
@needs_pydanticv2
def test_decimal_encoder_float():
data = {"value": Decimal(1.23)}
assert jsonable_encoder(data) == {"value": 1.23}
@needs_pydanticv2
def test_decimal_encoder_int():
data = {"value": Decimal(2)}
assert jsonable_encoder(data) == {"value": 2}

67
tests/test_path.py

@ -409,6 +409,73 @@ def test_path_param_maxlength_foobar():
)
def test_path_param_min_maxlength_foo():
response = client.get("/path/param-min_maxlength/foo")
assert response.status_code == 200
assert response.json() == "foo"
def test_path_param_min_maxlength_foobar():
response = client.get("/path/param-min_maxlength/foobar")
assert response.status_code == 422
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_too_long",
"loc": ["path", "item_id"],
"msg": "String should have at most 3 characters",
"input": "foobar",
"ctx": {"max_length": 3},
"url": match_pydantic_error_url("string_too_long"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["path", "item_id"],
"msg": "ensure this value has at most 3 characters",
"type": "value_error.any_str.max_length",
"ctx": {"limit_value": 3},
}
]
}
)
def test_path_param_min_maxlength_f():
response = client.get("/path/param-min_maxlength/f")
assert response.status_code == 422
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_too_short",
"loc": ["path", "item_id"],
"msg": "String should have at least 2 characters",
"input": "f",
"ctx": {"min_length": 2},
"url": match_pydantic_error_url("string_too_short"),
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["path", "item_id"],
"msg": "ensure this value has at least 2 characters",
"type": "value_error.any_str.min_length",
"ctx": {"limit_value": 2},
}
]
}
)
def test_path_param_gt_42():
response = client.get("/path/param-gt/42")
assert response.status_code == 200

Loading…
Cancel
Save