18 changed files with 1376 additions and 703 deletions
@ -1,3 +1,3 @@ |
|||
"""Fast API framework, fast high performance, fast to learn, fast to code""" |
|||
|
|||
__version__ = '0.1' |
|||
__version__ = "0.1" |
|||
|
@ -0,0 +1,46 @@ |
|||
from typing import Any, Callable, Dict, List, Sequence, Tuple |
|||
|
|||
from starlette.concurrency import run_in_threadpool |
|||
from starlette.requests import Request |
|||
|
|||
from fastapi.security.base import SecurityBase |
|||
from pydantic import BaseConfig, Schema |
|||
from pydantic.error_wrappers import ErrorWrapper |
|||
from pydantic.errors import MissingError |
|||
from pydantic.fields import Field, Required |
|||
from pydantic.schema import get_annotation_from_schema |
|||
|
|||
param_supported_types = (str, int, float, bool) |
|||
|
|||
|
|||
class SecurityRequirement: |
|||
def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None): |
|||
self.security_scheme = security_scheme |
|||
self.scopes = scopes |
|||
|
|||
|
|||
class Dependant: |
|||
def __init__( |
|||
self, |
|||
*, |
|||
path_params: List[Field] = None, |
|||
query_params: List[Field] = None, |
|||
header_params: List[Field] = None, |
|||
cookie_params: List[Field] = None, |
|||
body_params: List[Field] = None, |
|||
dependencies: List["Dependant"] = None, |
|||
security_schemes: List[SecurityRequirement] = None, |
|||
name: str = None, |
|||
call: Callable = None, |
|||
request_param_name: str = None, |
|||
) -> None: |
|||
self.path_params = path_params or [] |
|||
self.query_params = query_params or [] |
|||
self.header_params = header_params or [] |
|||
self.cookie_params = cookie_params or [] |
|||
self.body_params = body_params or [] |
|||
self.dependencies = dependencies or [] |
|||
self.security_requirements = security_schemes or [] |
|||
self.request_param_name = request_param_name |
|||
self.name = name |
|||
self.call = call |
@ -0,0 +1,327 @@ |
|||
import asyncio |
|||
import inspect |
|||
from copy import deepcopy |
|||
from typing import Any, Callable, Dict, List, Tuple |
|||
|
|||
from starlette.concurrency import run_in_threadpool |
|||
from starlette.requests import Request |
|||
|
|||
from fastapi import params |
|||
from fastapi.dependencies.models import Dependant, SecurityRequirement |
|||
from fastapi.security.base import SecurityBase |
|||
from fastapi.utils import get_path_param_names |
|||
from pydantic import BaseConfig, Schema, create_model |
|||
from pydantic.error_wrappers import ErrorWrapper |
|||
from pydantic.errors import MissingError |
|||
from pydantic.fields import Field, Required |
|||
from pydantic.schema import get_annotation_from_schema |
|||
from pydantic.utils import lenient_issubclass |
|||
|
|||
param_supported_types = (str, int, float, bool) |
|||
|
|||
|
|||
def get_sub_dependant(*, param: inspect.Parameter, path: str): |
|||
depends: params.Depends = param.default |
|||
if depends.dependency: |
|||
dependency = depends.dependency |
|||
else: |
|||
dependency = param.annotation |
|||
assert callable(dependency) |
|||
sub_dependant = get_dependant(path=path, call=dependency, name=param.name) |
|||
if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase): |
|||
security_requirement = SecurityRequirement( |
|||
security_scheme=dependency, scopes=depends.scopes |
|||
) |
|||
sub_dependant.security_requirements.append(security_requirement) |
|||
return sub_dependant |
|||
|
|||
|
|||
def get_flat_dependant(dependant: Dependant): |
|||
flat_dependant = Dependant( |
|||
path_params=dependant.path_params.copy(), |
|||
query_params=dependant.query_params.copy(), |
|||
header_params=dependant.header_params.copy(), |
|||
cookie_params=dependant.cookie_params.copy(), |
|||
body_params=dependant.body_params.copy(), |
|||
security_schemes=dependant.security_requirements.copy(), |
|||
) |
|||
for sub_dependant in dependant.dependencies: |
|||
if sub_dependant is dependant: |
|||
raise ValueError("recursion", dependant.dependencies) |
|||
flat_sub = get_flat_dependant(sub_dependant) |
|||
flat_dependant.path_params.extend(flat_sub.path_params) |
|||
flat_dependant.query_params.extend(flat_sub.query_params) |
|||
flat_dependant.header_params.extend(flat_sub.header_params) |
|||
flat_dependant.cookie_params.extend(flat_sub.cookie_params) |
|||
flat_dependant.body_params.extend(flat_sub.body_params) |
|||
flat_dependant.security_requirements.extend(flat_sub.security_requirements) |
|||
return flat_dependant |
|||
|
|||
|
|||
def get_dependant(*, path: str, call: Callable, name: str = None): |
|||
path_param_names = get_path_param_names(path) |
|||
endpoint_signature = inspect.signature(call) |
|||
signature_params = endpoint_signature.parameters |
|||
dependant = Dependant(call=call, name=name) |
|||
for param_name in signature_params: |
|||
param = signature_params[param_name] |
|||
if isinstance(param.default, params.Depends): |
|||
sub_dependant = get_sub_dependant(param=param, path=path) |
|||
dependant.dependencies.append(sub_dependant) |
|||
for param_name in signature_params: |
|||
param = signature_params[param_name] |
|||
if ( |
|||
(param.default == param.empty) or isinstance(param.default, params.Path) |
|||
) and (param_name in path_param_names): |
|||
assert lenient_issubclass( |
|||
param.annotation, param_supported_types |
|||
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}" |
|||
param = signature_params[param_name] |
|||
add_param_to_fields( |
|||
param=param, |
|||
dependant=dependant, |
|||
default_schema=params.Path, |
|||
force_type=params.ParamTypes.path, |
|||
) |
|||
elif (param.default == param.empty or param.default is None) and ( |
|||
param.annotation == param.empty |
|||
or lenient_issubclass(param.annotation, param_supported_types) |
|||
): |
|||
add_param_to_fields( |
|||
param=param, dependant=dependant, default_schema=params.Query |
|||
) |
|||
elif isinstance(param.default, params.Param): |
|||
if param.annotation != param.empty: |
|||
assert lenient_issubclass( |
|||
param.annotation, param_supported_types |
|||
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" |
|||
add_param_to_fields( |
|||
param=param, dependant=dependant, default_schema=params.Query |
|||
) |
|||
elif lenient_issubclass(param.annotation, Request): |
|||
dependant.request_param_name = param_name |
|||
elif not isinstance(param.default, params.Depends): |
|||
add_param_to_body_fields(param=param, dependant=dependant) |
|||
return dependant |
|||
|
|||
|
|||
def add_param_to_fields( |
|||
*, |
|||
param: inspect.Parameter, |
|||
dependant: Dependant, |
|||
default_schema=params.Param, |
|||
force_type: params.ParamTypes = None, |
|||
): |
|||
default_value = Required |
|||
if not param.default == param.empty: |
|||
default_value = param.default |
|||
if isinstance(default_value, params.Param): |
|||
schema = default_value |
|||
default_value = schema.default |
|||
if schema.in_ is None: |
|||
schema.in_ = default_schema.in_ |
|||
if force_type: |
|||
schema.in_ = force_type |
|||
else: |
|||
schema = default_schema(default_value) |
|||
required = default_value == Required |
|||
annotation = Any |
|||
if not param.annotation == param.empty: |
|||
annotation = param.annotation |
|||
annotation = get_annotation_from_schema(annotation, schema) |
|||
field = Field( |
|||
name=param.name, |
|||
type_=annotation, |
|||
default=None if required else default_value, |
|||
alias=schema.alias or param.name, |
|||
required=required, |
|||
model_config=BaseConfig(), |
|||
class_validators=[], |
|||
schema=schema, |
|||
) |
|||
if schema.in_ == params.ParamTypes.path: |
|||
dependant.path_params.append(field) |
|||
elif schema.in_ == params.ParamTypes.query: |
|||
dependant.query_params.append(field) |
|||
elif schema.in_ == params.ParamTypes.header: |
|||
dependant.header_params.append(field) |
|||
else: |
|||
assert ( |
|||
schema.in_ == params.ParamTypes.cookie |
|||
), f"non-body parameters must be in path, query, header or cookie: {param.name}" |
|||
dependant.cookie_params.append(field) |
|||
|
|||
|
|||
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): |
|||
default_value = Required |
|||
if not param.default == param.empty: |
|||
default_value = param.default |
|||
if isinstance(default_value, Schema): |
|||
schema = default_value |
|||
default_value = schema.default |
|||
else: |
|||
schema = Schema(default_value) |
|||
required = default_value == Required |
|||
annotation = get_annotation_from_schema(param.annotation, schema) |
|||
field = Field( |
|||
name=param.name, |
|||
type_=annotation, |
|||
default=None if required else default_value, |
|||
alias=schema.alias or param.name, |
|||
required=required, |
|||
model_config=BaseConfig, |
|||
class_validators=[], |
|||
schema=schema, |
|||
) |
|||
dependant.body_params.append(field) |
|||
|
|||
|
|||
def is_coroutine_callable(call: Callable = None): |
|||
if not call: |
|||
return False |
|||
if inspect.isfunction(call): |
|||
return asyncio.iscoroutinefunction(call) |
|||
if inspect.isclass(call): |
|||
return False |
|||
call = getattr(call, "__call__", None) |
|||
if not call: |
|||
return False |
|||
return asyncio.iscoroutinefunction(call) |
|||
|
|||
|
|||
async def solve_dependencies( |
|||
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None |
|||
): |
|||
values: Dict[str, Any] = {} |
|||
errors: List[ErrorWrapper] = [] |
|||
for sub_dependant in dependant.dependencies: |
|||
sub_values, sub_errors = await solve_dependencies( |
|||
request=request, dependant=sub_dependant, body=body |
|||
) |
|||
if sub_errors: |
|||
return {}, errors |
|||
if sub_dependant.call and is_coroutine_callable(sub_dependant.call): |
|||
solved = await sub_dependant.call(**sub_values) |
|||
else: |
|||
solved = await run_in_threadpool(sub_dependant.call, **sub_values) |
|||
values[ |
|||
sub_dependant.name |
|||
] = solved # type: ignore # Sub-dependants always have a name |
|||
path_values, path_errors = request_params_to_args( |
|||
dependant.path_params, request.path_params |
|||
) |
|||
query_values, query_errors = request_params_to_args( |
|||
dependant.query_params, request.query_params |
|||
) |
|||
header_values, header_errors = request_params_to_args( |
|||
dependant.header_params, request.headers |
|||
) |
|||
cookie_values, cookie_errors = request_params_to_args( |
|||
dependant.cookie_params, request.cookies |
|||
) |
|||
values.update(path_values) |
|||
values.update(query_values) |
|||
values.update(header_values) |
|||
values.update(cookie_values) |
|||
errors = path_errors + query_errors + header_errors + cookie_errors |
|||
if dependant.body_params: |
|||
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above |
|||
dependant.body_params, body |
|||
) |
|||
values.update(body_values) |
|||
errors.extend(body_errors) |
|||
if dependant.request_param_name: |
|||
values[dependant.request_param_name] = request |
|||
return values, errors |
|||
|
|||
|
|||
def request_params_to_args( |
|||
required_params: List[Field], received_params: Dict[str, Any] |
|||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
for field in required_params: |
|||
value = received_params.get(field.alias) |
|||
if value is None: |
|||
if field.required: |
|||
errors.append( |
|||
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) |
|||
) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
v_, errors_ = field.validate( |
|||
value, values, loc=(field.schema.in_.value, field.alias) |
|||
) |
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
|
|||
|
|||
async def request_body_to_args( |
|||
required_params: List[Field], received_body: Dict[str, Any] |
|||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
if required_params: |
|||
field = required_params[0] |
|||
embed = getattr(field.schema, "embed", None) |
|||
if len(required_params) == 1 and not embed: |
|||
received_body = {field.alias: received_body} |
|||
for field in required_params: |
|||
value = received_body.get(field.alias) |
|||
if value is None: |
|||
if field.required: |
|||
errors.append( |
|||
ErrorWrapper( |
|||
MissingError(), loc=("body", field.alias), config=BaseConfig |
|||
) |
|||
) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
v_, errors_ = field.validate(value, values, loc=("body", field.alias)) |
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
|
|||
|
|||
def get_body_field(*, dependant: Dependant, name: str): |
|||
flat_dependant = get_flat_dependant(dependant) |
|||
if not flat_dependant.body_params: |
|||
return None |
|||
first_param = flat_dependant.body_params[0] |
|||
embed = getattr(first_param.schema, "embed", None) |
|||
if len(flat_dependant.body_params) == 1 and not embed: |
|||
return first_param |
|||
model_name = "Body_" + name |
|||
BodyModel = create_model(model_name) |
|||
for f in flat_dependant.body_params: |
|||
BodyModel.__fields__[f.name] = f |
|||
required = any(True for f in flat_dependant.body_params if f.required) |
|||
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): |
|||
BodySchema = params.File |
|||
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): |
|||
BodySchema = params.Form |
|||
else: |
|||
BodySchema = params.Body |
|||
|
|||
field = Field( |
|||
name="body", |
|||
type_=BodyModel, |
|||
default=None, |
|||
required=required, |
|||
model_config=BaseConfig, |
|||
class_validators=[], |
|||
alias="body", |
|||
schema=BodySchema(None), |
|||
) |
|||
return field |
@ -1,33 +1,44 @@ |
|||
from enum import Enum |
|||
from types import GeneratorType |
|||
from typing import Set |
|||
|
|||
from pydantic import BaseModel |
|||
from enum import Enum |
|||
from pydantic.json import pydantic_encoder |
|||
|
|||
|
|||
def jsonable_encoder( |
|||
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True, |
|||
obj, |
|||
include: Set[str] = None, |
|||
exclude: Set[str] = set(), |
|||
by_alias: bool = False, |
|||
include_none=True, |
|||
): |
|||
if isinstance(obj, BaseModel): |
|||
return jsonable_encoder( |
|||
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none |
|||
obj.dict(include=include, exclude=exclude, by_alias=by_alias), |
|||
include_none=include_none, |
|||
) |
|||
elif isinstance(obj, Enum): |
|||
if isinstance(obj, Enum): |
|||
return obj.value |
|||
if isinstance(obj, (str, int, float, type(None))): |
|||
return obj |
|||
if isinstance(obj, dict): |
|||
return { |
|||
jsonable_encoder( |
|||
key, by_alias=by_alias, include_none=include_none, |
|||
): jsonable_encoder( |
|||
value, by_alias=by_alias, include_none=include_none, |
|||
) |
|||
for key, value in obj.items() if value is not None or include_none |
|||
key, by_alias=by_alias, include_none=include_none |
|||
): jsonable_encoder(value, by_alias=by_alias, include_none=include_none) |
|||
for key, value in obj.items() |
|||
if value is not None or include_none |
|||
} |
|||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): |
|||
return [ |
|||
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none) |
|||
jsonable_encoder( |
|||
item, |
|||
include=include, |
|||
exclude=exclude, |
|||
by_alias=by_alias, |
|||
include_none=include_none, |
|||
) |
|||
for item in obj |
|||
] |
|||
return pydantic_encoder(obj) |
@ -0,0 +1,2 @@ |
|||
METHODS_WITH_BODY = set(("POST", "PUT")) |
|||
REF_PREFIX = "#/components/schemas/" |
@ -0,0 +1,347 @@ |
|||
import logging |
|||
from enum import Enum |
|||
from typing import Any, Dict, List, Optional, Union |
|||
|
|||
from pydantic import BaseModel, Schema as PSchema |
|||
from pydantic.types import UrlStr |
|||
|
|||
try: |
|||
import pydantic.types.EmailStr |
|||
from pydantic.types import EmailStr |
|||
except ImportError: |
|||
logging.warning( |
|||
"email-validator not installed, email fields will be treated as str" |
|||
) |
|||
|
|||
class EmailStr(str): |
|||
pass |
|||
|
|||
|
|||
class Contact(BaseModel): |
|||
name: Optional[str] = None |
|||
url: Optional[UrlStr] = None |
|||
email: Optional[EmailStr] = None |
|||
|
|||
|
|||
class License(BaseModel): |
|||
name: str |
|||
url: Optional[UrlStr] = None |
|||
|
|||
|
|||
class Info(BaseModel): |
|||
title: str |
|||
description: Optional[str] = None |
|||
termsOfService: Optional[str] = None |
|||
contact: Optional[Contact] = None |
|||
license: Optional[License] = None |
|||
version: str |
|||
|
|||
|
|||
class ServerVariable(BaseModel): |
|||
enum: Optional[List[str]] = None |
|||
default: str |
|||
description: Optional[str] = None |
|||
|
|||
|
|||
class Server(BaseModel): |
|||
url: UrlStr |
|||
description: Optional[str] = None |
|||
variables: Optional[Dict[str, ServerVariable]] = None |
|||
|
|||
|
|||
class Reference(BaseModel): |
|||
ref: str = PSchema(..., alias="$ref") |
|||
|
|||
|
|||
class Discriminator(BaseModel): |
|||
propertyName: str |
|||
mapping: Optional[Dict[str, str]] = None |
|||
|
|||
|
|||
class XML(BaseModel): |
|||
name: Optional[str] = None |
|||
namespace: Optional[str] = None |
|||
prefix: Optional[str] = None |
|||
attribute: Optional[bool] = None |
|||
wrapped: Optional[bool] = None |
|||
|
|||
|
|||
class ExternalDocumentation(BaseModel): |
|||
description: Optional[str] = None |
|||
url: UrlStr |
|||
|
|||
|
|||
class SchemaBase(BaseModel): |
|||
ref: Optional[str] = PSchema(None, alias="$ref") |
|||
title: Optional[str] = None |
|||
multipleOf: Optional[float] = None |
|||
maximum: Optional[float] = None |
|||
exclusiveMaximum: Optional[float] = None |
|||
minimum: Optional[float] = None |
|||
exclusiveMinimum: Optional[float] = None |
|||
maxLength: Optional[int] = PSchema(None, gte=0) |
|||
minLength: Optional[int] = PSchema(None, gte=0) |
|||
pattern: Optional[str] = None |
|||
maxItems: Optional[int] = PSchema(None, gte=0) |
|||
minItems: Optional[int] = PSchema(None, gte=0) |
|||
uniqueItems: Optional[bool] = None |
|||
maxProperties: Optional[int] = PSchema(None, gte=0) |
|||
minProperties: Optional[int] = PSchema(None, gte=0) |
|||
required: Optional[List[str]] = None |
|||
enum: Optional[List[str]] = None |
|||
type: Optional[str] = None |
|||
allOf: Optional[List[Any]] = None |
|||
oneOf: Optional[List[Any]] = None |
|||
anyOf: Optional[List[Any]] = None |
|||
not_: Optional[List[Any]] = PSchema(None, alias="not") |
|||
items: Optional[Any] = None |
|||
properties: Optional[Dict[str, Any]] = None |
|||
additionalProperties: Optional[Union[bool, Any]] = None |
|||
description: Optional[str] = None |
|||
format: Optional[str] = None |
|||
default: Optional[Any] = None |
|||
nullable: Optional[bool] = None |
|||
discriminator: Optional[Discriminator] = None |
|||
readOnly: Optional[bool] = None |
|||
writeOnly: Optional[bool] = None |
|||
xml: Optional[XML] = None |
|||
externalDocs: Optional[ExternalDocumentation] = None |
|||
example: Optional[Any] = None |
|||
deprecated: Optional[bool] = None |
|||
|
|||
|
|||
class Schema(SchemaBase): |
|||
allOf: Optional[List[SchemaBase]] = None |
|||
oneOf: Optional[List[SchemaBase]] = None |
|||
anyOf: Optional[List[SchemaBase]] = None |
|||
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") |
|||
items: Optional[SchemaBase] = None |
|||
properties: Optional[Dict[str, SchemaBase]] = None |
|||
additionalProperties: Optional[Union[bool, SchemaBase]] = None |
|||
|
|||
|
|||
class Example(BaseModel): |
|||
summary: Optional[str] = None |
|||
description: Optional[str] = None |
|||
value: Optional[Any] = None |
|||
externalValue: Optional[UrlStr] = None |
|||
|
|||
|
|||
class ParameterInType(Enum): |
|||
query = "query" |
|||
header = "header" |
|||
path = "path" |
|||
cookie = "cookie" |
|||
|
|||
|
|||
class Encoding(BaseModel): |
|||
contentType: Optional[str] = None |
|||
# Workaround OpenAPI recursive reference, using Any |
|||
headers: Optional[Dict[str, Union[Any, Reference]]] = None |
|||
style: Optional[str] = None |
|||
explode: Optional[bool] = None |
|||
allowReserved: Optional[bool] = None |
|||
|
|||
|
|||
class MediaType(BaseModel): |
|||
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") |
|||
example: Optional[Any] = None |
|||
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
|||
encoding: Optional[Dict[str, Encoding]] = None |
|||
|
|||
|
|||
class ParameterBase(BaseModel): |
|||
description: Optional[str] = None |
|||
required: Optional[bool] = None |
|||
deprecated: Optional[bool] = None |
|||
# Serialization rules for simple scenarios |
|||
style: Optional[str] = None |
|||
explode: Optional[bool] = None |
|||
allowReserved: Optional[bool] = None |
|||
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") |
|||
example: Optional[Any] = None |
|||
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
|||
# Serialization rules for more complex scenarios |
|||
content: Optional[Dict[str, MediaType]] = None |
|||
|
|||
|
|||
class Parameter(ParameterBase): |
|||
name: str |
|||
in_: ParameterInType = PSchema(..., alias="in") |
|||
|
|||
|
|||
class Header(ParameterBase): |
|||
pass |
|||
|
|||
|
|||
# Workaround OpenAPI recursive reference |
|||
class EncodingWithHeaders(Encoding): |
|||
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
|||
|
|||
|
|||
class RequestBody(BaseModel): |
|||
description: Optional[str] = None |
|||
content: Dict[str, MediaType] |
|||
required: Optional[bool] = None |
|||
|
|||
|
|||
class Link(BaseModel): |
|||
operationRef: Optional[str] = None |
|||
operationId: Optional[str] = None |
|||
parameters: Optional[Dict[str, Union[Any, str]]] = None |
|||
requestBody: Optional[Union[Any, str]] = None |
|||
description: Optional[str] = None |
|||
server: Optional[Server] = None |
|||
|
|||
|
|||
class Response(BaseModel): |
|||
description: str |
|||
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
|||
content: Optional[Dict[str, MediaType]] = None |
|||
links: Optional[Dict[str, Union[Link, Reference]]] = None |
|||
|
|||
|
|||
class Responses(BaseModel): |
|||
default: Response |
|||
|
|||
|
|||
class Operation(BaseModel): |
|||
tags: Optional[List[str]] = None |
|||
summary: Optional[str] = None |
|||
description: Optional[str] = None |
|||
externalDocs: Optional[ExternalDocumentation] = None |
|||
operationId: Optional[str] = None |
|||
parameters: Optional[List[Union[Parameter, Reference]]] = None |
|||
requestBody: Optional[Union[RequestBody, Reference]] = None |
|||
responses: Union[Responses, Dict[Union[str], Response]] |
|||
# Workaround OpenAPI recursive reference |
|||
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None |
|||
deprecated: Optional[bool] = None |
|||
security: Optional[List[Dict[str, List[str]]]] = None |
|||
servers: Optional[List[Server]] = None |
|||
|
|||
|
|||
class PathItem(BaseModel): |
|||
ref: Optional[str] = PSchema(None, alias="$ref") |
|||
summary: Optional[str] = None |
|||
description: Optional[str] = None |
|||
get: Optional[Operation] = None |
|||
put: Optional[Operation] = None |
|||
post: Optional[Operation] = None |
|||
delete: Optional[Operation] = None |
|||
options: Optional[Operation] = None |
|||
head: Optional[Operation] = None |
|||
patch: Optional[Operation] = None |
|||
trace: Optional[Operation] = None |
|||
servers: Optional[List[Server]] = None |
|||
parameters: Optional[List[Union[Parameter, Reference]]] = None |
|||
|
|||
|
|||
# Workaround OpenAPI recursive reference |
|||
class OperationWithCallbacks(BaseModel): |
|||
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None |
|||
|
|||
|
|||
class SecuritySchemeType(Enum): |
|||
apiKey = "apiKey" |
|||
http = "http" |
|||
oauth2 = "oauth2" |
|||
openIdConnect = "openIdConnect" |
|||
|
|||
|
|||
class SecurityBase(BaseModel): |
|||
type_: SecuritySchemeType = PSchema(..., alias="type") |
|||
description: Optional[str] = None |
|||
|
|||
|
|||
class APIKeyIn(Enum): |
|||
query = "query" |
|||
header = "header" |
|||
cookie = "cookie" |
|||
|
|||
|
|||
class APIKey(SecurityBase): |
|||
type_ = PSchema(SecuritySchemeType.apiKey, alias="type") |
|||
in_: APIKeyIn = PSchema(..., alias="in") |
|||
name: str |
|||
|
|||
|
|||
class HTTPBase(SecurityBase): |
|||
type_ = PSchema(SecuritySchemeType.http, alias="type") |
|||
scheme: str |
|||
|
|||
|
|||
class HTTPBearer(HTTPBase): |
|||
scheme = "bearer" |
|||
bearerFormat: Optional[str] = None |
|||
|
|||
|
|||
class OAuthFlow(BaseModel): |
|||
refreshUrl: Optional[str] = None |
|||
scopes: Dict[str, str] = {} |
|||
|
|||
|
|||
class OAuthFlowImplicit(OAuthFlow): |
|||
authorizationUrl: str |
|||
|
|||
|
|||
class OAuthFlowPassword(OAuthFlow): |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlowClientCredentials(OAuthFlow): |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlowAuthorizationCode(OAuthFlow): |
|||
authorizationUrl: str |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlows(BaseModel): |
|||
implicit: Optional[OAuthFlowImplicit] = None |
|||
password: Optional[OAuthFlowPassword] = None |
|||
clientCredentials: Optional[OAuthFlowClientCredentials] = None |
|||
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None |
|||
|
|||
|
|||
class OAuth2(SecurityBase): |
|||
type_ = PSchema(SecuritySchemeType.oauth2, alias="type") |
|||
flows: OAuthFlows |
|||
|
|||
|
|||
class OpenIdConnect(SecurityBase): |
|||
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") |
|||
openIdConnectUrl: str |
|||
|
|||
|
|||
SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect] |
|||
|
|||
|
|||
class Components(BaseModel): |
|||
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None |
|||
responses: Optional[Dict[str, Union[Response, Reference]]] = None |
|||
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None |
|||
examples: Optional[Dict[str, Union[Example, Reference]]] = None |
|||
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None |
|||
headers: Optional[Dict[str, Union[Header, Reference]]] = None |
|||
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None |
|||
links: Optional[Dict[str, Union[Link, Reference]]] = None |
|||
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None |
|||
|
|||
|
|||
class Tag(BaseModel): |
|||
name: str |
|||
description: Optional[str] = None |
|||
externalDocs: Optional[ExternalDocumentation] = None |
|||
|
|||
|
|||
class OpenAPI(BaseModel): |
|||
openapi: str |
|||
info: Info |
|||
servers: Optional[List[Server]] = None |
|||
paths: Dict[str, PathItem] |
|||
components: Optional[Components] = None |
|||
security: Optional[List[Dict[str, List[str]]]] = None |
|||
tags: Optional[List[Tag]] = None |
|||
externalDocs: Optional[ExternalDocumentation] = None |
@ -0,0 +1,280 @@ |
|||
from typing import Any, Dict, Sequence, Type |
|||
|
|||
from starlette.responses import HTMLResponse, JSONResponse |
|||
from starlette.routing import BaseRoute |
|||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY |
|||
|
|||
from fastapi import routing |
|||
from fastapi.dependencies.models import Dependant |
|||
from fastapi.dependencies.utils import get_flat_dependant |
|||
from fastapi.encoders import jsonable_encoder |
|||
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY |
|||
from fastapi.openapi.models import OpenAPI |
|||
from fastapi.params import Body |
|||
from fastapi.utils import get_flat_models_from_routes, get_model_definitions |
|||
from pydantic.fields import Field |
|||
from pydantic.schema import field_schema, get_model_name_map |
|||
from pydantic.utils import lenient_issubclass |
|||
|
|||
validation_error_definition = { |
|||
"title": "ValidationError", |
|||
"type": "object", |
|||
"properties": { |
|||
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}}, |
|||
"msg": {"title": "Message", "type": "string"}, |
|||
"type": {"title": "Error Type", "type": "string"}, |
|||
}, |
|||
"required": ["loc", "msg", "type"], |
|||
} |
|||
|
|||
validation_error_response_definition = { |
|||
"title": "HTTPValidationError", |
|||
"type": "object", |
|||
"properties": { |
|||
"detail": { |
|||
"title": "Detail", |
|||
"type": "array", |
|||
"items": {"$ref": REF_PREFIX + "ValidationError"}, |
|||
} |
|||
}, |
|||
} |
|||
|
|||
|
|||
def get_openapi_params(dependant: Dependant): |
|||
flat_dependant = get_flat_dependant(dependant) |
|||
return ( |
|||
flat_dependant.path_params |
|||
+ flat_dependant.query_params |
|||
+ flat_dependant.header_params |
|||
+ flat_dependant.cookie_params |
|||
) |
|||
|
|||
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): |
|||
if not (route.include_in_schema and isinstance(route, routing.APIRoute)): |
|||
return None |
|||
path = {} |
|||
security_schemes = {} |
|||
definitions = {} |
|||
for method in route.methods: |
|||
operation: Dict[str, Any] = {} |
|||
if route.tags: |
|||
operation["tags"] = route.tags |
|||
if route.summary: |
|||
operation["summary"] = route.summary |
|||
if route.description: |
|||
operation["description"] = route.description |
|||
if route.operation_id: |
|||
operation["operationId"] = route.operation_id |
|||
else: |
|||
operation["operationId"] = route.name |
|||
if route.deprecated: |
|||
operation["deprecated"] = route.deprecated |
|||
parameters = [] |
|||
flat_dependant = get_flat_dependant(route.dependant) |
|||
security_definitions = {} |
|||
for security_requirement in flat_dependant.security_requirements: |
|||
security_definition = jsonable_encoder( |
|||
security_requirement.security_scheme, |
|||
exclude={"scheme_name"}, |
|||
by_alias=True, |
|||
include_none=False, |
|||
) |
|||
security_name = ( |
|||
getattr( |
|||
security_requirement.security_scheme, "scheme_name", None |
|||
) |
|||
or security_requirement.security_scheme.__class__.__name__ |
|||
) |
|||
security_definitions[security_name] = security_definition |
|||
operation.setdefault("security", []).append( |
|||
{security_name: security_requirement.scopes} |
|||
) |
|||
if security_definitions: |
|||
security_schemes.update( |
|||
security_definitions |
|||
) |
|||
all_route_params = get_openapi_params(route.dependant) |
|||
for param in all_route_params: |
|||
if "ValidationError" not in definitions: |
|||
definitions["ValidationError"] = validation_error_definition |
|||
definitions[ |
|||
"HTTPValidationError" |
|||
] = validation_error_response_definition |
|||
parameter = { |
|||
"name": param.alias, |
|||
"in": param.schema.in_.value, |
|||
"required": param.required, |
|||
"schema": field_schema(param, model_name_map={})[0], |
|||
} |
|||
if param.schema.description: |
|||
parameter["description"] = param.schema.description |
|||
if param.schema.deprecated: |
|||
parameter["deprecated"] = param.schema.deprecated |
|||
parameters.append(parameter) |
|||
if parameters: |
|||
operation["parameters"] = parameters |
|||
if method in METHODS_WITH_BODY: |
|||
body_field = route.body_field |
|||
if body_field: |
|||
assert isinstance(body_field, Field) |
|||
body_schema, _ = field_schema( |
|||
body_field, |
|||
model_name_map=model_name_map, |
|||
ref_prefix=REF_PREFIX, |
|||
) |
|||
if isinstance(body_field.schema, Body): |
|||
request_media_type = body_field.schema.media_type |
|||
else: |
|||
# Includes not declared media types (Schema) |
|||
request_media_type = "application/json" |
|||
required = body_field.required |
|||
request_body_oai = {} |
|||
if required: |
|||
request_body_oai["required"] = required |
|||
request_body_oai["content"] = { |
|||
request_media_type: {"schema": body_schema} |
|||
} |
|||
operation["requestBody"] = request_body_oai |
|||
response_code = str(route.response_code) |
|||
response_schema = {"type": "string"} |
|||
if lenient_issubclass(route.response_wrapper, JSONResponse): |
|||
response_media_type = "application/json" |
|||
if route.response_field: |
|||
response_schema, _ = field_schema( |
|||
route.response_field, |
|||
model_name_map=model_name_map, |
|||
ref_prefix=REF_PREFIX, |
|||
) |
|||
else: |
|||
response_schema = {} |
|||
elif lenient_issubclass(route.response_wrapper, HTMLResponse): |
|||
response_media_type = "text/html" |
|||
else: |
|||
response_media_type = "text/plain" |
|||
content = {response_media_type: {"schema": response_schema}} |
|||
operation["responses"] = { |
|||
response_code: { |
|||
"description": route.response_description, |
|||
"content": content, |
|||
} |
|||
} |
|||
if all_route_params or route.body_field: |
|||
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { |
|||
"description": "Validation Error", |
|||
"content": { |
|||
"application/json": { |
|||
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"} |
|||
} |
|||
}, |
|||
} |
|||
path[method.lower()] = operation |
|||
return path, security_schemes, definitions |
|||
|
|||
|
|||
def get_openapi( |
|||
*, |
|||
title: str, |
|||
version: str, |
|||
openapi_version: str = "3.0.2", |
|||
description: str = None, |
|||
routes: Sequence[BaseRoute] |
|||
): |
|||
info = {"title": title, "version": version} |
|||
if description: |
|||
info["description"] = description |
|||
output = {"openapi": openapi_version, "info": info} |
|||
components: Dict[str, Dict] = {} |
|||
paths: Dict[str, Dict] = {} |
|||
flat_models = get_flat_models_from_routes(routes) |
|||
model_name_map = get_model_name_map(flat_models) |
|||
definitions = get_model_definitions( |
|||
flat_models=flat_models, model_name_map=model_name_map |
|||
) |
|||
for route in routes: |
|||
result = get_openapi_path(route=route, model_name_map=model_name_map) |
|||
if result: |
|||
path, security_schemes, path_definitions = result |
|||
if path: |
|||
paths.setdefault(route.path, {}).update(path) |
|||
if security_schemes: |
|||
components.setdefault("securitySchemes", {}).update(security_schemes) |
|||
if path_definitions: |
|||
definitions.update(path_definitions) |
|||
if definitions: |
|||
components.setdefault("schemas", {}).update(definitions) |
|||
if components: |
|||
output["components"] = components |
|||
output["paths"] = paths |
|||
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) |
|||
|
|||
|
|||
def get_swagger_ui_html(*, openapi_url: str, title: str): |
|||
return HTMLResponse( |
|||
""" |
|||
<! doctype html> |
|||
<html> |
|||
<head> |
|||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css"> |
|||
<title> |
|||
""" + title + """ |
|||
</title> |
|||
</head> |
|||
<body> |
|||
<div id="swagger-ui"> |
|||
</div> |
|||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script> |
|||
<!-- `SwaggerUIBundle` is now available on the page --> |
|||
<script> |
|||
|
|||
const ui = SwaggerUIBundle({ |
|||
url: '""" |
|||
+ openapi_url |
|||
+ """', |
|||
dom_id: '#swagger-ui', |
|||
presets: [ |
|||
SwaggerUIBundle.presets.apis, |
|||
SwaggerUIBundle.SwaggerUIStandalonePreset |
|||
], |
|||
layout: "BaseLayout" |
|||
|
|||
}) |
|||
</script> |
|||
</body> |
|||
</html> |
|||
""", |
|||
media_type="text/html", |
|||
) |
|||
|
|||
|
|||
def get_redoc_html(*, openapi_url: str, title: str): |
|||
return HTMLResponse( |
|||
""" |
|||
<!DOCTYPE html> |
|||
<html> |
|||
<head> |
|||
<title> |
|||
""" + title + """ |
|||
</title> |
|||
<!-- needed for adaptive design --> |
|||
<meta charset="utf-8"/> |
|||
<meta name="viewport" content="width=device-width, initial-scale=1"> |
|||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet"> |
|||
|
|||
<!-- |
|||
ReDoc doesn't change outer page styles |
|||
--> |
|||
<style> |
|||
body { |
|||
margin: 0; |
|||
padding: 0; |
|||
} |
|||
</style> |
|||
</head> |
|||
<body> |
|||
<redoc spec-url='""" + openapi_url + """'></redoc> |
|||
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script> |
|||
</body> |
|||
</html> |
|||
""", |
|||
media_type="text/html", |
|||
) |
@ -0,0 +1,46 @@ |
|||
import re |
|||
from typing import Dict, Sequence, Set, Type |
|||
|
|||
from starlette.routing import BaseRoute |
|||
|
|||
from fastapi import routing |
|||
from fastapi.openapi.constants import REF_PREFIX |
|||
from pydantic import BaseModel |
|||
from pydantic.fields import Field |
|||
from pydantic.schema import get_flat_models_from_fields, model_process_schema |
|||
|
|||
|
|||
def get_flat_models_from_routes(routes: Sequence[BaseRoute]): |
|||
body_fields_from_routes = [] |
|||
responses_from_routes = [] |
|||
for route in routes: |
|||
if route.include_in_schema and isinstance(route, routing.APIRoute): |
|||
if route.body_field: |
|||
assert isinstance( |
|||
route.body_field, Field |
|||
), "A request body must be a Pydantic Field" |
|||
body_fields_from_routes.append(route.body_field) |
|||
if route.response_field: |
|||
responses_from_routes.append(route.response_field) |
|||
flat_models = get_flat_models_from_fields( |
|||
body_fields_from_routes + responses_from_routes |
|||
) |
|||
return flat_models |
|||
|
|||
|
|||
def get_model_definitions( |
|||
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] |
|||
): |
|||
definitions: Dict[str, Dict] = {} |
|||
for model in flat_models: |
|||
m_schema, m_definitions = model_process_schema( |
|||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX |
|||
) |
|||
definitions.update(m_definitions) |
|||
model_name = model_name_map[model] |
|||
definitions[model_name] = m_schema |
|||
return definitions |
|||
|
|||
|
|||
def get_path_param_names(path: str): |
|||
return {item.strip("{}") for item in re.findall("{[^}]*}", path)} |
Loading…
Reference in new issue