@ -1,9 +1,10 @@
import http . client
from typing import Any , Dict , List , Optional , Sequence , Tuple , Type , cast
from enum import Enum
from typing import Any , Dict , List , Optional , Sequence , Set , Tuple , Type , Union , cast
from fastapi import routing
from fastapi . dependencies . models import Dependant
from fastapi . dependencies . utils import get_flat_dependant
from fastapi . dependencies . utils import get_flat_dependant , get_flat_params
from fastapi . encoders import jsonable_encoder
from fastapi . openapi . constants import (
METHODS_WITH_BODY ,
@ -15,11 +16,14 @@ from fastapi.params import Body, Param
from fastapi . utils import (
generate_operation_id_for_path ,
get_field_info ,
get_flat_models_from_routes ,
get_model_definitions ,
)
from pydantic import BaseModel
from pydantic . schema import field_schema , get_model_name_map
from pydantic . schema import (
field_schema ,
get_flat_models_from_fields ,
get_model_name_map ,
)
from pydantic . utils import lenient_issubclass
from starlette . responses import JSONResponse
from starlette . routing import BaseRoute
@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = {
}
def get_openapi_params ( dependant : Dependant ) - > List [ ModelField ] :
flat_dependant = get_flat_dependant ( dependant , skip_repeats = True )
return (
flat_dependant . path_params
+ flat_dependant . query_params
+ flat_dependant . header_params
+ flat_dependant . cookie_params
)
def get_openapi_security_definitions ( flat_dependant : Dependant ) - > Tuple [ Dict , List ] :
security_definitions = { }
operation_security = [ ]
@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
def get_openapi_operation_parameters (
* ,
all_route_params : Sequence [ ModelField ] ,
model_name_map : Dict [ Union [ Type [ BaseModel ] , Type [ Enum ] ] , str ]
) - > List [ Dict [ str , Any ] ] :
parameters = [ ]
for param in all_route_params :
field_info = get_field_info ( param )
field_info = cast ( Param , field_info )
# ignore mypy error until enum schemas are released
parameter = {
" name " : param . alias ,
" in " : field_info . in_ . value ,
" required " : param . required ,
" schema " : field_schema ( param , model_name_map = { } ) [ 0 ] ,
" schema " : field_schema (
param , model_name_map = model_name_map , ref_prefix = REF_PREFIX # type: ignore
) [ 0 ] ,
}
if field_info . description :
parameter [ " description " ] = field_info . description
@ -111,13 +110,16 @@ def get_openapi_operation_parameters(
def get_openapi_operation_request_body (
* , body_field : Optional [ ModelField ] , model_name_map : Dict [ Type [ BaseModel ] , str ]
* ,
body_field : Optional [ ModelField ] ,
model_name_map : Dict [ Union [ Type [ BaseModel ] , Type [ Enum ] ] , str ]
) - > Optional [ Dict ] :
if not body_field :
return None
assert isinstance ( body_field , ModelField )
# ignore mypy error until enum schemas are released
body_schema , _ , _ = field_schema (
body_field , model_name_map = model_name_map , ref_prefix = REF_PREFIX
body_field , model_name_map = model_name_map , ref_prefix = REF_PREFIX # type: ignore
)
field_info = cast ( Body , get_field_info ( body_field ) )
request_media_type = field_info . media_type
@ -176,8 +178,10 @@ def get_openapi_path(
operation . setdefault ( " security " , [ ] ) . extend ( operation_security )
if security_definitions :
security_schemes . update ( security_definitions )
all_route_params = get_openapi_params ( route . dependant )
operation_parameters = get_openapi_operation_parameters ( all_route_params )
all_route_params = get_flat_params ( route . dependant )
operation_parameters = get_openapi_operation_parameters (
all_route_params = all_route_params , model_name_map = model_name_map
)
parameters . extend ( operation_parameters )
if parameters :
operation [ " parameters " ] = list (
@ -270,6 +274,38 @@ def get_openapi_path(
return path , security_schemes , definitions
def get_flat_models_from_routes (
routes : Sequence [ BaseRoute ] ,
) - > Set [ Union [ Type [ BaseModel ] , Type [ Enum ] ] ] :
body_fields_from_routes : List [ ModelField ] = [ ]
responses_from_routes : List [ ModelField ] = [ ]
request_fields_from_routes : List [ ModelField ] = [ ]
callback_flat_models : Set [ Union [ Type [ BaseModel ] , Type [ Enum ] ] ] = set ( )
for route in routes :
if getattr ( route , " include_in_schema " , None ) and isinstance (
route , routing . APIRoute
) :
if route . body_field :
assert isinstance (
route . body_field , ModelField
) , " A request body must be a Pydantic Field "
body_fields_from_routes . append ( route . body_field )
if route . response_field :
responses_from_routes . append ( route . response_field )
if route . response_fields :
responses_from_routes . extend ( route . response_fields . values ( ) )
if route . callbacks :
callback_flat_models | = get_flat_models_from_routes ( route . callbacks )
params = get_flat_params ( route . dependant )
request_fields_from_routes . extend ( params )
flat_models = callback_flat_models | get_flat_models_from_fields (
body_fields_from_routes + responses_from_routes + request_fields_from_routes ,
known_models = set ( ) ,
)
return flat_models
def get_openapi (
* ,
title : str ,
@ -286,9 +322,11 @@ def get_openapi(
components : Dict [ str , Dict ] = { }
paths : Dict [ str , Dict ] = { }
flat_models = get_flat_models_from_routes ( routes )
model_name_map = get_model_name_map ( flat_models )
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map ( flat_models ) # type: ignore
# ignore mypy error until enum schemas are released
definitions = get_model_definitions (
flat_models = flat_models , model_name_map = model_name_map
flat_models = flat_models , model_name_map = model_name_map # type: ignore
)
for route in routes :
if isinstance ( route , routing . APIRoute ) :