pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
334 lines
11 KiB
334 lines
11 KiB
from copy import copy
|
|
from dataclasses import dataclass, is_dataclass
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from fastapi._compat import shared
|
|
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
|
from fastapi.types import ModelNameMap
|
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
|
from typing_extensions import Literal
|
|
|
|
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
|
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
|
# Keeping old "Required" functionality from Pydantic V1, without
|
|
# shadowing typing.Required.
|
|
RequiredParam: Any = Ellipsis
|
|
|
|
if not PYDANTIC_V2:
|
|
from pydantic import BaseConfig as BaseConfig
|
|
from pydantic import BaseModel as BaseModel
|
|
from pydantic import ValidationError as ValidationError
|
|
from pydantic import create_model as create_model
|
|
from pydantic.class_validators import Validator as Validator
|
|
from pydantic.color import Color as Color
|
|
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
|
|
from pydantic.errors import MissingError
|
|
from pydantic.fields import ( # type: ignore[attr-defined]
|
|
SHAPE_FROZENSET,
|
|
SHAPE_LIST,
|
|
SHAPE_SEQUENCE,
|
|
SHAPE_SET,
|
|
SHAPE_SINGLETON,
|
|
SHAPE_TUPLE,
|
|
SHAPE_TUPLE_ELLIPSIS,
|
|
)
|
|
from pydantic.fields import FieldInfo as FieldInfo
|
|
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined]
|
|
from pydantic.fields import Undefined as Undefined # type: ignore[attr-defined]
|
|
from pydantic.fields import ( # type: ignore[attr-defined]
|
|
UndefinedType as UndefinedType,
|
|
)
|
|
from pydantic.networks import AnyUrl as AnyUrl
|
|
from pydantic.networks import NameEmail as NameEmail
|
|
from pydantic.schema import TypeModelSet as TypeModelSet
|
|
from pydantic.schema import (
|
|
field_schema,
|
|
get_flat_models_from_fields,
|
|
model_process_schema,
|
|
)
|
|
from pydantic.schema import (
|
|
get_annotation_from_field_info as get_annotation_from_field_info,
|
|
)
|
|
from pydantic.schema import get_flat_models_from_field as get_flat_models_from_field
|
|
from pydantic.schema import get_model_name_map as get_model_name_map
|
|
from pydantic.types import SecretBytes as SecretBytes
|
|
from pydantic.types import SecretStr as SecretStr
|
|
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
|
|
from pydantic.utils import lenient_issubclass as lenient_issubclass
|
|
|
|
|
|
else:
|
|
from pydantic.v1 import BaseConfig as BaseConfig # type: ignore[assignment]
|
|
from pydantic.v1 import BaseModel as BaseModel # type: ignore[assignment]
|
|
from pydantic.v1 import ( # type: ignore[assignment]
|
|
ValidationError as ValidationError,
|
|
)
|
|
from pydantic.v1 import create_model as create_model # type: ignore[no-redef]
|
|
from pydantic.v1.class_validators import Validator as Validator
|
|
from pydantic.v1.color import Color as Color # type: ignore[assignment]
|
|
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
|
|
from pydantic.v1.errors import MissingError
|
|
from pydantic.v1.fields import (
|
|
SHAPE_FROZENSET,
|
|
SHAPE_LIST,
|
|
SHAPE_SEQUENCE,
|
|
SHAPE_SET,
|
|
SHAPE_SINGLETON,
|
|
SHAPE_TUPLE,
|
|
SHAPE_TUPLE_ELLIPSIS,
|
|
)
|
|
from pydantic.v1.fields import FieldInfo as FieldInfo # type: ignore[assignment]
|
|
from pydantic.v1.fields import ModelField as ModelField
|
|
from pydantic.v1.fields import Undefined as Undefined
|
|
from pydantic.v1.fields import UndefinedType as UndefinedType
|
|
from pydantic.v1.networks import AnyUrl as AnyUrl
|
|
from pydantic.v1.networks import ( # type: ignore[assignment]
|
|
NameEmail as NameEmail,
|
|
)
|
|
from pydantic.v1.schema import TypeModelSet as TypeModelSet
|
|
from pydantic.v1.schema import (
|
|
field_schema,
|
|
get_flat_models_from_fields,
|
|
model_process_schema,
|
|
)
|
|
from pydantic.v1.schema import (
|
|
get_annotation_from_field_info as get_annotation_from_field_info,
|
|
)
|
|
from pydantic.v1.schema import (
|
|
get_flat_models_from_field as get_flat_models_from_field,
|
|
)
|
|
from pydantic.v1.schema import get_model_name_map as get_model_name_map
|
|
from pydantic.v1.types import ( # type: ignore[assignment]
|
|
SecretBytes as SecretBytes,
|
|
)
|
|
from pydantic.v1.types import ( # type: ignore[assignment]
|
|
SecretStr as SecretStr,
|
|
)
|
|
from pydantic.v1.typing import evaluate_forwardref as evaluate_forwardref
|
|
from pydantic.v1.utils import lenient_issubclass as lenient_issubclass
|
|
|
|
|
|
GetJsonSchemaHandler = Any
|
|
JsonSchemaValue = Dict[str, Any]
|
|
CoreSchema = Any
|
|
Url = AnyUrl
|
|
|
|
sequence_shapes = {
|
|
SHAPE_LIST,
|
|
SHAPE_SET,
|
|
SHAPE_FROZENSET,
|
|
SHAPE_TUPLE,
|
|
SHAPE_SEQUENCE,
|
|
SHAPE_TUPLE_ELLIPSIS,
|
|
}
|
|
sequence_shape_to_type = {
|
|
SHAPE_LIST: list,
|
|
SHAPE_SET: set,
|
|
SHAPE_TUPLE: tuple,
|
|
SHAPE_SEQUENCE: list,
|
|
SHAPE_TUPLE_ELLIPSIS: list,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class GenerateJsonSchema:
|
|
ref_template: str
|
|
|
|
|
|
class PydanticSchemaGenerationError(Exception):
|
|
pass
|
|
|
|
|
|
RequestErrorModel: Type[BaseModel] = create_model("Request")
|
|
|
|
|
|
def with_info_plain_validator_function(
|
|
function: Callable[..., Any],
|
|
*,
|
|
ref: Union[str, None] = None,
|
|
metadata: Any = None,
|
|
serialization: Any = None,
|
|
) -> Any:
|
|
return {}
|
|
|
|
|
|
def get_model_definitions(
|
|
*,
|
|
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
|
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
|
) -> Dict[str, Any]:
|
|
definitions: Dict[str, Dict[str, Any]] = {}
|
|
for model in flat_models:
|
|
m_schema, m_definitions, m_nested_models = model_process_schema(
|
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
|
)
|
|
definitions.update(m_definitions)
|
|
model_name = model_name_map[model]
|
|
definitions[model_name] = m_schema
|
|
for m_schema in definitions.values():
|
|
if "description" in m_schema:
|
|
m_schema["description"] = m_schema["description"].split("\f")[0]
|
|
return definitions
|
|
|
|
|
|
def is_pv1_scalar_field(field: ModelField) -> bool:
|
|
from fastapi import params
|
|
|
|
field_info = field.field_info
|
|
if not (
|
|
field.shape == SHAPE_SINGLETON
|
|
and not lenient_issubclass(field.type_, BaseModel)
|
|
and not lenient_issubclass(field.type_, dict)
|
|
and not shared.field_annotation_is_sequence(field.type_)
|
|
and not is_dataclass(field.type_)
|
|
and not isinstance(field_info, params.Body)
|
|
):
|
|
return False
|
|
if field.sub_fields:
|
|
if not all(is_pv1_scalar_field(f) for f in field.sub_fields):
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
|
if (field.shape in sequence_shapes) and not lenient_issubclass(
|
|
field.type_, BaseModel
|
|
):
|
|
if field.sub_fields is not None:
|
|
for sub_field in field.sub_fields:
|
|
if not is_pv1_scalar_field(sub_field):
|
|
return False
|
|
return True
|
|
if shared._annotation_is_sequence(field.type_):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
|
use_errors: List[Any] = []
|
|
for error in errors:
|
|
if isinstance(error, ErrorWrapper):
|
|
new_errors = ValidationError( # type: ignore[call-arg]
|
|
errors=[error], model=RequestErrorModel
|
|
).errors()
|
|
use_errors.extend(new_errors)
|
|
elif isinstance(error, list):
|
|
use_errors.extend(_normalize_errors(error))
|
|
else:
|
|
use_errors.append(error)
|
|
return use_errors
|
|
|
|
|
|
def _regenerate_error_with_loc(
|
|
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
|
) -> List[Dict[str, Any]]:
|
|
updated_loc_errors: List[Any] = [
|
|
{**err, "loc": loc_prefix + err.get("loc", ())}
|
|
for err in _normalize_errors(errors)
|
|
]
|
|
|
|
return updated_loc_errors
|
|
|
|
|
|
def _model_rebuild(model: Type[BaseModel]) -> None:
|
|
model.update_forward_refs()
|
|
|
|
|
|
def _model_dump(
|
|
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
|
) -> Any:
|
|
return model.dict(**kwargs)
|
|
|
|
|
|
def _get_model_config(model: BaseModel) -> Any:
|
|
return model.__config__ # type: ignore[attr-defined]
|
|
|
|
|
|
def get_schema_from_model_field(
|
|
*,
|
|
field: ModelField,
|
|
model_name_map: ModelNameMap,
|
|
field_mapping: Dict[
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
],
|
|
separate_input_output_schemas: bool = True,
|
|
) -> Dict[str, Any]:
|
|
return field_schema( # type: ignore[no-any-return]
|
|
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
|
)[0]
|
|
|
|
|
|
# def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
|
# models = get_flat_models_from_fields(fields, known_models=set())
|
|
# return get_model_name_map(models) # type: ignore[no-any-return]
|
|
|
|
|
|
def get_definitions(
|
|
*,
|
|
fields: List[ModelField],
|
|
model_name_map: ModelNameMap,
|
|
separate_input_output_schemas: bool = True,
|
|
) -> Tuple[
|
|
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
|
Dict[str, Dict[str, Any]],
|
|
]:
|
|
models = get_flat_models_from_fields(fields, known_models=set())
|
|
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
|
|
|
|
|
def is_scalar_field(field: ModelField) -> bool:
|
|
return is_pv1_scalar_field(field)
|
|
|
|
|
|
def is_sequence_field(field: ModelField) -> bool:
|
|
return field.shape in sequence_shapes or shared._annotation_is_sequence(field.type_)
|
|
|
|
|
|
def is_scalar_sequence_field(field: ModelField) -> bool:
|
|
return is_pv1_scalar_sequence_field(field)
|
|
|
|
|
|
def is_bytes_field(field: ModelField) -> bool:
|
|
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
|
|
|
|
|
|
def is_bytes_sequence_field(field: ModelField) -> bool:
|
|
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes)
|
|
|
|
|
|
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
|
return copy(field_info)
|
|
|
|
|
|
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
|
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return]
|
|
|
|
|
|
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
|
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
|
|
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
|
return new_error.errors()[0] # type: ignore[return-value]
|
|
|
|
|
|
def create_body_model(
|
|
*, fields: Sequence[ModelField], model_name: str
|
|
) -> Type[BaseModel]:
|
|
BodyModel = create_model(model_name)
|
|
for f in fields:
|
|
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
|
return BodyModel
|
|
|
|
|
|
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
|
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
|
|