From 497718e93d9defe1eecbb2903e8858be21c46aeb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 13 Jul 2023 17:58:14 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20custom=20GenerateJsonSchema?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/_compat.py | 71 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 2b4d3e725..492322337 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -48,8 +48,10 @@ sequence_annotation_to_type = { sequence_types = tuple(sequence_annotation_to_type.keys()) if PYDANTIC_V2: + from typing import Sequence + from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import PydanticUserError, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -57,10 +59,17 @@ if PYDANTIC_V2: from pydantic._internal._typing_extra import eval_type_lenient from pydantic._internal._utils import lenient_issubclass as lenient_issubclass from pydantic.fields import FieldInfo - from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema + from pydantic.json_schema import ( + DEFAULT_REF_TEMPLATE, + DefsRef, + JsonSchemaKeyT, + JsonSchemaMode, + _sort_json_schema, + ) + from pydantic.json_schema import GenerateJsonSchema as _GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema - from pydantic_core import PydanticUndefined, PydanticUndefinedType + from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema from pydantic_core import Url as Url try: @@ -72,6 +81,62 @@ if PYDANTIC_V2: general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 ) + + class GenerateJsonSchema(_GenerateJsonSchema): + def __init__( + self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE + ): + super().__init__(by_alias=by_alias, ref_template=ref_template) + self.skip_null_schema = False + + def nullable_schema( + self, schema: core_schema.NullableSchema + ) -> JsonSchemaValue: + if self.skip_null_schema: + return super().generate_inner(schema["schema"]) + return super().nullable_schema(schema) + + def generate_definitions( + self, + inputs: Sequence[ + tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema] + ], + ) -> tuple[ + dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], + dict[DefsRef, JsonSchemaValue], + ]: + # Avoid circular import - Maybe there's a better way to check if it's a Param + from fastapi.params import Param + + if self._used: + raise PydanticUserError( + "This JSON schema generator has already been used to generate a JSON schema. " + f"You must create a new instance of {type(self).__name__} to generate a new JSON schema.", + code="json-schema-already-used", + ) + + for key, mode, schema in inputs: + self.mode = mode + self.skip_null_schema = isinstance(key, ModelField) and isinstance( + key.field_info, Param + ) + self.generate_inner(schema) + + definitions_remapping = self._build_definitions_remapping() + + json_schemas_map: dict[tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} + for key, mode, schema in inputs: + self.mode = mode + json_schema = self.generate_inner(schema) + json_schemas_map[(key, mode)] = definitions_remapping.remap_json_schema( + json_schema + ) + + json_schema = {"$defs": self.definitions} + json_schema = definitions_remapping.remap_json_schema(json_schema) + self._used = True + return json_schemas_map, _sort_json_schema(json_schema["$defs"]) # type: ignore + RequiredParam = PydanticUndefined Undefined = PydanticUndefined UndefinedType = PydanticUndefinedType