diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..32fd15900 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -48,8 +48,10 @@ sequence_types = tuple(sequence_annotation_to_type.keys()) Url: Type[Any] if PYDANTIC_V2: + from typing import Sequence + from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import PydanticUserError, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -57,10 +59,17 @@ if PYDANTIC_V2: from pydantic._internal._typing_extra import eval_type_lenient from pydantic._internal._utils import lenient_issubclass as lenient_issubclass from pydantic.fields import FieldInfo - from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema + from pydantic.json_schema import ( + DEFAULT_REF_TEMPLATE, + DefsRef, + JsonSchemaKeyT, + JsonSchemaMode, + _sort_json_schema, + ) + from pydantic.json_schema import GenerateJsonSchema as _GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema - from pydantic_core import PydanticUndefined, PydanticUndefinedType + from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema from pydantic_core import Url as Url try: @@ -72,6 +81,75 @@ if PYDANTIC_V2: general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 ) + class GenerateJsonSchema(_GenerateJsonSchema): + def __init__( + self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE + ): + super().__init__(by_alias=by_alias, ref_template=ref_template) + self.skip_null_schema = False + + def nullable_schema( + self, schema: core_schema.NullableSchema + ) -> JsonSchemaValue: + if self.skip_null_schema: + return super().generate_inner(schema["schema"]) + return super().nullable_schema(schema) + + def default_schema( + self, schema: core_schema.WithDefaultSchema + ) -> JsonSchemaValue: + json_schema = super().default_schema(schema) + if ( + self.skip_null_schema + and json_schema.get("default", PydanticUndefined) is None + ): + json_schema.pop("default") + return json_schema + + def generate_definitions( + self, + inputs: Sequence[ + Tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema] + ], + ) -> Tuple[ + Dict[Tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], + Dict[DefsRef, JsonSchemaValue], + ]: + # Avoid circular import - Maybe there's a better way to check if it's a Param + from fastapi.params import Param + + if self._used: + raise PydanticUserError( + "This JSON schema generator has already been used to generate a JSON schema. " + f"You must create a new instance of {type(self).__name__} to generate a new JSON schema.", + code="json-schema-already-used", + ) + + for key, mode, schema in inputs: + self._mode = mode + self.skip_null_schema = isinstance(key, ModelField) and isinstance( + key.field_info, Param + ) + self.generate_inner(schema) + + definitions_remapping = self._build_definitions_remapping() + + json_schemas_map: Dict[Tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} + for key, mode, schema in inputs: + self._mode = mode + self.skip_null_schema = isinstance(key, ModelField) and isinstance( + key.field_info, Param + ) + json_schema = self.generate_inner(schema) + json_schemas_map[(key, mode)] = definitions_remapping.remap_json_schema( + json_schema + ) + + json_schema = {"$defs": self.definitions} + json_schema = definitions_remapping.remap_json_schema(json_schema) + self._used = True + return json_schemas_map, _sort_json_schema(json_schema["$defs"]) # type: ignore + RequiredParam = PydanticUndefined Undefined = PydanticUndefined UndefinedType = PydanticUndefinedType diff --git a/tests/test_multi_body_errors.py b/tests/test_multi_body_errors.py index 0102f0f1a..5db4a6f9f 100644 --- a/tests/test_multi_body_errors.py +++ b/tests/test_multi_body_errors.py @@ -83,7 +83,11 @@ def test_put_incorrect_body_multiple(): }, { "type": "decimal_parsing", - "loc": ["body", 0, "age"], + "loc": [ + "body", + 0, + "age", + ], "msg": "Input should be a valid decimal", "input": "five", }, @@ -95,7 +99,11 @@ def test_put_incorrect_body_multiple(): }, { "type": "decimal_parsing", - "loc": ["body", 1, "age"], + "loc": [ + "body", + 1, + "age", + ], "msg": "Input should be a valid decimal", "input": "six", },