From d4e3dcfa3aeb8dec72e296dbd8ede074b097e146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 5 Jul 2023 19:05:39 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Use=20new=20Pydantic=20v2?= =?UTF-8?q?=20JSON=20Schema=20generator=20(#9813)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- fastapi/_compat.py | 23 ++++++++++++++++------- fastapi/openapi/utils.py | 22 +++++++++++++++++++++- fastapi/routing.py | 6 +++++- fastapi/utils.py | 6 +++++- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 85a017fc0..b10e1ac05 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -79,6 +79,7 @@ if PYDANTIC_V2: class ModelField: field_info: FieldInfo name: str + mode: Literal["validation", "serialization"] = "validation" @property def alias(self) -> str: @@ -178,9 +179,12 @@ if PYDANTIC_V2: field: ModelField, schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ], ) -> Dict[str, Any]: # This expects that GenerateJsonSchema was already used to generate the definitions - json_schema = schema_generator.generate_inner(field._type_adapter.core_schema) + json_schema = field_mapping[(field, field.mode)] if "$ref" not in json_schema: # TODO remove when deprecating Pydantic v1 # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 @@ -197,12 +201,12 @@ if PYDANTIC_V2: fields: List[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, - ) -> Dict[str, Dict[str, Any]]: + ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]: inputs = [ - (field, "validation", field._type_adapter.core_schema) for field in fields + (field, field.mode, field._type_adapter.core_schema) for field in fields ] - _, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type] - return definitions # type: ignore[return-value] + field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type] + return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: from fastapi import params @@ -419,6 +423,9 @@ else: field: ModelField, schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ], ) -> Dict[str, Any]: # This expects that GenerateJsonSchema was already used to generate the definitions return field_schema( # type: ignore[no-any-return] @@ -434,9 +441,11 @@ else: fields: List[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, - ) -> Dict[str, Dict[str, Any]]: + ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]: models = get_flat_models_from_fields(fields, known_models=set()) - return get_model_definitions(flat_models=models, model_name_map=model_name_map) + return {}, get_model_definitions( + flat_models=models, model_name_map=model_name_map + ) def is_scalar_field(field: ModelField) -> bool: return is_pv1_scalar_field(field) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 3d292a819..e295361e6 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, from fastapi import routing from fastapi._compat import ( GenerateJsonSchema, + JsonSchemaValue, ModelField, Undefined, get_compat_model_name_map, @@ -30,6 +31,7 @@ from fastapi.utils import ( from starlette.responses import JSONResponse from starlette.routing import BaseRoute from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from typing_extensions import Literal validation_error_definition = { "title": "ValidationError", @@ -90,6 +92,9 @@ def get_openapi_operation_parameters( all_route_params: Sequence[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ], ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: @@ -101,6 +106,7 @@ def get_openapi_operation_parameters( field=param, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) parameter = { "name": param.alias, @@ -123,6 +129,9 @@ def get_openapi_operation_request_body( body_field: Optional[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ], ) -> Optional[Dict[str, Any]]: if not body_field: return None @@ -131,6 +140,7 @@ def get_openapi_operation_request_body( field=body_field, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) field_info = cast(Body, body_field.field_info) request_media_type = field_info.media_type @@ -198,6 +208,9 @@ def get_openapi_path( operation_ids: Set[str], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue + ], ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: path = {} security_schemes: Dict[str, Any] = {} @@ -228,6 +241,7 @@ def get_openapi_path( all_route_params=all_route_params, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) parameters.extend(operation_parameters) if parameters: @@ -248,6 +262,7 @@ def get_openapi_path( body_field=route.body_field, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) if request_body_oai: operation["requestBody"] = request_body_oai @@ -264,6 +279,7 @@ def get_openapi_path( operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) callbacks[callback.name] = {callback.path: cb_path} operation["callbacks"] = callbacks @@ -293,6 +309,7 @@ def get_openapi_path( field=route.response_field, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) else: response_schema = {} @@ -325,6 +342,7 @@ def get_openapi_path( field=field, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) media_type = route_response_media_type or "application/json" additional_schema = ( @@ -437,7 +455,7 @@ def get_openapi( all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) model_name_map = get_compat_model_name_map(all_fields) schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) - definitions = get_definitions( + field_mapping, definitions = get_definitions( fields=all_fields, schema_generator=schema_generator, model_name_map=model_name_map, @@ -449,6 +467,7 @@ def get_openapi( operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) if result: path, security_schemes, path_definitions = result @@ -467,6 +486,7 @@ def get_openapi( operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, + field_mapping=field_mapping, ) if result: path, security_schemes, path_definitions = result diff --git a/fastapi/routing.py b/fastapi/routing.py index ce4e88c86..d8ff0579c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -446,7 +446,11 @@ class APIRoute(routing.Route): ), f"Status code {status_code} must not have a response body" response_name = "Response_" + self.unique_id self.response_field = create_response_field( - name=response_name, type_=self.response_model + name=response_name, + type_=self.response_model, + # TODO: This should actually set mode='serialization', just, that changes the schemas + # mode="serialization", + mode="validation", ) # Create a clone of the field, so that a Pydantic submodel is not returned # as is just because it's an instance of a subclass of a more limited class diff --git a/fastapi/utils.py b/fastapi/utils.py index 2efe7f15a..267d64ce8 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -28,6 +28,7 @@ from fastapi._compat import ( from fastapi.datastructures import DefaultPlaceholder, DefaultType from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo +from typing_extensions import Literal if TYPE_CHECKING: # pragma: nocover from .routing import APIRoute @@ -68,6 +69,7 @@ def create_response_field( model_config: Type[BaseConfig] = BaseConfig, field_info: Optional[FieldInfo] = None, alias: Optional[str] = None, + mode: Literal["validation", "serialization"] = "validation", ) -> ModelField: """ Create a new response field. Raises if type_ is invalid. @@ -80,7 +82,9 @@ def create_response_field( else: field_info = field_info or FieldInfo() kwargs = {"name": name, "field_info": field_info} - if not PYDANTIC_V2: + if PYDANTIC_V2: + kwargs.update({"mode": mode}) + else: kwargs.update( { "type_": type_,