|
|
|
@ -1,7 +1,7 @@ |
|
|
|
import re |
|
|
|
import warnings |
|
|
|
from collections.abc import Sequence |
|
|
|
from copy import copy, deepcopy |
|
|
|
from copy import copy |
|
|
|
from dataclasses import dataclass, is_dataclass |
|
|
|
from enum import Enum |
|
|
|
from functools import lru_cache |
|
|
|
@ -169,11 +169,11 @@ class ModelField: |
|
|
|
values: dict[str, Any] = {}, # noqa: B006 |
|
|
|
*, |
|
|
|
loc: tuple[Union[int, str], ...] = (), |
|
|
|
) -> tuple[Any, Union[list[dict[str, Any]], None]]: |
|
|
|
) -> tuple[Any, list[dict[str, Any]]]: |
|
|
|
try: |
|
|
|
return ( |
|
|
|
self._type_adapter.validate_python(value, from_attributes=True), |
|
|
|
None, |
|
|
|
[], |
|
|
|
) |
|
|
|
except ValidationError as exc: |
|
|
|
return None, _regenerate_error_with_loc( |
|
|
|
@ -305,94 +305,12 @@ def get_definitions( |
|
|
|
if "description" in item_def: |
|
|
|
item_description = cast(str, item_def["description"]).split("\f")[0] |
|
|
|
item_def["description"] = item_description |
|
|
|
new_mapping, new_definitions = _remap_definitions_and_field_mappings( |
|
|
|
model_name_map=model_name_map, |
|
|
|
definitions=definitions, # type: ignore[arg-type] |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
return new_mapping, new_definitions |
|
|
|
|
|
|
|
|
|
|
|
def _replace_refs( |
|
|
|
*, |
|
|
|
schema: dict[str, Any], |
|
|
|
old_name_to_new_name_map: dict[str, str], |
|
|
|
) -> dict[str, Any]: |
|
|
|
new_schema = deepcopy(schema) |
|
|
|
for key, value in new_schema.items(): |
|
|
|
if key == "$ref": |
|
|
|
value = schema["$ref"] |
|
|
|
if isinstance(value, str): |
|
|
|
ref_name = schema["$ref"].split("/")[-1] |
|
|
|
if ref_name in old_name_to_new_name_map: |
|
|
|
new_name = old_name_to_new_name_map[ref_name] |
|
|
|
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name) |
|
|
|
continue |
|
|
|
if isinstance(value, dict): |
|
|
|
new_schema[key] = _replace_refs( |
|
|
|
schema=value, |
|
|
|
old_name_to_new_name_map=old_name_to_new_name_map, |
|
|
|
) |
|
|
|
elif isinstance(value, list): |
|
|
|
new_value = [] |
|
|
|
for item in value: |
|
|
|
if isinstance(item, dict): |
|
|
|
new_item = _replace_refs( |
|
|
|
schema=item, |
|
|
|
old_name_to_new_name_map=old_name_to_new_name_map, |
|
|
|
) |
|
|
|
new_value.append(new_item) |
|
|
|
|
|
|
|
else: |
|
|
|
new_value.append(item) |
|
|
|
new_schema[key] = new_value |
|
|
|
return new_schema |
|
|
|
|
|
|
|
|
|
|
|
def _remap_definitions_and_field_mappings( |
|
|
|
*, |
|
|
|
model_name_map: ModelNameMap, |
|
|
|
definitions: dict[str, Any], |
|
|
|
field_mapping: dict[ |
|
|
|
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue |
|
|
|
], |
|
|
|
) -> tuple[ |
|
|
|
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], |
|
|
|
dict[str, Any], |
|
|
|
]: |
|
|
|
old_name_to_new_name_map = {} |
|
|
|
for field_key, schema in field_mapping.items(): |
|
|
|
model = field_key[0].type_ |
|
|
|
if model not in model_name_map or "$ref" not in schema: |
|
|
|
continue |
|
|
|
new_name = model_name_map[model] |
|
|
|
old_name = schema["$ref"].split("/")[-1] |
|
|
|
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}: |
|
|
|
continue |
|
|
|
old_name_to_new_name_map[old_name] = new_name |
|
|
|
|
|
|
|
new_field_mapping: dict[ |
|
|
|
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue |
|
|
|
] = {} |
|
|
|
for field_key, schema in field_mapping.items(): |
|
|
|
new_schema = _replace_refs( |
|
|
|
schema=schema, |
|
|
|
old_name_to_new_name_map=old_name_to_new_name_map, |
|
|
|
) |
|
|
|
new_field_mapping[field_key] = new_schema |
|
|
|
|
|
|
|
new_definitions = {} |
|
|
|
for key, value in definitions.items(): |
|
|
|
if key in old_name_to_new_name_map: |
|
|
|
new_key = old_name_to_new_name_map[key] |
|
|
|
else: |
|
|
|
new_key = key |
|
|
|
new_value = _replace_refs( |
|
|
|
schema=value, |
|
|
|
old_name_to_new_name_map=old_name_to_new_name_map, |
|
|
|
) |
|
|
|
new_definitions[new_key] = new_value |
|
|
|
return new_field_mapping, new_definitions |
|
|
|
# definitions: dict[DefsRef, dict[str, Any]] |
|
|
|
# but mypy complains about general str in other places that are not declared as |
|
|
|
# DefsRef, although DefsRef is just str: |
|
|
|
# DefsRef = NewType('DefsRef', str) |
|
|
|
# So, a cast to simplify the types here |
|
|
|
return field_mapping, cast(dict[str, dict[str, Any]], definitions) |
|
|
|
|
|
|
|
|
|
|
|
def is_scalar_field(field: ModelField) -> bool: |
|
|
|
@ -441,7 +359,7 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: |
|
|
|
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index] |
|
|
|
|
|
|
|
|
|
|
|
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]: |
|
|
|
def get_missing_field_error(loc: tuple[Union[int, str], ...]) -> dict[str, Any]: |
|
|
|
error = ValidationError.from_exception_data( |
|
|
|
"Field required", [{"type": "missing", "loc": loc, "input": {}}] |
|
|
|
).errors(include_url=False)[0] |
|
|
|
@ -499,17 +417,6 @@ def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str |
|
|
|
return {v: k for k, v in name_model_map.items()} |
|
|
|
|
|
|
|
|
|
|
|
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap: |
|
|
|
all_flat_models: TypeModelSet = set() |
|
|
|
|
|
|
|
v2_model_fields = [field for field in fields if isinstance(field, ModelField)] |
|
|
|
v2_flat_models = get_flat_models_from_fields(v2_model_fields, known_models=set()) |
|
|
|
all_flat_models = all_flat_models.union(v2_flat_models) |
|
|
|
|
|
|
|
model_name_map = get_model_name_map(all_flat_models) |
|
|
|
return model_name_map |
|
|
|
|
|
|
|
|
|
|
|
def get_flat_models_from_model( |
|
|
|
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None |
|
|
|
) -> TypeModelSet: |
|
|
|
|