from functools import lru_cache from typing import ( Any, Dict, List, Sequence, Tuple, Type, ) from fastapi._compat import v1 from fastapi._compat.shared import PYDANTIC_V2, lenient_issubclass from fastapi.types import ModelNameMap from pydantic import BaseModel from typing_extensions import Literal from .model_field import ModelField if PYDANTIC_V2: from .v2 import BaseConfig as BaseConfig from .v2 import FieldInfo as FieldInfo from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError from .v2 import RequiredParam as RequiredParam from .v2 import Undefined as Undefined from .v2 import UndefinedType as UndefinedType from .v2 import Url as Url from .v2 import Validator as Validator from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import get_missing_field_error as get_missing_field_error from .v2 import ( with_info_plain_validator_function as with_info_plain_validator_function, ) else: from .v1 import BaseConfig as BaseConfig # type: ignore[assignment] from .v1 import FieldInfo as FieldInfo from .v1 import ( # type: ignore[assignment] PydanticSchemaGenerationError as PydanticSchemaGenerationError, ) from .v1 import RequiredParam as RequiredParam from .v1 import Undefined as Undefined from .v1 import UndefinedType as UndefinedType from .v1 import Url as Url # type: ignore[assignment] from .v1 import Validator as Validator from .v1 import evaluate_forwardref as evaluate_forwardref from .v1 import get_missing_field_error as get_missing_field_error from .v1 import ( # type: ignore[assignment] with_info_plain_validator_function as with_info_plain_validator_function, ) @lru_cache def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: if lenient_issubclass(model, v1.BaseModel): return v1.get_model_fields(model) else: from . import v2 return v2.get_model_fields(model) # type: ignore[return-value] def _is_undefined(value: object) -> bool: if isinstance(value, v1.UndefinedType): return True elif PYDANTIC_V2: from . import v2 return isinstance(value, v2.UndefinedType) return False def _get_model_config(model: BaseModel) -> Any: if isinstance(model, v1.BaseModel): return v1._get_model_config(model) elif PYDANTIC_V2: from . import v2 return v2._get_model_config(model) def _model_dump( model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any ) -> Any: if isinstance(model, v1.BaseModel): return v1._model_dump(model, mode=mode, **kwargs) elif PYDANTIC_V2: from . import v2 return v2._model_dump(model, mode=mode, **kwargs) def _is_error_wrapper(exc: Exception) -> bool: if isinstance(exc, v1.ErrorWrapper): return True elif PYDANTIC_V2: from . import v2 return isinstance(exc, v2.ErrorWrapper) return False def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: if isinstance(field_info, v1.FieldInfo): return v1.copy_field_info(field_info=field_info, annotation=annotation) else: assert PYDANTIC_V2 from . import v2 return v2.copy_field_info(field_info=field_info, annotation=annotation) def create_body_model( *, fields: Sequence[ModelField], model_name: str ) -> Type[BaseModel]: if fields and isinstance(fields[0], v1.ModelField): return v1.create_body_model(fields=fields, model_name=model_name) else: assert PYDANTIC_V2 from . import v2 return v2.create_body_model(fields=fields, model_name=model_name) # type: ignore[arg-type] def get_annotation_from_field_info( annotation: Any, field_info: FieldInfo, field_name: str ) -> Any: if isinstance(field_info, v1.FieldInfo): return v1.get_annotation_from_field_info( annotation=annotation, field_info=field_info, field_name=field_name ) else: assert PYDANTIC_V2 from . import v2 return v2.get_annotation_from_field_info( annotation=annotation, field_info=field_info, field_name=field_name ) def is_bytes_field(field: ModelField) -> bool: if isinstance(field, v1.ModelField): return v1.is_bytes_field(field) else: assert PYDANTIC_V2 from . import v2 return v2.is_bytes_field(field) # type: ignore[arg-type] def is_bytes_sequence_field(field: ModelField) -> bool: if isinstance(field, v1.ModelField): return v1.is_bytes_sequence_field(field) else: assert PYDANTIC_V2 from . import v2 return v2.is_bytes_sequence_field(field) # type: ignore[arg-type] def is_scalar_field(field: ModelField) -> bool: if isinstance(field, v1.ModelField): return v1.is_scalar_field(field) else: assert PYDANTIC_V2 from . import v2 return v2.is_scalar_field(field) # type: ignore[arg-type] def is_scalar_sequence_field(field: ModelField) -> bool: if isinstance(field, v1.ModelField): return v1.is_scalar_sequence_field(field) else: assert PYDANTIC_V2 from . import v2 return v2.is_scalar_sequence_field(field) # type: ignore[arg-type] def is_sequence_field(field: ModelField) -> bool: if isinstance(field, v1.ModelField): return v1.is_sequence_field(field) else: assert PYDANTIC_V2 from . import v2 return v2.is_sequence_field(field) # type: ignore[arg-type] def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: if isinstance(field, v1.ModelField): return v1.serialize_sequence_value(field=field, value=value) else: assert PYDANTIC_V2 from . import v2 return v2.serialize_sequence_value(field=field, value=value) # type: ignore[arg-type] def _model_rebuild(model: Type[BaseModel]) -> None: if lenient_issubclass(model, v1.BaseModel): v1._model_rebuild(model) elif PYDANTIC_V2: from . import v2 v2._model_rebuild(model) def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: v1_model_fields = [field for field in fields if isinstance(field, v1.ModelField)] v1_flat_models = v1.get_flat_models_from_fields(v1_model_fields, known_models=set()) # type: ignore[attr-defined] all_flat_models = v1_flat_models if PYDANTIC_V2: from . import v2 v2_model_fields = [ field for field in fields if isinstance(field, v2.ModelField) ] v2_flat_models = v2.get_flat_models_from_fields( v2_model_fields, known_models=set() ) all_flat_models = all_flat_models.union(v2_flat_models) model_name_map = v2.get_model_name_map(all_flat_models) return model_name_map model_name_map = v1.get_model_name_map(all_flat_models) return model_name_map def get_definitions( *, fields: List[ModelField], model_name_map: ModelNameMap, separate_input_output_schemas: bool = True, ) -> Tuple[ Dict[Tuple[ModelField, Literal["validation", "serialization"]], v1.JsonSchemaValue], Dict[str, Dict[str, Any]], ]: v1_fields = [field for field in fields if isinstance(field, v1.ModelField)] v1_field_maps, v1_definitions = v1.get_definitions( fields=v1_fields, model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) if not PYDANTIC_V2: return v1_field_maps, v1_definitions else: from . import v2 v2_fields = [field for field in fields if isinstance(field, v2.ModelField)] v2_field_maps, v2_definitions = v2.get_definitions( fields=v2_fields, model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) all_definitions = {**v1_definitions, **v2_definitions} all_field_maps = {**v1_field_maps, **v2_field_maps} return all_field_maps, all_definitions def get_schema_from_model_field( *, field: ModelField, model_name_map: ModelNameMap, field_mapping: Dict[ Tuple[ModelField, Literal["validation", "serialization"]], v1.JsonSchemaValue ], separate_input_output_schemas: bool = True, ) -> Dict[str, Any]: if isinstance(field, v1.ModelField): return v1.get_schema_from_model_field( field=field, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) else: assert PYDANTIC_V2 from . import v2 return v2.get_schema_from_model_field( field=field, # type: ignore[arg-type] model_name_map=model_name_map, field_mapping=field_mapping, # type: ignore[arg-type] separate_input_output_schemas=separate_input_output_schemas, ) def _is_model_field(value: Any) -> bool: if isinstance(value, v1.ModelField): return True elif PYDANTIC_V2: from . import v2 return isinstance(value, v2.ModelField) return False def _is_model_class(value: Any) -> bool: if lenient_issubclass(value, v1.BaseModel): return True elif PYDANTIC_V2: from . import v2 return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined] return False