Browse Source

♻️ Use new Pydantic v2 JSON Schema generator (#9813)

Co-authored-by: David Montague <[email protected]>
pull/9814/head
Sebastián Ramírez 2 years ago
committed by GitHub
parent
commit
d4e3dcfa3a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 23
      fastapi/_compat.py
  2. 22
      fastapi/openapi/utils.py
  3. 6
      fastapi/routing.py
  4. 6
      fastapi/utils.py

23
fastapi/_compat.py

@ -79,6 +79,7 @@ if PYDANTIC_V2:
class ModelField: class ModelField:
field_info: FieldInfo field_info: FieldInfo
name: str name: str
mode: Literal["validation", "serialization"] = "validation"
@property @property
def alias(self) -> str: def alias(self) -> str:
@ -178,9 +179,12 @@ if PYDANTIC_V2:
field: ModelField, field: ModelField,
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions # This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema) json_schema = field_mapping[(field, field.mode)]
if "$ref" not in json_schema: if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1 # TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
@ -197,12 +201,12 @@ if PYDANTIC_V2:
fields: List[ModelField], fields: List[ModelField],
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
inputs = [ inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields (field, field.mode, field._type_adapter.core_schema) for field in fields
] ]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type] field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value] return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
from fastapi import params from fastapi import params
@ -419,6 +423,9 @@ else:
field: ModelField, field: ModelField,
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions # This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return] return field_schema( # type: ignore[no-any-return]
@ -434,9 +441,11 @@ else:
fields: List[ModelField], fields: List[ModelField],
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
models = get_flat_models_from_fields(fields, known_models=set()) models = get_flat_models_from_fields(fields, known_models=set())
return get_model_definitions(flat_models=models, model_name_map=model_name_map) return {}, get_model_definitions(
flat_models=models, model_name_map=model_name_map
)
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field) return is_pv1_scalar_field(field)

22
fastapi/openapi/utils.py

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
from fastapi import routing from fastapi import routing
from fastapi._compat import ( from fastapi._compat import (
GenerateJsonSchema, GenerateJsonSchema,
JsonSchemaValue,
ModelField, ModelField,
Undefined, Undefined,
get_compat_model_name_map, get_compat_model_name_map,
@ -30,6 +31,7 @@ from fastapi.utils import (
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from typing_extensions import Literal
validation_error_definition = { validation_error_definition = {
"title": "ValidationError", "title": "ValidationError",
@ -90,6 +92,9 @@ def get_openapi_operation_parameters(
all_route_params: Sequence[ModelField], all_route_params: Sequence[ModelField],
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
parameters = [] parameters = []
for param in all_route_params: for param in all_route_params:
@ -101,6 +106,7 @@ def get_openapi_operation_parameters(
field=param, field=param,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
parameter = { parameter = {
"name": param.alias, "name": param.alias,
@ -123,6 +129,9 @@ def get_openapi_operation_request_body(
body_field: Optional[ModelField], body_field: Optional[ModelField],
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
if not body_field: if not body_field:
return None return None
@ -131,6 +140,7 @@ def get_openapi_operation_request_body(
field=body_field, field=body_field,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
field_info = cast(Body, body_field.field_info) field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type request_media_type = field_info.media_type
@ -198,6 +208,9 @@ def get_openapi_path(
operation_ids: Set[str], operation_ids: Set[str],
schema_generator: GenerateJsonSchema, schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap, model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {} path = {}
security_schemes: Dict[str, Any] = {} security_schemes: Dict[str, Any] = {}
@ -228,6 +241,7 @@ def get_openapi_path(
all_route_params=all_route_params, all_route_params=all_route_params,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
parameters.extend(operation_parameters) parameters.extend(operation_parameters)
if parameters: if parameters:
@ -248,6 +262,7 @@ def get_openapi_path(
body_field=route.body_field, body_field=route.body_field,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
if request_body_oai: if request_body_oai:
operation["requestBody"] = request_body_oai operation["requestBody"] = request_body_oai
@ -264,6 +279,7 @@ def get_openapi_path(
operation_ids=operation_ids, operation_ids=operation_ids,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
callbacks[callback.name] = {callback.path: cb_path} callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks operation["callbacks"] = callbacks
@ -293,6 +309,7 @@ def get_openapi_path(
field=route.response_field, field=route.response_field,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
else: else:
response_schema = {} response_schema = {}
@ -325,6 +342,7 @@ def get_openapi_path(
field=field, field=field,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
media_type = route_response_media_type or "application/json" media_type = route_response_media_type or "application/json"
additional_schema = ( additional_schema = (
@ -437,7 +455,7 @@ def get_openapi(
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields) model_name_map = get_compat_model_name_map(all_fields)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
definitions = get_definitions( field_mapping, definitions = get_definitions(
fields=all_fields, fields=all_fields,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
@ -449,6 +467,7 @@ def get_openapi(
operation_ids=operation_ids, operation_ids=operation_ids,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
if result: if result:
path, security_schemes, path_definitions = result path, security_schemes, path_definitions = result
@ -467,6 +486,7 @@ def get_openapi(
operation_ids=operation_ids, operation_ids=operation_ids,
schema_generator=schema_generator, schema_generator=schema_generator,
model_name_map=model_name_map, model_name_map=model_name_map,
field_mapping=field_mapping,
) )
if result: if result:
path, security_schemes, path_definitions = result path, security_schemes, path_definitions = result

6
fastapi/routing.py

@ -446,7 +446,11 @@ class APIRoute(routing.Route):
), f"Status code {status_code} must not have a response body" ), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id response_name = "Response_" + self.unique_id
self.response_field = create_response_field( self.response_field = create_response_field(
name=response_name, type_=self.response_model name=response_name,
type_=self.response_model,
# TODO: This should actually set mode='serialization', just, that changes the schemas
# mode="serialization",
mode="validation",
) )
# Create a clone of the field, so that a Pydantic submodel is not returned # Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class # as is just because it's an instance of a subclass of a more limited class

6
fastapi/utils.py

@ -28,6 +28,7 @@ from fastapi._compat import (
from fastapi.datastructures import DefaultPlaceholder, DefaultType from fastapi.datastructures import DefaultPlaceholder, DefaultType
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Literal
if TYPE_CHECKING: # pragma: nocover if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute from .routing import APIRoute
@ -68,6 +69,7 @@ def create_response_field(
model_config: Type[BaseConfig] = BaseConfig, model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None, field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None, alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField: ) -> ModelField:
""" """
Create a new response field. Raises if type_ is invalid. Create a new response field. Raises if type_ is invalid.
@ -80,7 +82,9 @@ def create_response_field(
else: else:
field_info = field_info or FieldInfo() field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info} kwargs = {"name": name, "field_info": field_info}
if not PYDANTIC_V2: if PYDANTIC_V2:
kwargs.update({"mode": mode})
else:
kwargs.update( kwargs.update(
{ {
"type_": type_, "type_": type_,

Loading…
Cancel
Save