From 30e7a1894c0e1c812c6589d2cd343b587011be83 Mon Sep 17 00:00:00 2001 From: chbndrhnns Date: Fri, 18 Jul 2025 18:29:05 +0200 Subject: [PATCH] Generate schema for pydantic v1 models --- fastapi/_compat.py | 159 +++++++++++++++++++++++++++---- tests/test_pydantic_v1_models.py | 22 +++++ 2 files changed, 163 insertions(+), 18 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 278f03156..5b2be9c70 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -233,15 +233,64 @@ if PYDANTIC_V2: override_mode: Union[Literal["validation"], None] = ( None if separate_input_output_schemas else "validation" ) - # This expects that GenerateJsonSchema was already used to generate the definitions - json_schema = field_mapping[(field, override_mode or field.mode)] - if "$ref" not in json_schema: - # TODO remove when deprecating Pydantic v1 - # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 - json_schema["title"] = ( - field.field_info.title or field.alias.title().replace("_", " ") - ) - return json_schema + + # Check if field is a v1 or v2 field + is_pydantic_v2_field = hasattr(field, "_type_adapter") + if is_pydantic_v2_field: + # V2 field - use field_mapping + json_schema = field_mapping[(field, override_mode or field.mode)] + if "$ref" not in json_schema: + # TODO remove when deprecating Pydantic v1 + # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 + json_schema["title"] = ( + field.field_info.title or field.alias.title().replace("_", " ") + ) + return json_schema + else: + # V1 field - use model_schema from pydantic v1 + from fastapi.openapi.constants import REF_PREFIX + from pydantic import v1 + + # Extract the model type from the field + field_type = field.type_ + + # Handle different field types + if hasattr(field_type, "__origin__") and field_type.__origin__ is not None: + # Handle generic types like List[Model] + args = getattr(field_type, "__args__", []) + for arg in args: + if lenient_issubclass(arg, v1.BaseModel): + field_type = arg + break + + # If field type is a v1 model, generate schema for it + if lenient_issubclass(field_type, v1.BaseModel): + # Get model name from model_name_map or use the class name + model_name = model_name_map.get(field_type, field_type.__name__) + + # Generate schema for the model + # Note: model_schema doesn't accept model_name_map, but we've already + # extracted the model name above + v1.schema.model_schema(field_type, by_alias=True, ref_prefix=REF_PREFIX) + + # Return a reference to the model schema + return {"$ref": f"{REF_PREFIX}{model_name}"} + + # Fallback to a simple schema based on the field info + schema = { + "title": field.field_info.title + or field.alias.title().replace("_", " "), + "type": "object", + } + + # Add description if available + if ( + hasattr(field.field_info, "description") + and field.field_info.description + ): + schema["description"] = field.field_info.description + + return schema def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: return {} @@ -261,18 +310,92 @@ if PYDANTIC_V2: override_mode: Union[Literal["validation"], None] = ( None if separate_input_output_schemas else "validation" ) - inputs = [ - (field, override_mode or field.mode, field._type_adapter.core_schema) - for field in fields - ] - field_mapping, definitions = schema_generator.generate_definitions( - inputs=inputs - ) - for item_def in cast(Dict[str, Dict[str, Any]], definitions).values(): + + # Split fields into v1 and v2 fields + v1_fields = [] + v2_fields = [] + + for field in fields: + # Check if field has _type_adapter attribute (v2) or not (v1) + if hasattr(field, "_type_adapter"): + v2_fields.append(field) + else: + v1_fields.append(field) + + # Process v2 fields if any + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ] = {} + definitions: Dict[Any, Dict[str, Any]] = {} + + if v2_fields: + inputs = [ + (field, override_mode or field.mode, field._type_adapter.core_schema) + for field in v2_fields + ] + field_mapping, definitions = schema_generator.generate_definitions( + inputs=inputs + ) + + # Process v1 fields if any + if v1_fields: + # Import necessary functions from pydantic v1 + from fastapi.openapi.constants import REF_PREFIX + from pydantic import v1 + + # Extract v1 models from fields + v1_models = set() + for field in v1_fields: + # Get the model from the field's type + field_type = field.type_ + if ( + hasattr(field_type, "__origin__") + and field_type.__origin__ is not None + ): + # Handle generic types like List[Model] + args = getattr(field_type, "__args__", []) + for arg in args: + if lenient_issubclass(arg, v1.BaseModel): + v1_models.add(arg) + elif lenient_issubclass(field_type, v1.BaseModel): + v1_models.add(field_type) + + # If we found any v1 models, generate their schemas + if v1_models: + # Get model name map for v1 models + v1_model_name_map = model_name_map or v1.schema.get_model_name_map( + v1_models + ) + + # Generate definitions for v1 models + v1_definitions = {} + for model in v1_models: + m_schema, m_definitions, m_nested_models = ( + v1.schema.model_process_schema( + model, + model_name_map=cast( + Dict[Union[Type["v1.BaseModel"], Type[Enum]], str], + v1_model_name_map, + ), + ref_prefix=REF_PREFIX, + ) + ) + v1_definitions.update(m_definitions) + model_name = v1_model_name_map[model] + if "description" in m_schema: + m_schema["description"] = m_schema["description"].split("\f")[0] + v1_definitions[model_name] = m_schema + + # Merge definitions + definitions.update(cast(Dict[str, Dict[str, Any]], v1_definitions)) + + # Clean up descriptions + for item_def in definitions.values(): if "description" in item_def: item_description = cast(str, item_def["description"]).split("\f")[0] item_def["description"] = item_description - return field_mapping, definitions # type: ignore[return-value] + + return field_mapping, definitions def is_scalar_field(field: ModelField) -> bool: from fastapi import params diff --git a/tests/test_pydantic_v1_models.py b/tests/test_pydantic_v1_models.py index 2728d9f15..595101e54 100644 --- a/tests/test_pydantic_v1_models.py +++ b/tests/test_pydantic_v1_models.py @@ -97,3 +97,25 @@ class TestRequestBody: def test_model__invalid(self): response = client.post("/request_body", json={"name": "myname"}) assert response.status_code == 422, response.text + + +@needs_pydanticv2 +class TestSchema: + def test_can_generate(self): + spec = app.openapi() + schema = spec["paths"]["/request_body"]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"] + # Check that the schema is not empty and contains the expected properties + assert "$ref" in schema + ref = schema["$ref"].split("/")[-1] + assert ref in spec["components"]["schemas"] + item_schema = spec["components"]["schemas"][ref] + assert item_schema["properties"]["name"]["type"] == "string" + assert item_schema["properties"]["description"]["type"] == "string" + assert item_schema["properties"]["price"]["type"] == "number" + assert item_schema["properties"]["tax"]["type"] == "number" + assert item_schema["properties"]["tags"]["type"] == "array" + assert "required" in item_schema + assert "name" in item_schema["required"] + assert "price" in item_schema["required"]