commit
406c092a3b
11 changed files with 1787 additions and 0 deletions
@ -0,0 +1,3 @@ |
|||
"""Fast API framework, fast high performance, fast to learn, fast to code""" |
|||
|
|||
__version__ = '0.1' |
@ -0,0 +1,581 @@ |
|||
import typing |
|||
import inspect |
|||
|
|||
from starlette.applications import Starlette |
|||
from starlette.middleware.lifespan import LifespanMiddleware |
|||
from starlette.middleware.errors import ServerErrorMiddleware |
|||
from starlette.exceptions import ExceptionMiddleware |
|||
from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse |
|||
from starlette.requests import Request |
|||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY |
|||
|
|||
from pydantic import BaseModel, BaseConfig, Schema |
|||
from pydantic.utils import lenient_issubclass |
|||
from pydantic.fields import Field |
|||
from pydantic.schema import ( |
|||
field_schema, |
|||
get_flat_models_from_models, |
|||
get_flat_models_from_fields, |
|||
get_model_name_map, |
|||
schema, |
|||
model_process_schema, |
|||
) |
|||
|
|||
from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant |
|||
from .pydantic_utils import jsonable_encoder |
|||
|
|||
|
|||
def docs(openapi_url): |
|||
return HTMLResponse( |
|||
""" |
|||
<! doctype html> |
|||
<html> |
|||
<head> |
|||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css"> |
|||
</head> |
|||
<body> |
|||
<div id="swagger-ui"> |
|||
</div> |
|||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script> |
|||
<!-- `SwaggerUIBundle` is now available on the page --> |
|||
<script> |
|||
|
|||
const ui = SwaggerUIBundle({ |
|||
url: '""" + openapi_url + """', |
|||
dom_id: '#swagger-ui', |
|||
presets: [ |
|||
SwaggerUIBundle.presets.apis, |
|||
SwaggerUIBundle.SwaggerUIStandalonePreset |
|||
], |
|||
layout: "BaseLayout" |
|||
|
|||
}) |
|||
</script> |
|||
</body> |
|||
</html> |
|||
""", |
|||
media_type="text/html", |
|||
) |
|||
|
|||
|
|||
class FastAPI(Starlette): |
|||
def __init__( |
|||
self, |
|||
debug: bool = False, |
|||
template_directory: str = None, |
|||
title: str = "Fast API", |
|||
description: str = "", |
|||
version: str = "0.1.0", |
|||
openapi_url: str = "/openapi.json", |
|||
docs_url: str = "/docs", |
|||
**extra: typing.Dict[str, typing.Any], |
|||
) -> None: |
|||
self._debug = debug |
|||
self.router = APIRouter() |
|||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) |
|||
self.error_middleware = ServerErrorMiddleware( |
|||
self.exception_middleware, debug=debug |
|||
) |
|||
self.lifespan_middleware = LifespanMiddleware(self.error_middleware) |
|||
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] |
|||
self.template_env = self.load_template_env(template_directory) |
|||
|
|||
self.title = title |
|||
self.description = description |
|||
self.version = version |
|||
self.openapi_url = openapi_url |
|||
self.docs_url = docs_url |
|||
self.extra = extra |
|||
|
|||
self.openapi_version = "3.0.2" |
|||
|
|||
if self.openapi_url: |
|||
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'" |
|||
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'" |
|||
|
|||
if self.docs_url: |
|||
assert self.openapi_url, "The openapi_url is required for the docs" |
|||
|
|||
self.add_route( |
|||
self.openapi_url, |
|||
lambda req: JSONResponse(self.openapi()), |
|||
include_in_schema=False, |
|||
) |
|||
self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False) |
|||
|
|||
def add_api_route( |
|||
self, |
|||
path: str, |
|||
endpoint: typing.Callable, |
|||
methods: typing.List[str] = None, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
) -> None: |
|||
self.router.add_api_route( |
|||
path, |
|||
endpoint=endpoint, |
|||
methods=methods, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def api_route( |
|||
self, |
|||
path: str, |
|||
methods: typing.List[str] = None, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
) -> typing.Callable: |
|||
def decorator(func: typing.Callable) -> typing.Callable: |
|||
self.router.add_api_route( |
|||
path, |
|||
func, |
|||
methods=methods, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
return func |
|||
|
|||
return decorator |
|||
|
|||
def get( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.get( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def put( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.put( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def post( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.post( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def delete( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.delete( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def options( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.options( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def head( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.head( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def patch( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.patch( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def trace( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.router.trace( |
|||
path=path, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def openapi(self): |
|||
info = {"title": self.title, "version": self.version} |
|||
if self.description: |
|||
info["description"] = self.description |
|||
output = {"openapi": self.openapi_version, "info": info} |
|||
components = {} |
|||
paths = {} |
|||
methods_with_body = set(("POST", "PUT")) |
|||
body_fields_from_routes = [] |
|||
responses_from_routes = [] |
|||
ref_prefix = "#/components/schemas/" |
|||
for route in self.routes: |
|||
route: APIRoute |
|||
if route.include_in_schema and isinstance(route, APIRoute): |
|||
if route.request_body: |
|||
assert isinstance( |
|||
route.request_body, Field |
|||
), "A request body must be a Pydantic BaseModel or Field" |
|||
body_fields_from_routes.append(route.request_body) |
|||
if route.response_field: |
|||
responses_from_routes.append(route.response_field) |
|||
flat_models = get_flat_models_from_fields( |
|||
body_fields_from_routes + responses_from_routes |
|||
) |
|||
model_name_map = get_model_name_map(flat_models) |
|||
definitions = {} |
|||
for model in flat_models: |
|||
m_schema, m_definitions = model_process_schema( |
|||
model, model_name_map=model_name_map, ref_prefix=ref_prefix |
|||
) |
|||
definitions.update(m_definitions) |
|||
model_name = model_name_map[model] |
|||
definitions[model_name] = m_schema |
|||
validation_error_definition = { |
|||
"title": "ValidationError", |
|||
"type": "object", |
|||
"properties": { |
|||
"loc": { |
|||
"title": "Location", |
|||
"type": "array", |
|||
"items": {"type": "string"}, |
|||
}, |
|||
"msg": {"title": "Message", "type": "string"}, |
|||
"type": {"title": "Error Type", "type": "string"}, |
|||
}, |
|||
"required": ["loc", "msg", "type"], |
|||
} |
|||
validation_error_response_definition = { |
|||
"title": "HTTPValidationError", |
|||
"type": "object", |
|||
"properties": { |
|||
"detail": { |
|||
"title": "Detail", |
|||
"type": "array", |
|||
"items": {"$ref": ref_prefix + "ValidationError"}, |
|||
} |
|||
}, |
|||
} |
|||
for route in self.routes: |
|||
route: APIRoute |
|||
if route.include_in_schema and isinstance(route, APIRoute): |
|||
path = paths.get(route.path, {}) |
|||
for method in route.methods: |
|||
operation = {} |
|||
if route.tags: |
|||
operation["tags"] = route.tags |
|||
if route.summary: |
|||
operation["summary"] = route.summary |
|||
if route.description: |
|||
operation["description"] = route.description |
|||
if route.operation_id: |
|||
operation["operationId"] = route.operation_id |
|||
else: |
|||
operation["operationId"] = route.name |
|||
if route.deprecated: |
|||
operation["deprecated"] = route.deprecated |
|||
parameters = [] |
|||
flat_dependant = get_flat_dependant(route.dependant) |
|||
security_definitions = {} |
|||
for security_scheme in flat_dependant.security_schemes: |
|||
security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False) |
|||
security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__ |
|||
security_definitions[security_name] = security_definition |
|||
if security_definitions: |
|||
components.setdefault("securitySchemes", {}).update(security_definitions) |
|||
operation["security"] = [{name: []} for name in security_definitions] |
|||
all_route_params = get_openapi_params(route.dependant) |
|||
for param in all_route_params: |
|||
if "ValidationError" not in definitions: |
|||
definitions["ValidationError"] = validation_error_definition |
|||
definitions[ |
|||
"HTTPValidationError" |
|||
] = validation_error_response_definition |
|||
parameter = { |
|||
"name": param.alias, |
|||
"in": param.schema.in_.value, |
|||
"required": param.required, |
|||
"schema": field_schema(param, model_name_map={})[0], |
|||
} |
|||
if param.schema.description: |
|||
parameter["description"] = param.schema.description |
|||
if param.schema.deprecated: |
|||
parameter["deprecated"] = param.schema.deprecated |
|||
parameters.append(parameter) |
|||
if parameters: |
|||
operation["parameters"] = parameters |
|||
if method in methods_with_body: |
|||
request_body = getattr(route, "request_body", None) |
|||
if request_body: |
|||
assert isinstance(request_body, Field) |
|||
body_schema, _ = field_schema( |
|||
request_body, |
|||
model_name_map=model_name_map, |
|||
ref_prefix=ref_prefix, |
|||
) |
|||
required = request_body.required |
|||
request_body_oai = {} |
|||
if required: |
|||
request_body_oai["required"] = required |
|||
request_body_oai["content"] = { |
|||
"application/json": {"schema": body_schema} |
|||
} |
|||
operation["requestBody"] = request_body_oai |
|||
response_code = str(route.response_code) |
|||
response_schema = {"type": "string"} |
|||
if lenient_issubclass(route.response_wrapper, JSONResponse): |
|||
response_media_type = "application/json" |
|||
if route.response_field: |
|||
response_schema, _ = field_schema( |
|||
route.response_field, |
|||
model_name_map=model_name_map, |
|||
ref_prefix=ref_prefix, |
|||
) |
|||
else: |
|||
response_schema = {} |
|||
elif lenient_issubclass(route.response_wrapper, HTMLResponse): |
|||
response_media_type = "text/html" |
|||
else: |
|||
response_media_type = "text/plain" |
|||
content = {response_media_type: {"schema": response_schema}} |
|||
operation["responses"] = { |
|||
response_code: { |
|||
"description": route.response_description, |
|||
"content": content, |
|||
} |
|||
} |
|||
if all_route_params or getattr(route, "request_body", None): |
|||
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { |
|||
"description": "Validation Error", |
|||
"content": { |
|||
"application/json": { |
|||
"schema": { |
|||
"$ref": ref_prefix + "HTTPValidationError" |
|||
} |
|||
} |
|||
}, |
|||
} |
|||
path[method.lower()] = operation |
|||
paths[route.path] = path |
|||
if definitions: |
|||
components.setdefault("schemas", {}).update(definitions) |
|||
if components: |
|||
output["components"] = components |
|||
output["paths"] = paths |
|||
return output |
@ -0,0 +1,246 @@ |
|||
from typing import Sequence |
|||
from enum import Enum |
|||
from pydantic import Schema |
|||
|
|||
|
|||
class ParamTypes(Enum): |
|||
query = "query" |
|||
header = "header" |
|||
path = "path" |
|||
cookie = "cookie" |
|||
|
|||
|
|||
class Param(Schema): |
|||
in_: ParamTypes |
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
deprecated: bool = None, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.deprecated = deprecated |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Path(Param): |
|||
in_ = ParamTypes.path |
|||
|
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
deprecated: bool = None, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.description = description |
|||
self.deprecated = deprecated |
|||
self.in_ = self.in_ |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Query(Param): |
|||
in_ = ParamTypes.query |
|||
|
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
deprecated: bool = None, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.description = description |
|||
self.deprecated = deprecated |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Header(Param): |
|||
in_ = ParamTypes.header |
|||
|
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
deprecated: bool = None, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.description = description |
|||
self.deprecated = deprecated |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Cookie(Param): |
|||
in_ = ParamTypes.cookie |
|||
|
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
deprecated: bool = None, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.description = description |
|||
self.deprecated = deprecated |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Body(Schema): |
|||
def __init__( |
|||
self, |
|||
default, |
|||
*, |
|||
sub_key=False, |
|||
alias: str = None, |
|||
title: str = None, |
|||
description: str = None, |
|||
gt: float = None, |
|||
ge: float = None, |
|||
lt: float = None, |
|||
le: float = None, |
|||
min_length: int = None, |
|||
max_length: int = None, |
|||
regex: str = None, |
|||
**extra: object, |
|||
): |
|||
self.sub_key = sub_key |
|||
super().__init__( |
|||
default, |
|||
alias=alias, |
|||
title=title, |
|||
description=description, |
|||
gt=gt, |
|||
ge=ge, |
|||
lt=lt, |
|||
le=le, |
|||
min_length=min_length, |
|||
max_length=max_length, |
|||
regex=regex, |
|||
**extra, |
|||
) |
|||
|
|||
|
|||
class Depends: |
|||
def __init__(self, dependency = None): |
|||
self.dependency = dependency |
|||
|
|||
|
|||
class Security: |
|||
def __init__(self, security_scheme = None, scopes: Sequence[str] = None): |
|||
self.security_scheme = security_scheme |
|||
self.scopes = scopes |
|||
|
@ -0,0 +1,33 @@ |
|||
from types import GeneratorType |
|||
from typing import Set |
|||
from pydantic import BaseModel |
|||
from enum import Enum |
|||
from pydantic.json import pydantic_encoder |
|||
|
|||
|
|||
def jsonable_encoder( |
|||
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True, |
|||
): |
|||
if isinstance(obj, BaseModel): |
|||
return jsonable_encoder( |
|||
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none |
|||
) |
|||
elif isinstance(obj, Enum): |
|||
return obj.value |
|||
if isinstance(obj, (str, int, float, type(None))): |
|||
return obj |
|||
if isinstance(obj, dict): |
|||
return { |
|||
jsonable_encoder( |
|||
key, by_alias=by_alias, include_none=include_none, |
|||
): jsonable_encoder( |
|||
value, by_alias=by_alias, include_none=include_none, |
|||
) |
|||
for key, value in obj.items() if value is not None or include_none |
|||
} |
|||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): |
|||
return [ |
|||
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none) |
|||
for item in obj |
|||
] |
|||
return pydantic_encoder(obj) |
@ -0,0 +1,785 @@ |
|||
import asyncio |
|||
import inspect |
|||
import re |
|||
import typing |
|||
from copy import deepcopy |
|||
|
|||
from starlette import routing |
|||
from starlette.routing import get_name, request_response |
|||
from starlette.requests import Request |
|||
from starlette.responses import Response, JSONResponse |
|||
from starlette.concurrency import run_in_threadpool |
|||
from starlette.exceptions import HTTPException |
|||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY |
|||
|
|||
|
|||
from pydantic.fields import Field, Required |
|||
from pydantic.schema import get_annotation_from_schema |
|||
from pydantic import BaseConfig, BaseModel, create_model, Schema |
|||
from pydantic.error_wrappers import ErrorWrapper, ValidationError |
|||
from pydantic.errors import MissingError |
|||
from pydantic.utils import lenient_issubclass |
|||
from .pydantic_utils import jsonable_encoder |
|||
|
|||
from fastapi import params |
|||
from fastapi.security.base import SecurityBase |
|||
|
|||
|
|||
param_supported_types = (str, int, float, bool) |
|||
|
|||
|
|||
class Dependant: |
|||
def __init__( |
|||
self, |
|||
*, |
|||
path_params: typing.List[Field] = None, |
|||
query_params: typing.List[Field] = None, |
|||
header_params: typing.List[Field] = None, |
|||
cookie_params: typing.List[Field] = None, |
|||
body_params: typing.List[Field] = None, |
|||
dependencies: typing.List["Dependant"] = None, |
|||
security_schemes: typing.List[Field] = None, |
|||
name: str = None, |
|||
call: typing.Callable = None, |
|||
request_param_name: str = None, |
|||
) -> None: |
|||
self.path_params: typing.List[Field] = path_params or [] |
|||
self.query_params: typing.List[Field] = query_params or [] |
|||
self.header_params: typing.List[Field] = header_params or [] |
|||
self.cookie_params: typing.List[Field] = cookie_params or [] |
|||
self.body_params: typing.List[Field] = body_params or [] |
|||
self.dependencies: typing.List[Dependant] = dependencies or [] |
|||
self.security_schemes: typing.List[Field] = security_schemes or [] |
|||
self.request_param_name = request_param_name |
|||
self.name = name |
|||
self.call: typing.Callable = call |
|||
|
|||
|
|||
def request_params_to_args( |
|||
required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any] |
|||
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
for field in required_params: |
|||
value = received_params.get(field.alias) |
|||
if value is None: |
|||
if field.required: |
|||
errors.append( |
|||
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) |
|||
) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
v_, errors_ = field.validate( |
|||
value, values, loc=(field.schema.in_.value, field.alias) |
|||
) |
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors_: ErrorWrapper |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
|
|||
|
|||
def request_body_to_args( |
|||
required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any] |
|||
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: |
|||
values = {} |
|||
errors = [] |
|||
if required_params: |
|||
field = required_params[0] |
|||
sub_key = getattr(field.schema, "sub_key", None) |
|||
if len(required_params) == 1 and not sub_key: |
|||
received_body = {field.alias: received_body} |
|||
for field in required_params: |
|||
value = received_body.get(field.alias) |
|||
if value is None: |
|||
if field.required: |
|||
errors.append( |
|||
ErrorWrapper( |
|||
MissingError(), loc=("body", field.alias), config=BaseConfig |
|||
) |
|||
) |
|||
else: |
|||
values[field.name] = deepcopy(field.default) |
|||
continue |
|||
|
|||
v_, errors_ = field.validate(value, values, loc=("body", field.alias)) |
|||
if isinstance(errors_, ErrorWrapper): |
|||
errors_: ErrorWrapper |
|||
errors.append(errors_) |
|||
elif isinstance(errors_, list): |
|||
errors.extend(errors_) |
|||
else: |
|||
values[field.name] = v_ |
|||
return values, errors |
|||
|
|||
|
|||
def add_param_to_fields( |
|||
*, |
|||
param: inspect.Parameter, |
|||
dependant: Dependant, |
|||
default_schema=params.Param, |
|||
force_type: params.ParamTypes = None, |
|||
): |
|||
default_value = Required |
|||
if not param.default == param.empty: |
|||
default_value = param.default |
|||
if isinstance(default_value, params.Param): |
|||
schema = default_value |
|||
default_value = schema.default |
|||
if schema.in_ is None: |
|||
schema.in_ = default_schema.in_ |
|||
if force_type: |
|||
schema.in_ = force_type |
|||
else: |
|||
schema = default_schema(default_value) |
|||
required = default_value == Required |
|||
annotation = typing.Any |
|||
if not param.annotation == param.empty: |
|||
annotation = param.annotation |
|||
annotation = get_annotation_from_schema(annotation, schema) |
|||
Config = BaseConfig |
|||
field = Field( |
|||
name=param.name, |
|||
type_=annotation, |
|||
default=None if required else default_value, |
|||
alias=schema.alias or param.name, |
|||
required=required, |
|||
model_config=Config, |
|||
class_validators=[], |
|||
schema=schema, |
|||
) |
|||
if schema.in_ == params.ParamTypes.path: |
|||
dependant.path_params.append(field) |
|||
elif schema.in_ == params.ParamTypes.query: |
|||
dependant.query_params.append(field) |
|||
elif schema.in_ == params.ParamTypes.header: |
|||
dependant.header_params.append(field) |
|||
else: |
|||
assert ( |
|||
schema.in_ == params.ParamTypes.cookie |
|||
), f"non-body parameters must be in path, query, header or cookie: {param.name}" |
|||
dependant.cookie_params.append(field) |
|||
|
|||
|
|||
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): |
|||
default_value = Required |
|||
if not param.default == param.empty: |
|||
default_value = param.default |
|||
if isinstance(default_value, Schema): |
|||
schema = default_value |
|||
default_value = schema.default |
|||
else: |
|||
schema = Schema(default_value) |
|||
required = default_value == Required |
|||
annotation = get_annotation_from_schema(param.annotation, schema) |
|||
field = Field( |
|||
name=param.name, |
|||
type_=annotation, |
|||
default=None if required else default_value, |
|||
alias=schema.alias or param.name, |
|||
required=required, |
|||
model_config=BaseConfig, |
|||
class_validators=[], |
|||
schema=schema, |
|||
) |
|||
dependant.body_params.append(field) |
|||
|
|||
|
|||
def get_sub_dependant( |
|||
*, param: inspect.Parameter, path: str |
|||
): |
|||
depends: params.Depends = param.default |
|||
if depends.dependency: |
|||
dependency = depends.dependency |
|||
else: |
|||
dependency = param.annotation |
|||
assert callable(dependency) |
|||
sub_dependant = get_dependant(path=path, call=dependency, name=param.name) |
|||
if isinstance(dependency, SecurityBase): |
|||
sub_dependant.security_schemes.append(dependency) |
|||
return sub_dependant |
|||
|
|||
|
|||
def get_flat_dependant(dependant: Dependant): |
|||
flat_dependant = Dependant( |
|||
path_params=dependant.path_params.copy(), |
|||
query_params=dependant.query_params.copy(), |
|||
header_params=dependant.header_params.copy(), |
|||
cookie_params=dependant.cookie_params.copy(), |
|||
body_params=dependant.body_params.copy(), |
|||
security_schemes=dependant.security_schemes.copy(), |
|||
) |
|||
for sub_dependant in dependant.dependencies: |
|||
if sub_dependant is dependant: |
|||
raise ValueError("recursion", dependant.dependencies) |
|||
flat_sub = get_flat_dependant(sub_dependant) |
|||
flat_dependant.path_params.extend(flat_sub.path_params) |
|||
flat_dependant.query_params.extend(flat_sub.query_params) |
|||
flat_dependant.header_params.extend(flat_sub.header_params) |
|||
flat_dependant.cookie_params.extend(flat_sub.cookie_params) |
|||
flat_dependant.body_params.extend(flat_sub.body_params) |
|||
flat_dependant.security_schemes.extend(flat_sub.security_schemes) |
|||
return flat_dependant |
|||
|
|||
|
|||
def get_path_param_names(path: str): |
|||
return {item.strip("{}") for item in re.findall("{[^}]*}", path)} |
|||
|
|||
|
|||
def get_dependant(*, path: str, call: typing.Callable, name: str = None): |
|||
path_param_names = get_path_param_names(path) |
|||
endpoint_signature = inspect.signature(call) |
|||
signature_params = endpoint_signature.parameters |
|||
dependant = Dependant(call=call, name=name) |
|||
for param_name in signature_params: |
|||
param = signature_params[param_name] |
|||
if isinstance(param.default, params.Depends): |
|||
sub_dependant = get_sub_dependant(param=param, path=path) |
|||
dependant.dependencies.append(sub_dependant) |
|||
for param_name in signature_params: |
|||
param = signature_params[param_name] |
|||
if ( |
|||
(param.default == param.empty) or isinstance(param.default, params.Path) |
|||
) and (param_name in path_param_names): |
|||
assert lenient_issubclass( |
|||
param.annotation, param_supported_types |
|||
), f"Path params must be of type str, int, float or boot: {param}" |
|||
param = signature_params[param_name] |
|||
add_param_to_fields( |
|||
param=param, |
|||
dependant=dependant, |
|||
default_schema=params.Path, |
|||
force_type=params.ParamTypes.path, |
|||
) |
|||
elif (param.default == param.empty or param.default is None) and ( |
|||
param.annotation == param.empty |
|||
or lenient_issubclass(param.annotation, param_supported_types) |
|||
): |
|||
add_param_to_fields( |
|||
param=param, dependant=dependant, default_schema=params.Query |
|||
) |
|||
elif isinstance(param.default, params.Param): |
|||
if param.annotation != param.empty: |
|||
assert lenient_issubclass( |
|||
param.annotation, param_supported_types |
|||
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" |
|||
add_param_to_fields( |
|||
param=param, dependant=dependant, default_schema=params.Query |
|||
) |
|||
elif lenient_issubclass(param.annotation, Request): |
|||
dependant.request_param_name = param_name |
|||
elif not isinstance(param.default, params.Depends): |
|||
add_param_to_body_fields(param=param, dependant=dependant) |
|||
return dependant |
|||
|
|||
|
|||
def is_coroutine_callable(call: typing.Callable): |
|||
if inspect.isfunction(call): |
|||
return asyncio.iscoroutinefunction(call) |
|||
elif inspect.isclass(call): |
|||
return False |
|||
else: |
|||
call = getattr(call, "__call__", None) |
|||
if not call: |
|||
return False |
|||
else: |
|||
return asyncio.iscoroutinefunction(call) |
|||
|
|||
|
|||
async def solve_dependencies(*, request: Request, dependant: Dependant): |
|||
values = {} |
|||
errors = [] |
|||
for sub_dependant in dependant.dependencies: |
|||
sub_values, sub_errors = await solve_dependencies( |
|||
request=request, dependant=sub_dependant |
|||
) |
|||
if sub_errors: |
|||
return {}, errors |
|||
if is_coroutine_callable(sub_dependant.call): |
|||
solved = await sub_dependant.call(**sub_values) |
|||
else: |
|||
solved = await run_in_threadpool(sub_dependant.call, **sub_values) |
|||
values[sub_dependant.name] = solved |
|||
path_values, path_errors = request_params_to_args( |
|||
dependant.path_params, request.path_params |
|||
) |
|||
query_values, query_errors = request_params_to_args( |
|||
dependant.query_params, request.query_params |
|||
) |
|||
header_values, header_errors = request_params_to_args( |
|||
dependant.header_params, request.headers |
|||
) |
|||
cookie_values, cookie_errors = request_params_to_args( |
|||
dependant.cookie_params, request.cookies |
|||
) |
|||
values.update(path_values) |
|||
values.update(query_values) |
|||
values.update(header_values) |
|||
values.update(cookie_values) |
|||
errors = path_errors + query_errors + header_errors + cookie_errors |
|||
if dependant.body_params: |
|||
body = await request.json() |
|||
body_values, body_errors = request_body_to_args(dependant.body_params, body) |
|||
values.update(body_values) |
|||
errors.extend(body_errors) |
|||
if dependant.request_param_name: |
|||
values[dependant.request_param_name] = request |
|||
return values, errors |
|||
|
|||
|
|||
def get_app(dependant: Dependant): |
|||
is_coroutine = asyncio.iscoroutinefunction(dependant.call) |
|||
|
|||
async def app(request: Request) -> Response: |
|||
values, errors = await solve_dependencies(request=request, dependant=dependant) |
|||
if errors: |
|||
errors_out = ValidationError(errors) |
|||
raise HTTPException( |
|||
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() |
|||
) |
|||
else: |
|||
if is_coroutine: |
|||
raw_response = await dependant.call(**values) |
|||
else: |
|||
raw_response = await run_in_threadpool(dependant.call, **values) |
|||
if isinstance(raw_response, Response): |
|||
return raw_response |
|||
else: |
|||
return JSONResponse(content=jsonable_encoder(raw_response)) |
|||
return app |
|||
|
|||
|
|||
def get_openapi_params(dependant: Dependant): |
|||
flat_dependant = get_flat_dependant(dependant) |
|||
return ( |
|||
flat_dependant.path_params |
|||
+ flat_dependant.query_params |
|||
+ flat_dependant.header_params |
|||
+ flat_dependant.cookie_params |
|||
) |
|||
|
|||
|
|||
class APIRoute(routing.Route): |
|||
def __init__( |
|||
self, |
|||
path: str, |
|||
endpoint: typing.Callable, |
|||
*, |
|||
methods: typing.List[str] = None, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
) -> None: |
|||
# TODO define how to read and provide security params, and how to have them globally too |
|||
# TODO implement dependencies and injection |
|||
# TODO refactor code structure |
|||
# TODO create testing |
|||
# TODO testing coverage |
|||
assert path.startswith("/"), "Routed paths must always start '/'" |
|||
self.path = path |
|||
self.endpoint = endpoint |
|||
self.name = get_name(endpoint) if name is None else name |
|||
self.include_in_schema = include_in_schema |
|||
self.tags = tags |
|||
self.summary = summary |
|||
self.description = description |
|||
self.operation_id = operation_id |
|||
self.deprecated = deprecated |
|||
self.request_body: typing.Union[BaseModel, Field, None] = None |
|||
self.response_description = response_description |
|||
self.response_code = response_code |
|||
self.response_wrapper = response_wrapper |
|||
self.response_field = None |
|||
if response_type: |
|||
assert lenient_issubclass( |
|||
response_wrapper, JSONResponse |
|||
), "To declare a type the response must be a JSON response" |
|||
self.response_type = response_type |
|||
response_name = "Response_" + self.name |
|||
self.response_field = Field( |
|||
name=response_name, |
|||
type_=self.response_type, |
|||
class_validators=[], |
|||
default=None, |
|||
required=False, |
|||
model_config=BaseConfig(), |
|||
schema=Schema(None), |
|||
) |
|||
else: |
|||
self.response_type = None |
|||
if methods is None: |
|||
methods = ["GET"] |
|||
self.methods = methods |
|||
self.path_regex, self.path_format, self.param_convertors = self.compile_path( |
|||
path |
|||
) |
|||
assert inspect.isfunction(endpoint) or inspect.ismethod( |
|||
endpoint |
|||
), f"An endpoint must be a function or method" |
|||
|
|||
self.dependant = get_dependant(path=path, call=self.endpoint) |
|||
# flat_dependant = get_flat_dependant(self.dependant) |
|||
# path_param_names = get_path_param_names(path) |
|||
# for path_param in path_param_names: |
|||
# assert path_param in { |
|||
# f.alias for f in flat_dependant.path_params |
|||
# }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}" |
|||
|
|||
if self.dependant.body_params: |
|||
first_param = self.dependant.body_params[0] |
|||
sub_key = getattr(first_param.schema, "sub_key", None) |
|||
if len(self.dependant.body_params) == 1 and not sub_key: |
|||
self.request_body = first_param |
|||
else: |
|||
model_name = "Body_" + self.name |
|||
BodyModel = create_model(model_name) |
|||
for f in self.dependant.body_params: |
|||
BodyModel.__fields__[f.name] = f |
|||
required = any(True for f in self.dependant.body_params if f.required) |
|||
field = Field( |
|||
name="body", |
|||
type_=BodyModel, |
|||
default=None, |
|||
required=required, |
|||
model_config=BaseConfig, |
|||
class_validators=[], |
|||
alias="body", |
|||
schema=Schema(None), |
|||
) |
|||
self.request_body = field |
|||
|
|||
self.app = request_response(get_app(dependant=self.dependant)) |
|||
|
|||
|
|||
class APIRouter(routing.Router): |
|||
def add_api_route( |
|||
self, |
|||
path: str, |
|||
endpoint: typing.Callable, |
|||
methods: typing.List[str] = None, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
) -> None: |
|||
route = APIRoute( |
|||
path, |
|||
endpoint=endpoint, |
|||
methods=methods, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
self.routes.append(route) |
|||
|
|||
def api_route( |
|||
self, |
|||
path: str, |
|||
methods: typing.List[str] = None, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
) -> typing.Callable: |
|||
def decorator(func: typing.Callable) -> typing.Callable: |
|||
self.add_api_route( |
|||
path, |
|||
func, |
|||
methods=methods, |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
return func |
|||
|
|||
return decorator |
|||
|
|||
def get( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["GET"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def put( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["PUT"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def post( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["POST"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def delete( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["DELETE"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def options( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["OPTIONS"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def head( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["HEAD"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def patch( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["PATCH"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
|||
|
|||
def trace( |
|||
self, |
|||
path: str, |
|||
name: str = None, |
|||
include_in_schema: bool = True, |
|||
tags: typing.List[str] = [], |
|||
summary: str = None, |
|||
description: str = None, |
|||
operation_id: str = None, |
|||
deprecated: bool = None, |
|||
response_type: typing.Type = None, |
|||
response_description: str = "Successful Response", |
|||
response_code=200, |
|||
response_wrapper=JSONResponse, |
|||
): |
|||
return self.api_route( |
|||
path=path, |
|||
methods=["TRACE"], |
|||
name=name, |
|||
include_in_schema=include_in_schema, |
|||
tags=tags, |
|||
summary=summary, |
|||
description=description, |
|||
operation_id=operation_id, |
|||
deprecated=deprecated, |
|||
response_type=response_type, |
|||
response_description=response_description, |
|||
response_code=response_code, |
|||
response_wrapper=response_wrapper, |
|||
) |
@ -0,0 +1,40 @@ |
|||
from starlette.requests import Request |
|||
|
|||
from pydantic import Schema |
|||
from enum import Enum |
|||
from .base import SecurityBase, Types |
|||
|
|||
__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"] |
|||
|
|||
|
|||
class APIKeyIn(Enum): |
|||
query = "query" |
|||
header = "header" |
|||
cookie = "cookie" |
|||
|
|||
|
|||
class APIKeyBase(SecurityBase): |
|||
type_ = Schema(Types.apiKey, alias="type") |
|||
in_: str = Schema(..., alias="in") |
|||
name: str |
|||
|
|||
|
|||
class APIKeyQuery(APIKeyBase): |
|||
in_ = Schema(APIKeyIn.query, alias="in") |
|||
|
|||
async def __call__(self, requests: Request): |
|||
return requests.query_params.get(self.name) |
|||
|
|||
|
|||
class APIKeyHeader(APIKeyBase): |
|||
in_ = Schema(APIKeyIn.header, alias="in") |
|||
|
|||
async def __call__(self, requests: Request): |
|||
return requests.headers.get(self.name) |
|||
|
|||
|
|||
class APIKeyCookie(APIKeyBase): |
|||
in_ = Schema(APIKeyIn.cookie, alias="in") |
|||
|
|||
async def __call__(self, requests: Request): |
|||
return requests.cookies.get(self.name) |
@ -0,0 +1,17 @@ |
|||
from enum import Enum |
|||
from pydantic import BaseModel, Schema |
|||
|
|||
__all__ = ["Types", "SecurityBase"] |
|||
|
|||
|
|||
class Types(Enum): |
|||
apiKey = "apiKey" |
|||
http = "http" |
|||
oauth2 = "oauth2" |
|||
openIdConnect = "openIdConnect" |
|||
|
|||
|
|||
class SecurityBase(BaseModel): |
|||
scheme_name: str = None |
|||
type_: Types = Schema(..., alias="type") |
|||
description: str = None |
@ -0,0 +1,27 @@ |
|||
|
|||
from starlette.requests import Request |
|||
from pydantic import Schema |
|||
from .base import SecurityBase, Types |
|||
|
|||
__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] |
|||
|
|||
|
|||
class HTTPBase(SecurityBase): |
|||
type_ = Schema(Types.http, alias="type") |
|||
scheme: str |
|||
|
|||
async def __call__(self, request: Request): |
|||
return request.headers.get("Authorization") |
|||
|
|||
|
|||
class HTTPBasic(HTTPBase): |
|||
scheme = "basic" |
|||
|
|||
|
|||
class HTTPBearer(HTTPBase): |
|||
scheme = "bearer" |
|||
bearerFormat: str = None |
|||
|
|||
|
|||
class HTTPDigest(HTTPBase): |
|||
scheme = "digest" |
@ -0,0 +1,45 @@ |
|||
from typing import Dict |
|||
|
|||
from pydantic import BaseModel, Schema |
|||
|
|||
from starlette.requests import Request |
|||
from .base import SecurityBase, Types |
|||
|
|||
# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] |
|||
|
|||
|
|||
class OAuthFlow(BaseModel): |
|||
refreshUrl: str = None |
|||
scopes: Dict[str, str] = {} |
|||
|
|||
|
|||
class OAuthFlowImplicit(OAuthFlow): |
|||
authorizationUrl: str |
|||
|
|||
|
|||
class OAuthFlowPassword(OAuthFlow): |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlowClientCredentials(OAuthFlow): |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlowAuthorizationCode(OAuthFlow): |
|||
authorizationUrl: str |
|||
tokenUrl: str |
|||
|
|||
|
|||
class OAuthFlows(BaseModel): |
|||
implicit: OAuthFlowImplicit = None |
|||
password: OAuthFlowPassword = None |
|||
clientCredentials: OAuthFlowClientCredentials = None |
|||
authorizationCode: OAuthFlowAuthorizationCode = None |
|||
|
|||
|
|||
class OAuth2(SecurityBase): |
|||
type_ = Schema(Types.oauth2, alias="type") |
|||
flows: OAuthFlows |
|||
|
|||
async def __call__(self, request: Request): |
|||
return request.headers.get("Authorization") |
@ -0,0 +1,10 @@ |
|||
from starlette.requests import Request |
|||
|
|||
from .base import SecurityBase, Types |
|||
|
|||
class OpenIdConnect(SecurityBase): |
|||
type_ = Types.openIdConnect |
|||
openIdConnectUrl: str |
|||
|
|||
async def __call__(self, request: Request): |
|||
return request.headers.get("Authorization") |
Loading…
Reference in new issue