pythonasyncioapiasyncfastapiframeworkjsonjson-schemaopenapiopenapi3pydanticpython-typespython3redocreststarletteswaggerswagger-uiuvicornweb
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
21 KiB
581 lines
21 KiB
import typing
|
|
import inspect
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware.lifespan import LifespanMiddleware
|
|
from starlette.middleware.errors import ServerErrorMiddleware
|
|
from starlette.exceptions import ExceptionMiddleware
|
|
from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
|
|
from starlette.requests import Request
|
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
from pydantic import BaseModel, BaseConfig, Schema
|
|
from pydantic.utils import lenient_issubclass
|
|
from pydantic.fields import Field
|
|
from pydantic.schema import (
|
|
field_schema,
|
|
get_flat_models_from_models,
|
|
get_flat_models_from_fields,
|
|
get_model_name_map,
|
|
schema,
|
|
model_process_schema,
|
|
)
|
|
|
|
from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
|
|
from .pydantic_utils import jsonable_encoder
|
|
|
|
|
|
def docs(openapi_url):
|
|
return HTMLResponse(
|
|
"""
|
|
<! doctype html>
|
|
<html>
|
|
<head>
|
|
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
|
</head>
|
|
<body>
|
|
<div id="swagger-ui">
|
|
</div>
|
|
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
|
<!-- `SwaggerUIBundle` is now available on the page -->
|
|
<script>
|
|
|
|
const ui = SwaggerUIBundle({
|
|
url: '""" + openapi_url + """',
|
|
dom_id: '#swagger-ui',
|
|
presets: [
|
|
SwaggerUIBundle.presets.apis,
|
|
SwaggerUIBundle.SwaggerUIStandalonePreset
|
|
],
|
|
layout: "BaseLayout"
|
|
|
|
})
|
|
</script>
|
|
</body>
|
|
</html>
|
|
""",
|
|
media_type="text/html",
|
|
)
|
|
|
|
|
|
class FastAPI(Starlette):
|
|
def __init__(
|
|
self,
|
|
debug: bool = False,
|
|
template_directory: str = None,
|
|
title: str = "Fast API",
|
|
description: str = "",
|
|
version: str = "0.1.0",
|
|
openapi_url: str = "/openapi.json",
|
|
docs_url: str = "/docs",
|
|
**extra: typing.Dict[str, typing.Any],
|
|
) -> None:
|
|
self._debug = debug
|
|
self.router = APIRouter()
|
|
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
|
self.error_middleware = ServerErrorMiddleware(
|
|
self.exception_middleware, debug=debug
|
|
)
|
|
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
|
|
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
|
|
self.template_env = self.load_template_env(template_directory)
|
|
|
|
self.title = title
|
|
self.description = description
|
|
self.version = version
|
|
self.openapi_url = openapi_url
|
|
self.docs_url = docs_url
|
|
self.extra = extra
|
|
|
|
self.openapi_version = "3.0.2"
|
|
|
|
if self.openapi_url:
|
|
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
|
|
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
|
|
|
|
if self.docs_url:
|
|
assert self.openapi_url, "The openapi_url is required for the docs"
|
|
|
|
self.add_route(
|
|
self.openapi_url,
|
|
lambda req: JSONResponse(self.openapi()),
|
|
include_in_schema=False,
|
|
)
|
|
self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
|
|
|
|
def add_api_route(
|
|
self,
|
|
path: str,
|
|
endpoint: typing.Callable,
|
|
methods: typing.List[str] = None,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
) -> None:
|
|
self.router.add_api_route(
|
|
path,
|
|
endpoint=endpoint,
|
|
methods=methods,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def api_route(
|
|
self,
|
|
path: str,
|
|
methods: typing.List[str] = None,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
) -> typing.Callable:
|
|
def decorator(func: typing.Callable) -> typing.Callable:
|
|
self.router.add_api_route(
|
|
path,
|
|
func,
|
|
methods=methods,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def get(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.get(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def put(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.put(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def post(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.post(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.delete(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def options(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.options(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def head(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.head(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def patch(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.patch(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def trace(
|
|
self,
|
|
path: str,
|
|
name: str = None,
|
|
include_in_schema: bool = True,
|
|
tags: typing.List[str] = [],
|
|
summary: str = None,
|
|
description: str = None,
|
|
operation_id: str = None,
|
|
deprecated: bool = None,
|
|
response_type: typing.Type = None,
|
|
response_description: str = "Successful Response",
|
|
response_code=200,
|
|
response_wrapper=JSONResponse,
|
|
):
|
|
return self.router.trace(
|
|
path=path,
|
|
name=name,
|
|
include_in_schema=include_in_schema,
|
|
tags=tags,
|
|
summary=summary,
|
|
description=description,
|
|
operation_id=operation_id,
|
|
deprecated=deprecated,
|
|
response_type=response_type,
|
|
response_description=response_description,
|
|
response_code=response_code,
|
|
response_wrapper=response_wrapper,
|
|
)
|
|
|
|
def openapi(self):
|
|
info = {"title": self.title, "version": self.version}
|
|
if self.description:
|
|
info["description"] = self.description
|
|
output = {"openapi": self.openapi_version, "info": info}
|
|
components = {}
|
|
paths = {}
|
|
methods_with_body = set(("POST", "PUT"))
|
|
body_fields_from_routes = []
|
|
responses_from_routes = []
|
|
ref_prefix = "#/components/schemas/"
|
|
for route in self.routes:
|
|
route: APIRoute
|
|
if route.include_in_schema and isinstance(route, APIRoute):
|
|
if route.request_body:
|
|
assert isinstance(
|
|
route.request_body, Field
|
|
), "A request body must be a Pydantic BaseModel or Field"
|
|
body_fields_from_routes.append(route.request_body)
|
|
if route.response_field:
|
|
responses_from_routes.append(route.response_field)
|
|
flat_models = get_flat_models_from_fields(
|
|
body_fields_from_routes + responses_from_routes
|
|
)
|
|
model_name_map = get_model_name_map(flat_models)
|
|
definitions = {}
|
|
for model in flat_models:
|
|
m_schema, m_definitions = model_process_schema(
|
|
model, model_name_map=model_name_map, ref_prefix=ref_prefix
|
|
)
|
|
definitions.update(m_definitions)
|
|
model_name = model_name_map[model]
|
|
definitions[model_name] = m_schema
|
|
validation_error_definition = {
|
|
"title": "ValidationError",
|
|
"type": "object",
|
|
"properties": {
|
|
"loc": {
|
|
"title": "Location",
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
},
|
|
"msg": {"title": "Message", "type": "string"},
|
|
"type": {"title": "Error Type", "type": "string"},
|
|
},
|
|
"required": ["loc", "msg", "type"],
|
|
}
|
|
validation_error_response_definition = {
|
|
"title": "HTTPValidationError",
|
|
"type": "object",
|
|
"properties": {
|
|
"detail": {
|
|
"title": "Detail",
|
|
"type": "array",
|
|
"items": {"$ref": ref_prefix + "ValidationError"},
|
|
}
|
|
},
|
|
}
|
|
for route in self.routes:
|
|
route: APIRoute
|
|
if route.include_in_schema and isinstance(route, APIRoute):
|
|
path = paths.get(route.path, {})
|
|
for method in route.methods:
|
|
operation = {}
|
|
if route.tags:
|
|
operation["tags"] = route.tags
|
|
if route.summary:
|
|
operation["summary"] = route.summary
|
|
if route.description:
|
|
operation["description"] = route.description
|
|
if route.operation_id:
|
|
operation["operationId"] = route.operation_id
|
|
else:
|
|
operation["operationId"] = route.name
|
|
if route.deprecated:
|
|
operation["deprecated"] = route.deprecated
|
|
parameters = []
|
|
flat_dependant = get_flat_dependant(route.dependant)
|
|
security_definitions = {}
|
|
for security_scheme in flat_dependant.security_schemes:
|
|
security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
|
|
security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
|
|
security_definitions[security_name] = security_definition
|
|
if security_definitions:
|
|
components.setdefault("securitySchemes", {}).update(security_definitions)
|
|
operation["security"] = [{name: []} for name in security_definitions]
|
|
all_route_params = get_openapi_params(route.dependant)
|
|
for param in all_route_params:
|
|
if "ValidationError" not in definitions:
|
|
definitions["ValidationError"] = validation_error_definition
|
|
definitions[
|
|
"HTTPValidationError"
|
|
] = validation_error_response_definition
|
|
parameter = {
|
|
"name": param.alias,
|
|
"in": param.schema.in_.value,
|
|
"required": param.required,
|
|
"schema": field_schema(param, model_name_map={})[0],
|
|
}
|
|
if param.schema.description:
|
|
parameter["description"] = param.schema.description
|
|
if param.schema.deprecated:
|
|
parameter["deprecated"] = param.schema.deprecated
|
|
parameters.append(parameter)
|
|
if parameters:
|
|
operation["parameters"] = parameters
|
|
if method in methods_with_body:
|
|
request_body = getattr(route, "request_body", None)
|
|
if request_body:
|
|
assert isinstance(request_body, Field)
|
|
body_schema, _ = field_schema(
|
|
request_body,
|
|
model_name_map=model_name_map,
|
|
ref_prefix=ref_prefix,
|
|
)
|
|
required = request_body.required
|
|
request_body_oai = {}
|
|
if required:
|
|
request_body_oai["required"] = required
|
|
request_body_oai["content"] = {
|
|
"application/json": {"schema": body_schema}
|
|
}
|
|
operation["requestBody"] = request_body_oai
|
|
response_code = str(route.response_code)
|
|
response_schema = {"type": "string"}
|
|
if lenient_issubclass(route.response_wrapper, JSONResponse):
|
|
response_media_type = "application/json"
|
|
if route.response_field:
|
|
response_schema, _ = field_schema(
|
|
route.response_field,
|
|
model_name_map=model_name_map,
|
|
ref_prefix=ref_prefix,
|
|
)
|
|
else:
|
|
response_schema = {}
|
|
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
|
|
response_media_type = "text/html"
|
|
else:
|
|
response_media_type = "text/plain"
|
|
content = {response_media_type: {"schema": response_schema}}
|
|
operation["responses"] = {
|
|
response_code: {
|
|
"description": route.response_description,
|
|
"content": content,
|
|
}
|
|
}
|
|
if all_route_params or getattr(route, "request_body", None):
|
|
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
|
|
"description": "Validation Error",
|
|
"content": {
|
|
"application/json": {
|
|
"schema": {
|
|
"$ref": ref_prefix + "HTTPValidationError"
|
|
}
|
|
}
|
|
},
|
|
}
|
|
path[method.lower()] = operation
|
|
paths[route.path] = path
|
|
if definitions:
|
|
components.setdefault("schemas", {}).update(definitions)
|
|
if components:
|
|
output["components"] = components
|
|
output["paths"] = paths
|
|
return output
|
|
|