from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from fastapi import routing from fastapi.encoders import DictIntStrAny, SetIntStr from fastapi.exception_handlers import ( http_exception_handler, request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html, ) from fastapi.openapi.utils import get_openapi from fastapi.params import Depends from starlette.applications import Starlette from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.middleware.errors import ServerErrorMiddleware from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute class FastAPI(Starlette): def __init__( self, debug: bool = False, routes: List[BaseRoute] = None, template_directory: str = None, title: str = "Fast API", description: str = "", version: str = "0.1.0", openapi_url: Optional[str] = "/openapi.json", openapi_prefix: str = "", docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect", **extra: Dict[str, Any], ) -> None: self._debug = debug self.router: routing.APIRouter = routing.APIRouter( routes, dependency_overrides_provider=self ) self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug ) self.title = title self.description = description self.version = version self.openapi_url = openapi_url self.openapi_prefix = openapi_prefix.rstrip("/") self.docs_url = docs_url self.redoc_url = redoc_url self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url self.extra = extra self.dependency_overrides: Dict[Callable, Callable] = {} self.openapi_version = "3.0.2" if self.openapi_url: assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'" assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'" if self.docs_url or self.redoc_url: assert self.openapi_url, "The openapi_url is required for the docs" self.openapi_schema: Optional[Dict[str, Any]] = None self.setup() def openapi(self) -> Dict: if not self.openapi_schema: self.openapi_schema = get_openapi( title=self.title, version=self.version, openapi_version=self.openapi_version, description=self.description, routes=self.routes, openapi_prefix=self.openapi_prefix, ) return self.openapi_schema def setup(self) -> None: if self.openapi_url: async def openapi(req: Request) -> JSONResponse: return JSONResponse(self.openapi()) self.add_route(self.openapi_url, openapi, include_in_schema=False) openapi_url = self.openapi_prefix + self.openapi_url if self.openapi_url and self.docs_url: async def swagger_ui_html(req: Request) -> HTMLResponse: return get_swagger_ui_html( openapi_url=openapi_url, title=self.title + " - Swagger UI", oauth2_redirect_url=self.swagger_ui_oauth2_redirect_url, ) self.add_route(self.docs_url, swagger_ui_html, include_in_schema=False) if self.swagger_ui_oauth2_redirect_url: async def swagger_ui_redirect(req: Request) -> HTMLResponse: return get_swagger_ui_oauth2_redirect_html() self.add_route( self.swagger_ui_oauth2_redirect_url, swagger_ui_redirect, include_in_schema=False, ) if self.openapi_url and self.redoc_url: async def redoc_html(req: Request) -> HTMLResponse: return get_redoc_html( openapi_url=openapi_url, title=self.title + " - ReDoc" ) self.add_route(self.redoc_url, redoc_html, include_in_schema=False) self.add_exception_handler(HTTPException, http_exception_handler) self.add_exception_handler( RequestValidationError, request_validation_exception_handler ) def add_api_route( self, path: str, endpoint: Callable, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, methods: List[str] = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> None: self.router.add_api_route( path, endpoint=endpoint, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, methods=methods, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def api_route( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, methods: List[str] = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: def decorator(func: Callable) -> Callable: self.router.add_api_route( path, func, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, methods=methods, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) return func return decorator def add_api_websocket_route( self, path: str, endpoint: Callable, name: str = None ) -> None: self.router.add_api_websocket_route(path, endpoint, name=name) def websocket(self, path: str, name: str = None) -> Callable: def decorator(func: Callable) -> Callable: self.add_api_websocket_route(path, func, name=name) return func return decorator def include_router( self, router: routing.APIRouter, *, prefix: str = "", tags: List[str] = None, dependencies: Sequence[Depends] = None, responses: Dict[Union[int, str], Dict[str, Any]] = None, ) -> None: self.router.include_router( router, prefix=prefix, tags=tags, dependencies=dependencies, responses=responses or {}, ) def get( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.get( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def put( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.put( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def post( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.post( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def delete( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.delete( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, operation_id=operation_id, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def options( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.options( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def head( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.head( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def patch( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.patch( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, ) def trace( self, path: str, *, response_model: Type[Any] = None, status_code: int = 200, tags: List[str] = None, dependencies: Sequence[Depends] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", responses: Dict[Union[int, str], Dict[str, Any]] = None, deprecated: bool = None, operation_id: str = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, response_class: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.router.trace( path, response_model=response_model, status_code=status_code, tags=tags or [], dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses or {}, deprecated=deprecated, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, response_class=response_class, name=name, )