|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, |
|
|
|
from fastapi import routing |
|
|
|
from fastapi._compat import ( |
|
|
|
GenerateJsonSchema, |
|
|
|
JsonSchemaValue, |
|
|
|
ModelField, |
|
|
|
Undefined, |
|
|
|
get_compat_model_name_map, |
|
|
@ -30,6 +31,7 @@ from fastapi.utils import ( |
|
|
|
from starlette.responses import JSONResponse |
|
|
|
from starlette.routing import BaseRoute |
|
|
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY |
|
|
|
from typing_extensions import Literal |
|
|
|
|
|
|
|
validation_error_definition = { |
|
|
|
"title": "ValidationError", |
|
|
@ -90,6 +92,9 @@ def get_openapi_operation_parameters( |
|
|
|
all_route_params: Sequence[ModelField], |
|
|
|
schema_generator: GenerateJsonSchema, |
|
|
|
model_name_map: ModelNameMap, |
|
|
|
field_mapping: Dict[ |
|
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue |
|
|
|
], |
|
|
|
) -> List[Dict[str, Any]]: |
|
|
|
parameters = [] |
|
|
|
for param in all_route_params: |
|
|
@ -101,6 +106,7 @@ def get_openapi_operation_parameters( |
|
|
|
field=param, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
parameter = { |
|
|
|
"name": param.alias, |
|
|
@ -123,6 +129,9 @@ def get_openapi_operation_request_body( |
|
|
|
body_field: Optional[ModelField], |
|
|
|
schema_generator: GenerateJsonSchema, |
|
|
|
model_name_map: ModelNameMap, |
|
|
|
field_mapping: Dict[ |
|
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue |
|
|
|
], |
|
|
|
) -> Optional[Dict[str, Any]]: |
|
|
|
if not body_field: |
|
|
|
return None |
|
|
@ -131,6 +140,7 @@ def get_openapi_operation_request_body( |
|
|
|
field=body_field, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
field_info = cast(Body, body_field.field_info) |
|
|
|
request_media_type = field_info.media_type |
|
|
@ -198,6 +208,9 @@ def get_openapi_path( |
|
|
|
operation_ids: Set[str], |
|
|
|
schema_generator: GenerateJsonSchema, |
|
|
|
model_name_map: ModelNameMap, |
|
|
|
field_mapping: Dict[ |
|
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue |
|
|
|
], |
|
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: |
|
|
|
path = {} |
|
|
|
security_schemes: Dict[str, Any] = {} |
|
|
@ -228,6 +241,7 @@ def get_openapi_path( |
|
|
|
all_route_params=all_route_params, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
parameters.extend(operation_parameters) |
|
|
|
if parameters: |
|
|
@ -248,6 +262,7 @@ def get_openapi_path( |
|
|
|
body_field=route.body_field, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
if request_body_oai: |
|
|
|
operation["requestBody"] = request_body_oai |
|
|
@ -264,6 +279,7 @@ def get_openapi_path( |
|
|
|
operation_ids=operation_ids, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
callbacks[callback.name] = {callback.path: cb_path} |
|
|
|
operation["callbacks"] = callbacks |
|
|
@ -293,6 +309,7 @@ def get_openapi_path( |
|
|
|
field=route.response_field, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
else: |
|
|
|
response_schema = {} |
|
|
@ -325,6 +342,7 @@ def get_openapi_path( |
|
|
|
field=field, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
media_type = route_response_media_type or "application/json" |
|
|
|
additional_schema = ( |
|
|
@ -437,7 +455,7 @@ def get_openapi( |
|
|
|
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) |
|
|
|
model_name_map = get_compat_model_name_map(all_fields) |
|
|
|
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) |
|
|
|
definitions = get_definitions( |
|
|
|
field_mapping, definitions = get_definitions( |
|
|
|
fields=all_fields, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
@ -449,6 +467,7 @@ def get_openapi( |
|
|
|
operation_ids=operation_ids, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
if result: |
|
|
|
path, security_schemes, path_definitions = result |
|
|
@ -467,6 +486,7 @@ def get_openapi( |
|
|
|
operation_ids=operation_ids, |
|
|
|
schema_generator=schema_generator, |
|
|
|
model_name_map=model_name_map, |
|
|
|
field_mapping=field_mapping, |
|
|
|
) |
|
|
|
if result: |
|
|
|
path, security_schemes, path_definitions = result |
|
|
|