From 497718e93d9defe1eecbb2903e8858be21c46aeb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 13 Jul 2023 17:58:14 +0200 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20Add=20custom=20GenerateJsonSche?= =?UTF-8?q?ma?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/_compat.py | 71 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 2b4d3e725..492322337 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -48,8 +48,10 @@ sequence_annotation_to_type = { sequence_types = tuple(sequence_annotation_to_type.keys()) if PYDANTIC_V2: + from typing import Sequence + from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import PydanticUserError, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -57,10 +59,17 @@ if PYDANTIC_V2: from pydantic._internal._typing_extra import eval_type_lenient from pydantic._internal._utils import lenient_issubclass as lenient_issubclass from pydantic.fields import FieldInfo - from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema + from pydantic.json_schema import ( + DEFAULT_REF_TEMPLATE, + DefsRef, + JsonSchemaKeyT, + JsonSchemaMode, + _sort_json_schema, + ) + from pydantic.json_schema import GenerateJsonSchema as _GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema - from pydantic_core import PydanticUndefined, PydanticUndefinedType + from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema from pydantic_core import Url as Url try: @@ -72,6 +81,62 @@ if PYDANTIC_V2: general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 ) + + class GenerateJsonSchema(_GenerateJsonSchema): + def __init__( + self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE + ): + super().__init__(by_alias=by_alias, ref_template=ref_template) + self.skip_null_schema = False + + def nullable_schema( + self, schema: core_schema.NullableSchema + ) -> JsonSchemaValue: + if self.skip_null_schema: + return super().generate_inner(schema["schema"]) + return super().nullable_schema(schema) + + def generate_definitions( + self, + inputs: Sequence[ + tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema] + ], + ) -> tuple[ + dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], + dict[DefsRef, JsonSchemaValue], + ]: + # Avoid circular import - Maybe there's a better way to check if it's a Param + from fastapi.params import Param + + if self._used: + raise PydanticUserError( + "This JSON schema generator has already been used to generate a JSON schema. " + f"You must create a new instance of {type(self).__name__} to generate a new JSON schema.", + code="json-schema-already-used", + ) + + for key, mode, schema in inputs: + self.mode = mode + self.skip_null_schema = isinstance(key, ModelField) and isinstance( + key.field_info, Param + ) + self.generate_inner(schema) + + definitions_remapping = self._build_definitions_remapping() + + json_schemas_map: dict[tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} + for key, mode, schema in inputs: + self.mode = mode + json_schema = self.generate_inner(schema) + json_schemas_map[(key, mode)] = definitions_remapping.remap_json_schema( + json_schema + ) + + json_schema = {"$defs": self.definitions} + json_schema = definitions_remapping.remap_json_schema(json_schema) + self._used = True + return json_schemas_map, _sort_json_schema(json_schema["$defs"]) # type: ignore + RequiredParam = PydanticUndefined Undefined = PydanticUndefined UndefinedType = PydanticUndefinedType From 169e5b2b4eee2f8a87006bda250483fa6c80561c Mon Sep 17 00:00:00 2001 From: Florian Maurer Date: Thu, 6 Jun 2024 14:46:57 +0200 Subject: [PATCH 2/2] update mode and default_schema fix tests to work on python3.8 --- fastapi/_compat.py | 29 +++++++++++++++++++++-------- tests/test_multi_body_errors.py | 12 ++++++++++-- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 492322337..0a6fecad4 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -81,7 +81,6 @@ if PYDANTIC_V2: general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 ) - class GenerateJsonSchema(_GenerateJsonSchema): def __init__( self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE @@ -96,14 +95,25 @@ if PYDANTIC_V2: return super().generate_inner(schema["schema"]) return super().nullable_schema(schema) + def default_schema( + self, schema: core_schema.WithDefaultSchema + ) -> JsonSchemaValue: + json_schema = super().default_schema(schema) + if ( + self.skip_null_schema + and json_schema.get("default", PydanticUndefined) is None + ): + json_schema.pop("default") + return json_schema + def generate_definitions( self, inputs: Sequence[ - tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema] + Tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema] ], - ) -> tuple[ - dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], - dict[DefsRef, JsonSchemaValue], + ) -> Tuple[ + Dict[Tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], + Dict[DefsRef, JsonSchemaValue], ]: # Avoid circular import - Maybe there's a better way to check if it's a Param from fastapi.params import Param @@ -116,7 +126,7 @@ if PYDANTIC_V2: ) for key, mode, schema in inputs: - self.mode = mode + self._mode = mode self.skip_null_schema = isinstance(key, ModelField) and isinstance( key.field_info, Param ) @@ -124,9 +134,12 @@ if PYDANTIC_V2: definitions_remapping = self._build_definitions_remapping() - json_schemas_map: dict[tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} + json_schemas_map: Dict[Tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} for key, mode, schema in inputs: - self.mode = mode + self._mode = mode + self.skip_null_schema = isinstance(key, ModelField) and isinstance( + key.field_info, Param + ) json_schema = self.generate_inner(schema) json_schemas_map[(key, mode)] = definitions_remapping.remap_json_schema( json_schema diff --git a/tests/test_multi_body_errors.py b/tests/test_multi_body_errors.py index 0102f0f1a..5db4a6f9f 100644 --- a/tests/test_multi_body_errors.py +++ b/tests/test_multi_body_errors.py @@ -83,7 +83,11 @@ def test_put_incorrect_body_multiple(): }, { "type": "decimal_parsing", - "loc": ["body", 0, "age"], + "loc": [ + "body", + 0, + "age", + ], "msg": "Input should be a valid decimal", "input": "five", }, @@ -95,7 +99,11 @@ def test_put_incorrect_body_multiple(): }, { "type": "decimal_parsing", - "loc": ["body", 1, "age"], + "loc": [ + "body", + 1, + "age", + ], "msg": "Input should be a valid decimal", "input": "six", },