18 changed files with 1376 additions and 703 deletions
@ -1,3 +1,3 @@ |
|||||
"""Fast API framework, fast high performance, fast to learn, fast to code""" |
"""Fast API framework, fast high performance, fast to learn, fast to code""" |
||||
|
|
||||
__version__ = '0.1' |
__version__ = "0.1" |
||||
|
@ -0,0 +1,46 @@ |
|||||
|
from typing import Any, Callable, Dict, List, Sequence, Tuple |
||||
|
|
||||
|
from starlette.concurrency import run_in_threadpool |
||||
|
from starlette.requests import Request |
||||
|
|
||||
|
from fastapi.security.base import SecurityBase |
||||
|
from pydantic import BaseConfig, Schema |
||||
|
from pydantic.error_wrappers import ErrorWrapper |
||||
|
from pydantic.errors import MissingError |
||||
|
from pydantic.fields import Field, Required |
||||
|
from pydantic.schema import get_annotation_from_schema |
||||
|
|
||||
|
param_supported_types = (str, int, float, bool) |
||||
|
|
||||
|
|
||||
|
class SecurityRequirement: |
||||
|
def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None): |
||||
|
self.security_scheme = security_scheme |
||||
|
self.scopes = scopes |
||||
|
|
||||
|
|
||||
|
class Dependant: |
||||
|
def __init__( |
||||
|
self, |
||||
|
*, |
||||
|
path_params: List[Field] = None, |
||||
|
query_params: List[Field] = None, |
||||
|
header_params: List[Field] = None, |
||||
|
cookie_params: List[Field] = None, |
||||
|
body_params: List[Field] = None, |
||||
|
dependencies: List["Dependant"] = None, |
||||
|
security_schemes: List[SecurityRequirement] = None, |
||||
|
name: str = None, |
||||
|
call: Callable = None, |
||||
|
request_param_name: str = None, |
||||
|
) -> None: |
||||
|
self.path_params = path_params or [] |
||||
|
self.query_params = query_params or [] |
||||
|
self.header_params = header_params or [] |
||||
|
self.cookie_params = cookie_params or [] |
||||
|
self.body_params = body_params or [] |
||||
|
self.dependencies = dependencies or [] |
||||
|
self.security_requirements = security_schemes or [] |
||||
|
self.request_param_name = request_param_name |
||||
|
self.name = name |
||||
|
self.call = call |
@ -0,0 +1,327 @@ |
|||||
|
import asyncio |
||||
|
import inspect |
||||
|
from copy import deepcopy |
||||
|
from typing import Any, Callable, Dict, List, Tuple |
||||
|
|
||||
|
from starlette.concurrency import run_in_threadpool |
||||
|
from starlette.requests import Request |
||||
|
|
||||
|
from fastapi import params |
||||
|
from fastapi.dependencies.models import Dependant, SecurityRequirement |
||||
|
from fastapi.security.base import SecurityBase |
||||
|
from fastapi.utils import get_path_param_names |
||||
|
from pydantic import BaseConfig, Schema, create_model |
||||
|
from pydantic.error_wrappers import ErrorWrapper |
||||
|
from pydantic.errors import MissingError |
||||
|
from pydantic.fields import Field, Required |
||||
|
from pydantic.schema import get_annotation_from_schema |
||||
|
from pydantic.utils import lenient_issubclass |
||||
|
|
||||
|
param_supported_types = (str, int, float, bool) |
||||
|
|
||||
|
|
||||
|
def get_sub_dependant(*, param: inspect.Parameter, path: str): |
||||
|
depends: params.Depends = param.default |
||||
|
if depends.dependency: |
||||
|
dependency = depends.dependency |
||||
|
else: |
||||
|
dependency = param.annotation |
||||
|
assert callable(dependency) |
||||
|
sub_dependant = get_dependant(path=path, call=dependency, name=param.name) |
||||
|
if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase): |
||||
|
security_requirement = SecurityRequirement( |
||||
|
security_scheme=dependency, scopes=depends.scopes |
||||
|
) |
||||
|
sub_dependant.security_requirements.append(security_requirement) |
||||
|
return sub_dependant |
||||
|
|
||||
|
|
||||
|
def get_flat_dependant(dependant: Dependant): |
||||
|
flat_dependant = Dependant( |
||||
|
path_params=dependant.path_params.copy(), |
||||
|
query_params=dependant.query_params.copy(), |
||||
|
header_params=dependant.header_params.copy(), |
||||
|
cookie_params=dependant.cookie_params.copy(), |
||||
|
body_params=dependant.body_params.copy(), |
||||
|
security_schemes=dependant.security_requirements.copy(), |
||||
|
) |
||||
|
for sub_dependant in dependant.dependencies: |
||||
|
if sub_dependant is dependant: |
||||
|
raise ValueError("recursion", dependant.dependencies) |
||||
|
flat_sub = get_flat_dependant(sub_dependant) |
||||
|
flat_dependant.path_params.extend(flat_sub.path_params) |
||||
|
flat_dependant.query_params.extend(flat_sub.query_params) |
||||
|
flat_dependant.header_params.extend(flat_sub.header_params) |
||||
|
flat_dependant.cookie_params.extend(flat_sub.cookie_params) |
||||
|
flat_dependant.body_params.extend(flat_sub.body_params) |
||||
|
flat_dependant.security_requirements.extend(flat_sub.security_requirements) |
||||
|
return flat_dependant |
||||
|
|
||||
|
|
||||
|
def get_dependant(*, path: str, call: Callable, name: str = None): |
||||
|
path_param_names = get_path_param_names(path) |
||||
|
endpoint_signature = inspect.signature(call) |
||||
|
signature_params = endpoint_signature.parameters |
||||
|
dependant = Dependant(call=call, name=name) |
||||
|
for param_name in signature_params: |
||||
|
param = signature_params[param_name] |
||||
|
if isinstance(param.default, params.Depends): |
||||
|
sub_dependant = get_sub_dependant(param=param, path=path) |
||||
|
dependant.dependencies.append(sub_dependant) |
||||
|
for param_name in signature_params: |
||||
|
param = signature_params[param_name] |
||||
|
if ( |
||||
|
(param.default == param.empty) or isinstance(param.default, params.Path) |
||||
|
) and (param_name in path_param_names): |
||||
|
assert lenient_issubclass( |
||||
|
param.annotation, param_supported_types |
||||
|
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}" |
||||
|
param = signature_params[param_name] |
||||
|
add_param_to_fields( |
||||
|
param=param, |
||||
|
dependant=dependant, |
||||
|
default_schema=params.Path, |
||||
|
force_type=params.ParamTypes.path, |
||||
|
) |
||||
|
elif (param.default == param.empty or param.default is None) and ( |
||||
|
param.annotation == param.empty |
||||
|
or lenient_issubclass(param.annotation, param_supported_types) |
||||
|
): |
||||
|
add_param_to_fields( |
||||
|
param=param, dependant=dependant, default_schema=params.Query |
||||
|
) |
||||
|
elif isinstance(param.default, params.Param): |
||||
|
if param.annotation != param.empty: |
||||
|
assert lenient_issubclass( |
||||
|
param.annotation, param_supported_types |
||||
|
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" |
||||
|
add_param_to_fields( |
||||
|
param=param, dependant=dependant, default_schema=params.Query |
||||
|
) |
||||
|
elif lenient_issubclass(param.annotation, Request): |
||||
|
dependant.request_param_name = param_name |
||||
|
elif not isinstance(param.default, params.Depends): |
||||
|
add_param_to_body_fields(param=param, dependant=dependant) |
||||
|
return dependant |
||||
|
|
||||
|
|
||||
|
def add_param_to_fields( |
||||
|
*, |
||||
|
param: inspect.Parameter, |
||||
|
dependant: Dependant, |
||||
|
default_schema=params.Param, |
||||
|
force_type: params.ParamTypes = None, |
||||
|
): |
||||
|
default_value = Required |
||||
|
if not param.default == param.empty: |
||||
|
default_value = param.default |
||||
|
if isinstance(default_value, params.Param): |
||||
|
schema = default_value |
||||
|
default_value = schema.default |
||||
|
if schema.in_ is None: |
||||
|
schema.in_ = default_schema.in_ |
||||
|
if force_type: |
||||
|
schema.in_ = force_type |
||||
|
else: |
||||
|
schema = default_schema(default_value) |
||||
|
required = default_value == Required |
||||
|
annotation = Any |
||||
|
if not param.annotation == param.empty: |
||||
|
annotation = param.annotation |
||||
|
annotation = get_annotation_from_schema(annotation, schema) |
||||
|
field = Field( |
||||
|
name=param.name, |
||||
|
type_=annotation, |
||||
|
default=None if required else default_value, |
||||
|
alias=schema.alias or param.name, |
||||
|
required=required, |
||||
|
model_config=BaseConfig(), |
||||
|
class_validators=[], |
||||
|
schema=schema, |
||||
|
) |
||||
|
if schema.in_ == params.ParamTypes.path: |
||||
|
dependant.path_params.append(field) |
||||
|
elif schema.in_ == params.ParamTypes.query: |
||||
|
dependant.query_params.append(field) |
||||
|
elif schema.in_ == params.ParamTypes.header: |
||||
|
dependant.header_params.append(field) |
||||
|
else: |
||||
|
assert ( |
||||
|
schema.in_ == params.ParamTypes.cookie |
||||
|
), f"non-body parameters must be in path, query, header or cookie: {param.name}" |
||||
|
dependant.cookie_params.append(field) |
||||
|
|
||||
|
|
||||
|
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): |
||||
|
default_value = Required |
||||
|
if not param.default == param.empty: |
||||
|
default_value = param.default |
||||
|
if isinstance(default_value, Schema): |
||||
|
schema = default_value |
||||
|
default_value = schema.default |
||||
|
else: |
||||
|
schema = Schema(default_value) |
||||
|
required = default_value == Required |
||||
|
annotation = get_annotation_from_schema(param.annotation, schema) |
||||
|
field = Field( |
||||
|
name=param.name, |
||||
|
type_=annotation, |
||||
|
default=None if required else default_value, |
||||
|
alias=schema.alias or param.name, |
||||
|
required=required, |
||||
|
model_config=BaseConfig, |
||||
|
class_validators=[], |
||||
|
schema=schema, |
||||
|
) |
||||
|
dependant.body_params.append(field) |
||||
|
|
||||
|
|
||||
|
def is_coroutine_callable(call: Callable = None): |
||||
|
if not call: |
||||
|
return False |
||||
|
if inspect.isfunction(call): |
||||
|
return asyncio.iscoroutinefunction(call) |
||||
|
if inspect.isclass(call): |
||||
|
return False |
||||
|
call = getattr(call, "__call__", None) |
||||
|
if not call: |
||||
|
return False |
||||
|
return asyncio.iscoroutinefunction(call) |
||||
|
|
||||
|
|
||||
|
async def solve_dependencies( |
||||
|
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None |
||||
|
): |
||||
|
values: Dict[str, Any] = {} |
||||
|
errors: List[ErrorWrapper] = [] |
||||
|
for sub_dependant in dependant.dependencies: |
||||
|
sub_values, sub_errors = await solve_dependencies( |
||||
|
request=request, dependant=sub_dependant, body=body |
||||
|
) |
||||
|
if sub_errors: |
||||
|
return {}, errors |
||||
|
if sub_dependant.call and is_coroutine_callable(sub_dependant.call): |
||||
|
solved = await sub_dependant.call(**sub_values) |
||||
|
else: |
||||
|
solved = await run_in_threadpool(sub_dependant.call, **sub_values) |
||||
|
values[ |
||||
|
sub_dependant.name |
||||
|
] = solved # type: ignore # Sub-dependants always have a name |
||||
|
path_values, path_errors = request_params_to_args( |
||||
|
dependant.path_params, request.path_params |
||||
|
) |
||||
|
query_values, query_errors = request_params_to_args( |
||||
|
dependant.query_params, request.query_params |
||||
|
) |
||||
|
header_values, header_errors = request_params_to_args( |
||||
|
dependant.header_params, request.headers |
||||
|
) |
||||
|
cookie_values, cookie_errors = request_params_to_args( |
||||
|
dependant.cookie_params, request.cookies |
||||
|
) |
||||
|
values.update(path_values) |
||||
|
values.update(query_values) |
||||
|
values.update(header_values) |
||||
|
values.update(cookie_values) |
||||
|
errors = path_errors + query_errors + header_errors + cookie_errors |
||||
|
if dependant.body_params: |
||||
|
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above |
||||
|
dependant.body_params, body |
||||
|
) |
||||
|
values.update(body_values) |
||||
|
errors.extend(body_errors) |
||||
|
if dependant.request_param_name: |
||||
|
values[dependant.request_param_name] = request |
||||
|
return values, errors |
||||
|
|
||||
|
|
||||
|
def request_params_to_args( |
||||
|
required_params: List[Field], received_params: Dict[str, Any] |
||||
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
||||
|
values = {} |
||||
|
errors = [] |
||||
|
for field in required_params: |
||||
|
value = received_params.get(field.alias) |
||||
|
if value is None: |
||||
|
if field.required: |
||||
|
errors.append( |
||||
|
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) |
||||
|
) |
||||
|
else: |
||||
|
values[field.name] = deepcopy(field.default) |
||||
|
continue |
||||
|
v_, errors_ = field.validate( |
||||
|
value, values, loc=(field.schema.in_.value, field.alias) |
||||
|
) |
||||
|
if isinstance(errors_, ErrorWrapper): |
||||
|
errors.append(errors_) |
||||
|
elif isinstance(errors_, list): |
||||
|
errors.extend(errors_) |
||||
|
else: |
||||
|
values[field.name] = v_ |
||||
|
return values, errors |
||||
|
|
||||
|
|
||||
|
async def request_body_to_args( |
||||
|
required_params: List[Field], received_body: Dict[str, Any] |
||||
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
||||
|
values = {} |
||||
|
errors = [] |
||||
|
if required_params: |
||||
|
field = required_params[0] |
||||
|
embed = getattr(field.schema, "embed", None) |
||||
|
if len(required_params) == 1 and not embed: |
||||
|
received_body = {field.alias: received_body} |
||||
|
for field in required_params: |
||||
|
value = received_body.get(field.alias) |
||||
|
if value is None: |
||||
|
if field.required: |
||||
|
errors.append( |
||||
|
ErrorWrapper( |
||||
|
MissingError(), loc=("body", field.alias), config=BaseConfig |
||||
|
) |
||||
|
) |
||||
|
else: |
||||
|
values[field.name] = deepcopy(field.default) |
||||
|
continue |
||||
|
v_, errors_ = field.validate(value, values, loc=("body", field.alias)) |
||||
|
if isinstance(errors_, ErrorWrapper): |
||||
|
errors.append(errors_) |
||||
|
elif isinstance(errors_, list): |
||||
|
errors.extend(errors_) |
||||
|
else: |
||||
|
values[field.name] = v_ |
||||
|
return values, errors |
||||
|
|
||||
|
|
||||
|
def get_body_field(*, dependant: Dependant, name: str): |
||||
|
flat_dependant = get_flat_dependant(dependant) |
||||
|
if not flat_dependant.body_params: |
||||
|
return None |
||||
|
first_param = flat_dependant.body_params[0] |
||||
|
embed = getattr(first_param.schema, "embed", None) |
||||
|
if len(flat_dependant.body_params) == 1 and not embed: |
||||
|
return first_param |
||||
|
model_name = "Body_" + name |
||||
|
BodyModel = create_model(model_name) |
||||
|
for f in flat_dependant.body_params: |
||||
|
BodyModel.__fields__[f.name] = f |
||||
|
required = any(True for f in flat_dependant.body_params if f.required) |
||||
|
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): |
||||
|
BodySchema = params.File |
||||
|
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): |
||||
|
BodySchema = params.Form |
||||
|
else: |
||||
|
BodySchema = params.Body |
||||
|
|
||||
|
field = Field( |
||||
|
name="body", |
||||
|
type_=BodyModel, |
||||
|
default=None, |
||||
|
required=required, |
||||
|
model_config=BaseConfig, |
||||
|
class_validators=[], |
||||
|
alias="body", |
||||
|
schema=BodySchema(None), |
||||
|
) |
||||
|
return field |
@ -1,33 +1,44 @@ |
|||||
|
from enum import Enum |
||||
from types import GeneratorType |
from types import GeneratorType |
||||
from typing import Set |
from typing import Set |
||||
|
|
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
from enum import Enum |
|
||||
from pydantic.json import pydantic_encoder |
from pydantic.json import pydantic_encoder |
||||
|
|
||||
|
|
||||
def jsonable_encoder( |
def jsonable_encoder( |
||||
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True, |
obj, |
||||
|
include: Set[str] = None, |
||||
|
exclude: Set[str] = set(), |
||||
|
by_alias: bool = False, |
||||
|
include_none=True, |
||||
): |
): |
||||
if isinstance(obj, BaseModel): |
if isinstance(obj, BaseModel): |
||||
return jsonable_encoder( |
return jsonable_encoder( |
||||
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none |
obj.dict(include=include, exclude=exclude, by_alias=by_alias), |
||||
|
include_none=include_none, |
||||
) |
) |
||||
elif isinstance(obj, Enum): |
if isinstance(obj, Enum): |
||||
return obj.value |
return obj.value |
||||
if isinstance(obj, (str, int, float, type(None))): |
if isinstance(obj, (str, int, float, type(None))): |
||||
return obj |
return obj |
||||
if isinstance(obj, dict): |
if isinstance(obj, dict): |
||||
return { |
return { |
||||
jsonable_encoder( |
jsonable_encoder( |
||||
key, by_alias=by_alias, include_none=include_none, |
key, by_alias=by_alias, include_none=include_none |
||||
): jsonable_encoder( |
): jsonable_encoder(value, by_alias=by_alias, include_none=include_none) |
||||
value, by_alias=by_alias, include_none=include_none, |
for key, value in obj.items() |
||||
) |
if value is not None or include_none |
||||
for key, value in obj.items() if value is not None or include_none |
|
||||
} |
} |
||||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): |
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): |
||||
return [ |
return [ |
||||
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none) |
jsonable_encoder( |
||||
|
item, |
||||
|
include=include, |
||||
|
exclude=exclude, |
||||
|
by_alias=by_alias, |
||||
|
include_none=include_none, |
||||
|
) |
||||
for item in obj |
for item in obj |
||||
] |
] |
||||
return pydantic_encoder(obj) |
return pydantic_encoder(obj) |
@ -0,0 +1,2 @@ |
|||||
|
METHODS_WITH_BODY = set(("POST", "PUT")) |
||||
|
REF_PREFIX = "#/components/schemas/" |
@ -0,0 +1,347 @@ |
|||||
|
import logging |
||||
|
from enum import Enum |
||||
|
from typing import Any, Dict, List, Optional, Union |
||||
|
|
||||
|
from pydantic import BaseModel, Schema as PSchema |
||||
|
from pydantic.types import UrlStr |
||||
|
|
||||
|
try: |
||||
|
import pydantic.types.EmailStr |
||||
|
from pydantic.types import EmailStr |
||||
|
except ImportError: |
||||
|
logging.warning( |
||||
|
"email-validator not installed, email fields will be treated as str" |
||||
|
) |
||||
|
|
||||
|
class EmailStr(str): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class Contact(BaseModel): |
||||
|
name: Optional[str] = None |
||||
|
url: Optional[UrlStr] = None |
||||
|
email: Optional[EmailStr] = None |
||||
|
|
||||
|
|
||||
|
class License(BaseModel): |
||||
|
name: str |
||||
|
url: Optional[UrlStr] = None |
||||
|
|
||||
|
|
||||
|
class Info(BaseModel): |
||||
|
title: str |
||||
|
description: Optional[str] = None |
||||
|
termsOfService: Optional[str] = None |
||||
|
contact: Optional[Contact] = None |
||||
|
license: Optional[License] = None |
||||
|
version: str |
||||
|
|
||||
|
|
||||
|
class ServerVariable(BaseModel): |
||||
|
enum: Optional[List[str]] = None |
||||
|
default: str |
||||
|
description: Optional[str] = None |
||||
|
|
||||
|
|
||||
|
class Server(BaseModel): |
||||
|
url: UrlStr |
||||
|
description: Optional[str] = None |
||||
|
variables: Optional[Dict[str, ServerVariable]] = None |
||||
|
|
||||
|
|
||||
|
class Reference(BaseModel): |
||||
|
ref: str = PSchema(..., alias="$ref") |
||||
|
|
||||
|
|
||||
|
class Discriminator(BaseModel): |
||||
|
propertyName: str |
||||
|
mapping: Optional[Dict[str, str]] = None |
||||
|
|
||||
|
|
||||
|
class XML(BaseModel): |
||||
|
name: Optional[str] = None |
||||
|
namespace: Optional[str] = None |
||||
|
prefix: Optional[str] = None |
||||
|
attribute: Optional[bool] = None |
||||
|
wrapped: Optional[bool] = None |
||||
|
|
||||
|
|
||||
|
class ExternalDocumentation(BaseModel): |
||||
|
description: Optional[str] = None |
||||
|
url: UrlStr |
||||
|
|
||||
|
|
||||
|
class SchemaBase(BaseModel): |
||||
|
ref: Optional[str] = PSchema(None, alias="$ref") |
||||
|
title: Optional[str] = None |
||||
|
multipleOf: Optional[float] = None |
||||
|
maximum: Optional[float] = None |
||||
|
exclusiveMaximum: Optional[float] = None |
||||
|
minimum: Optional[float] = None |
||||
|
exclusiveMinimum: Optional[float] = None |
||||
|
maxLength: Optional[int] = PSchema(None, gte=0) |
||||
|
minLength: Optional[int] = PSchema(None, gte=0) |
||||
|
pattern: Optional[str] = None |
||||
|
maxItems: Optional[int] = PSchema(None, gte=0) |
||||
|
minItems: Optional[int] = PSchema(None, gte=0) |
||||
|
uniqueItems: Optional[bool] = None |
||||
|
maxProperties: Optional[int] = PSchema(None, gte=0) |
||||
|
minProperties: Optional[int] = PSchema(None, gte=0) |
||||
|
required: Optional[List[str]] = None |
||||
|
enum: Optional[List[str]] = None |
||||
|
type: Optional[str] = None |
||||
|
allOf: Optional[List[Any]] = None |
||||
|
oneOf: Optional[List[Any]] = None |
||||
|
anyOf: Optional[List[Any]] = None |
||||
|
not_: Optional[List[Any]] = PSchema(None, alias="not") |
||||
|
items: Optional[Any] = None |
||||
|
properties: Optional[Dict[str, Any]] = None |
||||
|
additionalProperties: Optional[Union[bool, Any]] = None |
||||
|
description: Optional[str] = None |
||||
|
format: Optional[str] = None |
||||
|
default: Optional[Any] = None |
||||
|
nullable: Optional[bool] = None |
||||
|
discriminator: Optional[Discriminator] = None |
||||
|
readOnly: Optional[bool] = None |
||||
|
writeOnly: Optional[bool] = None |
||||
|
xml: Optional[XML] = None |
||||
|
externalDocs: Optional[ExternalDocumentation] = None |
||||
|
example: Optional[Any] = None |
||||
|
deprecated: Optional[bool] = None |
||||
|
|
||||
|
|
||||
|
class Schema(SchemaBase): |
||||
|
allOf: Optional[List[SchemaBase]] = None |
||||
|
oneOf: Optional[List[SchemaBase]] = None |
||||
|
anyOf: Optional[List[SchemaBase]] = None |
||||
|
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") |
||||
|
items: Optional[SchemaBase] = None |
||||
|
properties: Optional[Dict[str, SchemaBase]] = None |
||||
|
additionalProperties: Optional[Union[bool, SchemaBase]] = None |
||||
|
|
||||
|
|
||||
|
class Example(BaseModel): |
||||
|
summary: Optional[str] = None |
||||
|
description: Optional[str] = None |
||||
|
value: Optional[Any] = None |
||||
|
externalValue: Optional[UrlStr] = None |
||||
|
|
||||
|
|
||||
|
class ParameterInType(Enum): |
||||
|
query = "query" |
||||
|
header = "header" |
||||
|
path = "path" |
||||
|
cookie = "cookie" |
||||
|
|
||||
|
|
||||
|
class Encoding(BaseModel): |
||||
|
contentType: Optional[str] = None |
||||
|
# Workaround OpenAPI recursive reference, using Any |
||||
|
headers: Optional[Dict[str, Union[Any, Reference]]] = None |
||||
|
style: Optional[str] = None |
||||
|
explode: Optional[bool] = None |
||||
|
allowReserved: Optional[bool] = None |
||||
|
|
||||
|
|
||||
|
class MediaType(BaseModel): |
||||
|
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") |
||||
|
example: Optional[Any] = None |
||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
||||
|
encoding: Optional[Dict[str, Encoding]] = None |
||||
|
|
||||
|
|
||||
|
class ParameterBase(BaseModel): |
||||
|
description: Optional[str] = None |
||||
|
required: Optional[bool] = None |
||||
|
deprecated: Optional[bool] = None |
||||
|
# Serialization rules for simple scenarios |
||||
|
style: Optional[str] = None |
||||
|
explode: Optional[bool] = None |
||||
|
allowReserved: Optional[bool] = None |
||||
|
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") |
||||
|
example: Optional[Any] = None |
||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
||||
|
# Serialization rules for more complex scenarios |
||||
|
content: Optional[Dict[str, MediaType]] = None |
||||
|
|
||||
|
|
||||
|
class Parameter(ParameterBase): |
||||
|
name: str |
||||
|
in_: ParameterInType = PSchema(..., alias="in") |
||||
|
|
||||
|
|
||||
|
class Header(ParameterBase): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
# Workaround OpenAPI recursive reference |
||||
|
class EncodingWithHeaders(Encoding): |
||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
||||
|
|
||||
|
|
||||
|
class RequestBody(BaseModel): |
||||
|
description: Optional[str] = None |
||||
|
content: Dict[str, MediaType] |
||||
|
required: Optional[bool] = None |
||||
|
|
||||
|
|
||||
|
class Link(BaseModel): |
||||
|
operationRef: Optional[str] = None |
||||
|
operationId: Optional[str] = None |
||||
|
parameters: Optional[Dict[str, Union[Any, str]]] = None |
||||
|
requestBody: Optional[Union[Any, str]] = None |
||||
|
description: Optional[str] = None |
||||
|
server: Optional[Server] = None |
||||
|
|
||||
|
|
||||
|
class Response(BaseModel): |
||||
|
description: str |
||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
||||
|
content: Optional[Dict[str, MediaType]] = None |
||||
|
links: Optional[Dict[str, Union[Link, Reference]]] = None |
||||
|
|
||||
|
|
||||
|
class Responses(BaseModel): |
||||
|
default: Response |
||||
|
|
||||
|
|
||||
|
class Operation(BaseModel): |
||||
|
tags: Optional[List[str]] = None |
||||
|
summary: Optional[str] = None |
||||
|
description: Optional[str] = None |
||||
|
externalDocs: Optional[ExternalDocumentation] = None |
||||
|
operationId: Optional[str] = None |
||||
|
parameters: Optional[List[Union[Parameter, Reference]]] = None |
||||
|
requestBody: Optional[Union[RequestBody, Reference]] = None |
||||
|
responses: Union[Responses, Dict[Union[str], Response]] |
||||
|
# Workaround OpenAPI recursive reference |
||||
|
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None |
||||
|
deprecated: Optional[bool] = None |
||||
|
security: Optional[List[Dict[str, List[str]]]] = None |
||||
|
servers: Optional[List[Server]] = None |
||||
|
|
||||
|
|
||||
|
class PathItem(BaseModel): |
||||
|
ref: Optional[str] = PSchema(None, alias="$ref") |
||||
|
summary: Optional[str] = None |
||||
|
description: Optional[str] = None |
||||
|
get: Optional[Operation] = None |
||||
|
put: Optional[Operation] = None |
||||
|
post: Optional[Operation] = None |
||||
|
delete: Optional[Operation] = None |
||||
|
options: Optional[Operation] = None |
||||
|
head: Optional[Operation] = None |
||||
|
patch: Optional[Operation] = None |
||||
|
trace: Optional[Operation] = None |
||||
|
servers: Optional[List[Server]] = None |
||||
|
parameters: Optional[List[Union[Parameter, Reference]]] = None |
||||
|
|
||||
|
|
||||
|
# Workaround OpenAPI recursive reference |
||||
|
class OperationWithCallbacks(BaseModel): |
||||
|
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None |
||||
|
|
||||
|
|
||||
|
class SecuritySchemeType(Enum): |
||||
|
apiKey = "apiKey" |
||||
|
http = "http" |
||||
|
oauth2 = "oauth2" |
||||
|
openIdConnect = "openIdConnect" |
||||
|
|
||||
|
|
||||
|
class SecurityBase(BaseModel): |
||||
|
type_: SecuritySchemeType = PSchema(..., alias="type") |
||||
|
description: Optional[str] = None |
||||
|
|
||||
|
|
||||
|
class APIKeyIn(Enum): |
||||
|
query = "query" |
||||
|
header = "header" |
||||
|
cookie = "cookie" |
||||
|
|
||||
|
|
||||
|
class APIKey(SecurityBase): |
||||
|
type_ = PSchema(SecuritySchemeType.apiKey, alias="type") |
||||
|
in_: APIKeyIn = PSchema(..., alias="in") |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
class HTTPBase(SecurityBase): |
||||
|
type_ = PSchema(SecuritySchemeType.http, alias="type") |
||||
|
scheme: str |
||||
|
|
||||
|
|
||||
|
class HTTPBearer(HTTPBase): |
||||
|
scheme = "bearer" |
||||
|
bearerFormat: Optional[str] = None |
||||
|
|
||||
|
|
||||
|
class OAuthFlow(BaseModel): |
||||
|
refreshUrl: Optional[str] = None |
||||
|
scopes: Dict[str, str] = {} |
||||
|
|
||||
|
|
||||
|
class OAuthFlowImplicit(OAuthFlow): |
||||
|
authorizationUrl: str |
||||
|
|
||||
|
|
||||
|
class OAuthFlowPassword(OAuthFlow): |
||||
|
tokenUrl: str |
||||
|
|
||||
|
|
||||
|
class OAuthFlowClientCredentials(OAuthFlow): |
||||
|
tokenUrl: str |
||||
|
|
||||
|
|
||||
|
class OAuthFlowAuthorizationCode(OAuthFlow): |
||||
|
authorizationUrl: str |
||||
|
tokenUrl: str |
||||
|
|
||||
|
|
||||
|
class OAuthFlows(BaseModel): |
||||
|
implicit: Optional[OAuthFlowImplicit] = None |
||||
|
password: Optional[OAuthFlowPassword] = None |
||||
|
clientCredentials: Optional[OAuthFlowClientCredentials] = None |
||||
|
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None |
||||
|
|
||||
|
|
||||
|
class OAuth2(SecurityBase): |
||||
|
type_ = PSchema(SecuritySchemeType.oauth2, alias="type") |
||||
|
flows: OAuthFlows |
||||
|
|
||||
|
|
||||
|
class OpenIdConnect(SecurityBase): |
||||
|
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") |
||||
|
openIdConnectUrl: str |
||||
|
|
||||
|
|
||||
|
SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect] |
||||
|
|
||||
|
|
||||
|
class Components(BaseModel): |
||||
|
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None |
||||
|
responses: Optional[Dict[str, Union[Response, Reference]]] = None |
||||
|
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None |
||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
||||
|
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None |
||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
||||
|
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None |
||||
|
links: Optional[Dict[str, Union[Link, Reference]]] = None |
||||
|
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None |
||||
|
|
||||
|
|
||||
|
class Tag(BaseModel): |
||||
|
name: str |
||||
|
description: Optional[str] = None |
||||
|
externalDocs: Optional[ExternalDocumentation] = None |
||||
|
|
||||
|
|
||||
|
class OpenAPI(BaseModel): |
||||
|
openapi: str |
||||
|
info: Info |
||||
|
servers: Optional[List[Server]] = None |
||||
|
paths: Dict[str, PathItem] |
||||
|
components: Optional[Components] = None |
||||
|
security: Optional[List[Dict[str, List[str]]]] = None |
||||
|
tags: Optional[List[Tag]] = None |
||||
|
externalDocs: Optional[ExternalDocumentation] = None |
@ -0,0 +1,280 @@ |
|||||
|
from typing import Any, Dict, Sequence, Type |
||||
|
|
||||
|
from starlette.responses import HTMLResponse, JSONResponse |
||||
|
from starlette.routing import BaseRoute |
||||
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY |
||||
|
|
||||
|
from fastapi import routing |
||||
|
from fastapi.dependencies.models import Dependant |
||||
|
from fastapi.dependencies.utils import get_flat_dependant |
||||
|
from fastapi.encoders import jsonable_encoder |
||||
|
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY |
||||
|
from fastapi.openapi.models import OpenAPI |
||||
|
from fastapi.params import Body |
||||
|
from fastapi.utils import get_flat_models_from_routes, get_model_definitions |
||||
|
from pydantic.fields import Field |
||||
|
from pydantic.schema import field_schema, get_model_name_map |
||||
|
from pydantic.utils import lenient_issubclass |
||||
|
|
||||
|
validation_error_definition = { |
||||
|
"title": "ValidationError", |
||||
|
"type": "object", |
||||
|
"properties": { |
||||
|
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}}, |
||||
|
"msg": {"title": "Message", "type": "string"}, |
||||
|
"type": {"title": "Error Type", "type": "string"}, |
||||
|
}, |
||||
|
"required": ["loc", "msg", "type"], |
||||
|
} |
||||
|
|
||||
|
validation_error_response_definition = { |
||||
|
"title": "HTTPValidationError", |
||||
|
"type": "object", |
||||
|
"properties": { |
||||
|
"detail": { |
||||
|
"title": "Detail", |
||||
|
"type": "array", |
||||
|
"items": {"$ref": REF_PREFIX + "ValidationError"}, |
||||
|
} |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def get_openapi_params(dependant: Dependant): |
||||
|
flat_dependant = get_flat_dependant(dependant) |
||||
|
return ( |
||||
|
flat_dependant.path_params |
||||
|
+ flat_dependant.query_params |
||||
|
+ flat_dependant.header_params |
||||
|
+ flat_dependant.cookie_params |
||||
|
) |
||||
|
|
||||
|
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): |
||||
|
if not (route.include_in_schema and isinstance(route, routing.APIRoute)): |
||||
|
return None |
||||
|
path = {} |
||||
|
security_schemes = {} |
||||
|
definitions = {} |
||||
|
for method in route.methods: |
||||
|
operation: Dict[str, Any] = {} |
||||
|
if route.tags: |
||||
|
operation["tags"] = route.tags |
||||
|
if route.summary: |
||||
|
operation["summary"] = route.summary |
||||
|
if route.description: |
||||
|
operation["description"] = route.description |
||||
|
if route.operation_id: |
||||
|
operation["operationId"] = route.operation_id |
||||
|
else: |
||||
|
operation["operationId"] = route.name |
||||
|
if route.deprecated: |
||||
|
operation["deprecated"] = route.deprecated |
||||
|
parameters = [] |
||||
|
flat_dependant = get_flat_dependant(route.dependant) |
||||
|
security_definitions = {} |
||||
|
for security_requirement in flat_dependant.security_requirements: |
||||
|
security_definition = jsonable_encoder( |
||||
|
security_requirement.security_scheme, |
||||
|
exclude={"scheme_name"}, |
||||
|
by_alias=True, |
||||
|
include_none=False, |
||||
|
) |
||||
|
security_name = ( |
||||
|
getattr( |
||||
|
security_requirement.security_scheme, "scheme_name", None |
||||
|
) |
||||
|
or security_requirement.security_scheme.__class__.__name__ |
||||
|
) |
||||
|
security_definitions[security_name] = security_definition |
||||
|
operation.setdefault("security", []).append( |
||||
|
{security_name: security_requirement.scopes} |
||||
|
) |
||||
|
if security_definitions: |
||||
|
security_schemes.update( |
||||
|
security_definitions |
||||
|
) |
||||
|
all_route_params = get_openapi_params(route.dependant) |
||||
|
for param in all_route_params: |
||||
|
if "ValidationError" not in definitions: |
||||
|
definitions["ValidationError"] = validation_error_definition |
||||
|
definitions[ |
||||
|
"HTTPValidationError" |
||||
|
] = validation_error_response_definition |
||||
|
parameter = { |
||||
|
"name": param.alias, |
||||
|
"in": param.schema.in_.value, |
||||
|
"required": param.required, |
||||
|
"schema": field_schema(param, model_name_map={})[0], |
||||
|
} |
||||
|
if param.schema.description: |
||||
|
parameter["description"] = param.schema.description |
||||
|
if param.schema.deprecated: |
||||
|
parameter["deprecated"] = param.schema.deprecated |
||||
|
parameters.append(parameter) |
||||
|
if parameters: |
||||
|
operation["parameters"] = parameters |
||||
|
if method in METHODS_WITH_BODY: |
||||
|
body_field = route.body_field |
||||
|
if body_field: |
||||
|
assert isinstance(body_field, Field) |
||||
|
body_schema, _ = field_schema( |
||||
|
body_field, |
||||
|
model_name_map=model_name_map, |
||||
|
ref_prefix=REF_PREFIX, |
||||
|
) |
||||
|
if isinstance(body_field.schema, Body): |
||||
|
request_media_type = body_field.schema.media_type |
||||
|
else: |
||||
|
# Includes not declared media types (Schema) |
||||
|
request_media_type = "application/json" |
||||
|
required = body_field.required |
||||
|
request_body_oai = {} |
||||
|
if required: |
||||
|
request_body_oai["required"] = required |
||||
|
request_body_oai["content"] = { |
||||
|
request_media_type: {"schema": body_schema} |
||||
|
} |
||||
|
operation["requestBody"] = request_body_oai |
||||
|
response_code = str(route.response_code) |
||||
|
response_schema = {"type": "string"} |
||||
|
if lenient_issubclass(route.response_wrapper, JSONResponse): |
||||
|
response_media_type = "application/json" |
||||
|
if route.response_field: |
||||
|
response_schema, _ = field_schema( |
||||
|
route.response_field, |
||||
|
model_name_map=model_name_map, |
||||
|
ref_prefix=REF_PREFIX, |
||||
|
) |
||||
|
else: |
||||
|
response_schema = {} |
||||
|
elif lenient_issubclass(route.response_wrapper, HTMLResponse): |
||||
|
response_media_type = "text/html" |
||||
|
else: |
||||
|
response_media_type = "text/plain" |
||||
|
content = {response_media_type: {"schema": response_schema}} |
||||
|
operation["responses"] = { |
||||
|
response_code: { |
||||
|
"description": route.response_description, |
||||
|
"content": content, |
||||
|
} |
||||
|
} |
||||
|
if all_route_params or route.body_field: |
||||
|
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { |
||||
|
"description": "Validation Error", |
||||
|
"content": { |
||||
|
"application/json": { |
||||
|
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"} |
||||
|
} |
||||
|
}, |
||||
|
} |
||||
|
path[method.lower()] = operation |
||||
|
return path, security_schemes, definitions |
||||
|
|
||||
|
|
||||
|
def get_openapi( |
||||
|
*, |
||||
|
title: str, |
||||
|
version: str, |
||||
|
openapi_version: str = "3.0.2", |
||||
|
description: str = None, |
||||
|
routes: Sequence[BaseRoute] |
||||
|
): |
||||
|
info = {"title": title, "version": version} |
||||
|
if description: |
||||
|
info["description"] = description |
||||
|
output = {"openapi": openapi_version, "info": info} |
||||
|
components: Dict[str, Dict] = {} |
||||
|
paths: Dict[str, Dict] = {} |
||||
|
flat_models = get_flat_models_from_routes(routes) |
||||
|
model_name_map = get_model_name_map(flat_models) |
||||
|
definitions = get_model_definitions( |
||||
|
flat_models=flat_models, model_name_map=model_name_map |
||||
|
) |
||||
|
for route in routes: |
||||
|
result = get_openapi_path(route=route, model_name_map=model_name_map) |
||||
|
if result: |
||||
|
path, security_schemes, path_definitions = result |
||||
|
if path: |
||||
|
paths.setdefault(route.path, {}).update(path) |
||||
|
if security_schemes: |
||||
|
components.setdefault("securitySchemes", {}).update(security_schemes) |
||||
|
if path_definitions: |
||||
|
definitions.update(path_definitions) |
||||
|
if definitions: |
||||
|
components.setdefault("schemas", {}).update(definitions) |
||||
|
if components: |
||||
|
output["components"] = components |
||||
|
output["paths"] = paths |
||||
|
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) |
||||
|
|
||||
|
|
||||
|
def get_swagger_ui_html(*, openapi_url: str, title: str): |
||||
|
return HTMLResponse( |
||||
|
""" |
||||
|
<! doctype html> |
||||
|
<html> |
||||
|
<head> |
||||
|
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css"> |
||||
|
<title> |
||||
|
""" + title + """ |
||||
|
</title> |
||||
|
</head> |
||||
|
<body> |
||||
|
<div id="swagger-ui"> |
||||
|
</div> |
||||
|
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script> |
||||
|
<!-- `SwaggerUIBundle` is now available on the page --> |
||||
|
<script> |
||||
|
|
||||
|
const ui = SwaggerUIBundle({ |
||||
|
url: '""" |
||||
|
+ openapi_url |
||||
|
+ """', |
||||
|
dom_id: '#swagger-ui', |
||||
|
presets: [ |
||||
|
SwaggerUIBundle.presets.apis, |
||||
|
SwaggerUIBundle.SwaggerUIStandalonePreset |
||||
|
], |
||||
|
layout: "BaseLayout" |
||||
|
|
||||
|
}) |
||||
|
</script> |
||||
|
</body> |
||||
|
</html> |
||||
|
""", |
||||
|
media_type="text/html", |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def get_redoc_html(*, openapi_url: str, title: str): |
||||
|
return HTMLResponse( |
||||
|
""" |
||||
|
<!DOCTYPE html> |
||||
|
<html> |
||||
|
<head> |
||||
|
<title> |
||||
|
""" + title + """ |
||||
|
</title> |
||||
|
<!-- needed for adaptive design --> |
||||
|
<meta charset="utf-8"/> |
||||
|
<meta name="viewport" content="width=device-width, initial-scale=1"> |
||||
|
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet"> |
||||
|
|
||||
|
<!-- |
||||
|
ReDoc doesn't change outer page styles |
||||
|
--> |
||||
|
<style> |
||||
|
body { |
||||
|
margin: 0; |
||||
|
padding: 0; |
||||
|
} |
||||
|
</style> |
||||
|
</head> |
||||
|
<body> |
||||
|
<redoc spec-url='""" + openapi_url + """'></redoc> |
||||
|
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script> |
||||
|
</body> |
||||
|
</html> |
||||
|
""", |
||||
|
media_type="text/html", |
||||
|
) |
@ -0,0 +1,46 @@ |
|||||
|
import re |
||||
|
from typing import Dict, Sequence, Set, Type |
||||
|
|
||||
|
from starlette.routing import BaseRoute |
||||
|
|
||||
|
from fastapi import routing |
||||
|
from fastapi.openapi.constants import REF_PREFIX |
||||
|
from pydantic import BaseModel |
||||
|
from pydantic.fields import Field |
||||
|
from pydantic.schema import get_flat_models_from_fields, model_process_schema |
||||
|
|
||||
|
|
||||
|
def get_flat_models_from_routes(routes: Sequence[BaseRoute]): |
||||
|
body_fields_from_routes = [] |
||||
|
responses_from_routes = [] |
||||
|
for route in routes: |
||||
|
if route.include_in_schema and isinstance(route, routing.APIRoute): |
||||
|
if route.body_field: |
||||
|
assert isinstance( |
||||
|
route.body_field, Field |
||||
|
), "A request body must be a Pydantic Field" |
||||
|
body_fields_from_routes.append(route.body_field) |
||||
|
if route.response_field: |
||||
|
responses_from_routes.append(route.response_field) |
||||
|
flat_models = get_flat_models_from_fields( |
||||
|
body_fields_from_routes + responses_from_routes |
||||
|
) |
||||
|
return flat_models |
||||
|
|
||||
|
|
||||
|
def get_model_definitions( |
||||
|
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] |
||||
|
): |
||||
|
definitions: Dict[str, Dict] = {} |
||||
|
for model in flat_models: |
||||
|
m_schema, m_definitions = model_process_schema( |
||||
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX |
||||
|
) |
||||
|
definitions.update(m_definitions) |
||||
|
model_name = model_name_map[model] |
||||
|
definitions[model_name] = m_schema |
||||
|
return definitions |
||||
|
|
||||
|
|
||||
|
def get_path_param_names(path: str): |
||||
|
return {item.strip("{}") for item in re.findall("{[^}]*}", path)} |
Loading…
Reference in new issue