Browse Source

♻️ Use new Pydantic v2 JSON Schema generator (#9813)

Co-authored-by: David Montague <[email protected]>
pull/9814/head
Sebastián Ramírez 2 years ago
committed by GitHub
parent
commit
d4e3dcfa3a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 23
      fastapi/_compat.py
  2. 22
      fastapi/openapi/utils.py
  3. 6
      fastapi/routing.py
  4. 6
      fastapi/utils.py

23
fastapi/_compat.py

@ -79,6 +79,7 @@ if PYDANTIC_V2:
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
@property
def alias(self) -> str:
@ -178,9 +179,12 @@ if PYDANTIC_V2:
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
json_schema = field_mapping[(field, field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
@ -197,12 +201,12 @@ if PYDANTIC_V2:
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
inputs = [
(field, "validation", field._type_adapter.core_schema) for field in fields
(field, field.mode, field._type_adapter.core_schema) for field in fields
]
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return definitions # type: ignore[return-value]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
@ -419,6 +423,9 @@ else:
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
@ -434,9 +441,11 @@ else:
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> Dict[str, Dict[str, Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
return {}, get_model_definitions(
flat_models=models, model_name_map=model_name_map
)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)

22
fastapi/openapi/utils.py

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
from fastapi import routing
from fastapi._compat import (
GenerateJsonSchema,
JsonSchemaValue,
ModelField,
Undefined,
get_compat_model_name_map,
@ -30,6 +31,7 @@ from fastapi.utils import (
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from typing_extensions import Literal
validation_error_definition = {
"title": "ValidationError",
@ -90,6 +92,9 @@ def get_openapi_operation_parameters(
all_route_params: Sequence[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
@ -101,6 +106,7 @@ def get_openapi_operation_parameters(
field=param,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
parameter = {
"name": param.alias,
@ -123,6 +129,9 @@ def get_openapi_operation_request_body(
body_field: Optional[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
@ -131,6 +140,7 @@ def get_openapi_operation_request_body(
field=body_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
@ -198,6 +208,9 @@ def get_openapi_path(
operation_ids: Set[str],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
@ -228,6 +241,7 @@ def get_openapi_path(
all_route_params=all_route_params,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
parameters.extend(operation_parameters)
if parameters:
@ -248,6 +262,7 @@ def get_openapi_path(
body_field=route.body_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
if request_body_oai:
operation["requestBody"] = request_body_oai
@ -264,6 +279,7 @@ def get_openapi_path(
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
@ -293,6 +309,7 @@ def get_openapi_path(
field=route.response_field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
else:
response_schema = {}
@ -325,6 +342,7 @@ def get_openapi_path(
field=field,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
media_type = route_response_media_type or "application/json"
additional_schema = (
@ -437,7 +455,7 @@ def get_openapi(
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
definitions = get_definitions(
field_mapping, definitions = get_definitions(
fields=all_fields,
schema_generator=schema_generator,
model_name_map=model_name_map,
@ -449,6 +467,7 @@ def get_openapi(
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
if result:
path, security_schemes, path_definitions = result
@ -467,6 +486,7 @@ def get_openapi(
operation_ids=operation_ids,
schema_generator=schema_generator,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
if result:
path, security_schemes, path_definitions = result

6
fastapi/routing.py

@ -446,7 +446,11 @@ class APIRoute(routing.Route):
), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id
self.response_field = create_response_field(
name=response_name, type_=self.response_model
name=response_name,
type_=self.response_model,
# TODO: This should actually set mode='serialization', just, that changes the schemas
# mode="serialization",
mode="validation",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class

6
fastapi/utils.py

@ -28,6 +28,7 @@ from fastapi._compat import (
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing_extensions import Literal
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
@ -68,6 +69,7 @@ def create_response_field(
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
@ -80,7 +82,9 @@ def create_response_field(
else:
field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info}
if not PYDANTIC_V2:
if PYDANTIC_V2:
kwargs.update({"mode": mode})
else:
kwargs.update(
{
"type_": type_,

Loading…
Cancel
Save