You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

509 lines
17 KiB

import asyncio
import inspect
from typing import Callable, List, Type
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.formparsers import UploadFile
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import get_name, request_response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi import params
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
def serialize_response(*, field: Field = None, response):
if field:
errors = []
value, errors_ = field.validate(response, {}, loc=("response",))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
if errors:
raise ValidationError(errors)
return jsonable_encoder(value)
else:
return jsonable_encoder(response)
def get_app(
dependant: Dependant,
body_field: Field = None,
response_code: str = 200,
response_wrapper: Type[Response] = JSONResponse,
response_field: Type[Field] = None,
):
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
async def app(request: Request) -> Response:
body = None
if body_field:
if isinstance(body_field.schema, params.Form):
raw_body = await request.form()
body = {}
for field, value in raw_body.items():
if isinstance(value, UploadFile):
body[field] = await value.read()
else:
body[field] = value
else:
body = await request.json()
values, errors = await solve_dependencies(
request=request, dependant=dependant, body=body
)
if errors:
errors_out = ValidationError(errors)
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
else:
if is_coroutine:
raw_response = await dependant.call(**values)
else:
raw_response = await run_in_threadpool(dependant.call, **values)
if isinstance(raw_response, Response):
return raw_response
if isinstance(raw_response, BaseModel):
return response_wrapper(
content=jsonable_encoder(raw_response), status_code=response_code
)
errors = []
try:
return response_wrapper(
content=serialize_response(
field=response_field, response=raw_response
),
status_code=response_code,
)
except Exception as e:
errors.append(e)
try:
response = dict(raw_response)
return response_wrapper(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
)
except Exception as e:
errors.append(e)
try:
response = vars(raw_response)
return response_wrapper(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
)
except Exception as e:
errors.append(e)
raise ValueError(errors)
return app
class APIRoute(routing.Route):
def __init__(
self,
path: str,
endpoint: Callable,
*,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> None:
# TODO define how to read and provide security params, and how to have them globally too
# TODO implement dependencies and injection
# TODO refactor code structure
# TODO create testing
# TODO testing coverage
assert path.startswith("/"), "Routed paths must always start '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
self.tags = tags or []
self.summary = summary
self.description = description or self.endpoint.__doc__
self.operation_id = operation_id
self.deprecated = deprecated
self.body_field: Field = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
self.response_field = None
if response_type:
assert lenient_issubclass(
response_wrapper, JSONResponse
), "To declare a type the response must be a JSON response"
self.response_type = response_type
response_name = "Response_" + self.name
self.response_field = Field(
name=response_name,
type_=self.response_type,
class_validators=[],
default=None,
required=False,
model_config=BaseConfig(),
schema=Schema(None),
)
else:
self.response_type = None
if methods is None:
methods = ["GET"]
self.methods = methods
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
path
)
assert inspect.isfunction(endpoint) or inspect.ismethod(
endpoint
), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint)
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.app = request_response(
get_app(
dependant=self.dependant,
body_field=self.body_field,
response_code=self.response_code,
response_wrapper=self.response_wrapper,
response_field=self.response_field,
)
)
class APIRouter(routing.Router):
def add_api_route(
self,
path: str,
endpoint: Callable,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> None:
route = APIRoute(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
self.routes.append(route)
def api_route(
self,
path: str,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
return func
return decorator
def get(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["GET"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def put(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["PUT"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def post(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["POST"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def delete(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["DELETE"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def options(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["OPTIONS"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def head(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["HEAD"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def patch(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["PATCH"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def trace(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["TRACE"],
name=name,
include_in_schema=include_in_schema,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)