pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
5.9 KiB
158 lines
5.9 KiB
import re
|
|
from dataclasses import is_dataclass
|
|
from typing import Any, Dict, List, Sequence, Set, Type, cast
|
|
|
|
from fastapi import routing
|
|
from fastapi.logger import logger
|
|
from fastapi.openapi.constants import REF_PREFIX
|
|
from pydantic import BaseConfig, BaseModel, create_model
|
|
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
|
from pydantic.utils import lenient_issubclass
|
|
from starlette.routing import BaseRoute
|
|
|
|
try:
|
|
from pydantic.fields import FieldInfo, ModelField
|
|
|
|
PYDANTIC_1 = True
|
|
except ImportError: # pragma: nocover
|
|
# TODO: remove when removing support for Pydantic < 1.0.0
|
|
from pydantic.fields import Field as ModelField # type: ignore
|
|
from pydantic import Schema as FieldInfo # type: ignore
|
|
|
|
logger.warning(
|
|
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
|
|
"removed soon."
|
|
)
|
|
PYDANTIC_1 = False
|
|
|
|
|
|
# TODO: remove when removing support for Pydantic < 1.0.0
|
|
def get_field_info(field: ModelField) -> FieldInfo:
|
|
if PYDANTIC_1:
|
|
return field.field_info # type: ignore
|
|
else:
|
|
return field.schema # type: ignore # pragma: nocover
|
|
|
|
|
|
# TODO: remove when removing support for Pydantic < 1.0.0
|
|
def warning_response_model_skip_defaults_deprecated() -> None:
|
|
logger.warning( # pragma: nocover
|
|
"response_model_skip_defaults has been deprecated in favor of "
|
|
"response_model_exclude_unset to keep in line with Pydantic v1, support for "
|
|
"it will be removed soon."
|
|
)
|
|
|
|
|
|
def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
|
|
body_fields_from_routes: List[ModelField] = []
|
|
responses_from_routes: List[ModelField] = []
|
|
for route in routes:
|
|
if getattr(route, "include_in_schema", None) and isinstance(
|
|
route, routing.APIRoute
|
|
):
|
|
if route.body_field:
|
|
assert isinstance(
|
|
route.body_field, ModelField
|
|
), "A request body must be a Pydantic Field"
|
|
body_fields_from_routes.append(route.body_field)
|
|
if route.response_field:
|
|
responses_from_routes.append(route.response_field)
|
|
if route.response_fields:
|
|
responses_from_routes.extend(route.response_fields.values())
|
|
flat_models = get_flat_models_from_fields(
|
|
body_fields_from_routes + responses_from_routes, known_models=set()
|
|
)
|
|
return flat_models
|
|
|
|
|
|
def get_model_definitions(
|
|
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
|
|
) -> Dict[str, Any]:
|
|
definitions: Dict[str, Dict] = {}
|
|
for model in flat_models:
|
|
m_schema, m_definitions, m_nested_models = model_process_schema(
|
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
|
)
|
|
definitions.update(m_definitions)
|
|
model_name = model_name_map[model]
|
|
definitions[model_name] = m_schema
|
|
return definitions
|
|
|
|
|
|
def get_path_param_names(path: str) -> Set[str]:
|
|
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
|
|
|
|
|
def create_cloned_field(field: ModelField) -> ModelField:
|
|
original_type = field.type_
|
|
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
|
original_type = original_type.__pydantic_model__ # type: ignore
|
|
use_type = original_type
|
|
if lenient_issubclass(original_type, BaseModel):
|
|
original_type = cast(Type[BaseModel], original_type)
|
|
use_type = create_model(
|
|
original_type.__name__, __config__=original_type.__config__
|
|
)
|
|
for f in original_type.__fields__.values():
|
|
use_type.__fields__[f.name] = f
|
|
use_type.__validators__ = original_type.__validators__
|
|
if PYDANTIC_1:
|
|
new_field = ModelField(
|
|
name=field.name,
|
|
type_=use_type,
|
|
class_validators={},
|
|
default=None,
|
|
required=False,
|
|
model_config=BaseConfig,
|
|
field_info=FieldInfo(None),
|
|
)
|
|
else: # pragma: nocover
|
|
new_field = ModelField( # type: ignore
|
|
name=field.name,
|
|
type_=use_type,
|
|
class_validators={},
|
|
default=None,
|
|
required=False,
|
|
model_config=BaseConfig,
|
|
schema=FieldInfo(None),
|
|
)
|
|
new_field.has_alias = field.has_alias
|
|
new_field.alias = field.alias
|
|
new_field.class_validators = field.class_validators
|
|
new_field.default = field.default
|
|
new_field.required = field.required
|
|
new_field.model_config = field.model_config
|
|
if PYDANTIC_1:
|
|
new_field.field_info = field.field_info
|
|
else: # pragma: nocover
|
|
new_field.schema = field.schema # type: ignore
|
|
new_field.allow_none = field.allow_none
|
|
new_field.validate_always = field.validate_always
|
|
if field.sub_fields:
|
|
new_field.sub_fields = [
|
|
create_cloned_field(sub_field) for sub_field in field.sub_fields
|
|
]
|
|
if field.key_field:
|
|
new_field.key_field = create_cloned_field(field.key_field)
|
|
new_field.validators = field.validators
|
|
if PYDANTIC_1:
|
|
new_field.pre_validators = field.pre_validators
|
|
new_field.post_validators = field.post_validators
|
|
else: # pragma: nocover
|
|
new_field.whole_pre_validators = field.whole_pre_validators # type: ignore
|
|
new_field.whole_post_validators = field.whole_post_validators # type: ignore
|
|
new_field.parse_json = field.parse_json
|
|
new_field.shape = field.shape
|
|
try:
|
|
new_field.populate_validators()
|
|
except AttributeError: # pragma: nocover
|
|
# TODO: remove when removing support for Pydantic < 1.0.0
|
|
new_field._populate_validators() # type: ignore
|
|
return new_field
|
|
|
|
|
|
def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
|
|
operation_id = name + path
|
|
operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
|
|
operation_id = operation_id + "_" + method.lower()
|
|
return operation_id
|
|
|