import asyncio import inspect import logging from typing import Any, Callable, Dict, List, Optional, Type, Union from fastapi import params from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies from fastapi.encoders import jsonable_encoder from fastapi.openapi.models import AdditionalResponse, AdditionalResponseDescription from fastapi.utils import UnconstrainedConfig from pydantic import BaseModel, Schema from pydantic.error_wrappers import ErrorWrapper, ValidationError from pydantic.fields import Field from pydantic.utils import lenient_issubclass from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import compile_path, get_name, request_response from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY def serialize_response(*, field: Field = None, response: Response) -> Any: encoded = jsonable_encoder(response) if field: errors = [] value, errors_ = field.validate(encoded, {}, loc=("response",)) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): errors.extend(errors_) if errors: raise ValidationError(errors) return jsonable_encoder(value) else: return encoded def get_app( dependant: Dependant, body_field: Field = None, status_code: int = 200, content_type: Type[Response] = JSONResponse, response_field: Field = None, ) -> Callable: assert dependant.call is not None, "dependant.call must me a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_body_form = body_field and isinstance(body_field.schema, params.Form) async def app(request: Request) -> Response: try: body = None if body_field: if is_body_form: raw_body = await request.form() form_fields = {} for field, value in raw_body.items(): form_fields[field] = value if form_fields: body = form_fields else: body_bytes = await request.body() if body_bytes: body = await request.json() except Exception as e: logging.error("Error getting request body", e) raise HTTPException( status_code=400, detail="There was an error parsing the body" ) values, errors = await solve_dependencies( request=request, dependant=dependant, body=body ) if errors: errors_out = ValidationError(errors) raise HTTPException( status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() ) else: assert dependant.call is not None, "dependant.call must me a function" if is_coroutine: raw_response = await dependant.call(**values) else: raw_response = await run_in_threadpool(dependant.call, **values) if isinstance(raw_response, Response): return raw_response response_data = serialize_response( field=response_field, response=raw_response ) return content_type(content=response_data, status_code=status_code) return app class APIRoute(routing.Route): def __init__( self, path: str, endpoint: Callable, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, name: str = None, methods: List[str] = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, ) -> None: assert path.startswith("/"), "Routed paths must always start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.response_model = response_model if self.response_model: assert lenient_issubclass( content_type, JSONResponse ), "To declare a type the response must be a JSON response" response_name = "Response_" + self.name self.response_field: Optional[Field] = Field( name=response_name, type_=self.response_model, class_validators=[], default=None, required=False, model_config=UnconstrainedConfig, schema=Schema(None), ) else: self.response_field = None self.status_code = status_code self.tags = tags or [] self.summary = summary self.description = description or self.endpoint.__doc__ self.response_description = response_description self.additional_responses: Dict[int, AdditionalResponseDescription] = {} existed_codes = [self.status_code, 422] if isinstance(additional_responses, dict): self.additional_responses = additional_responses.copy() for add_response in additional_responses: if isinstance(add_response, int): continue assert ( add_response.status_code not in existed_codes ), f"(Duplicated Status Code): Response with status code [{add_response.status_code}] already defined!" existed_codes.append(add_response.status_code) response_models = [m for m in add_response.models] valid_response_models = True try: valid_response_models = all( [issubclass(m, BaseModel) for m in response_models] ) except TypeError as te: valid_response_models = False if not valid_response_models: raise ValueError( "All response models must be " "a subclass of `pydantic.BaseModel` " "model." ) if add_response.content_type == "application/json" or lenient_issubclass( content_type, JSONResponse ): if len(response_models): schema_field = Field( name=f"Additional_response_{add_response.status_code}", type_=Union[tuple(response_models)], class_validators=[], default=None, required=False, model_config=UnconstrainedConfig, schema=Schema(None), ) else: schema_field = None else: schema_field = None add_resp_description = AdditionalResponseDescription( description=add_response.description, content_type=add_response.content_type, schema_field=schema_field, ) self.additional_responses[add_response.status_code] = add_resp_description self.deprecated = deprecated if methods is None: methods = ["GET"] self.methods = methods self.operation_id = operation_id self.include_in_schema = include_in_schema self.content_type = content_type self.path_regex, self.path_format, self.param_convertors = compile_path(path) assert inspect.isfunction(endpoint) or inspect.ismethod( endpoint ), f"An endpoint must be a function or method" self.dependant = get_dependant(path=path, call=self.endpoint) self.body_field = get_body_field(dependant=self.dependant, name=self.name) self.app = request_response( get_app( dependant=self.dependant, body_field=self.body_field, status_code=self.status_code, content_type=self.content_type, response_field=self.response_field, ) ) class APIRouter(routing.Router): def add_api_route( self, path: str, endpoint: Callable, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, methods: List[str] = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> None: route = APIRoute( path, endpoint=endpoint, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=methods, operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) self.routes.append(route) def api_route( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, methods: List[str] = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: def decorator(func: Callable) -> Callable: self.add_api_route( path, func, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=methods, operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) return func return decorator def include_router( self, router: "APIRouter", *, prefix: str = "", tags: List[str] = None, additional_responses: List[AdditionalResponse] = [], ) -> None: if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" assert not prefix.endswith( "/" ), "A path prefix must not end with '/', as the routes will start with '/'" for route in router.routes: if isinstance(route, APIRoute): # really ugly hack and repitition prev_add_resp = route.additional_responses existed_codes = [422, route.status_code] + [ int(c) for c in prev_add_resp.keys() ] for add_response in additional_responses: assert ( add_response.status_code not in existed_codes ), f"(Duplicated Status Code): Response with status code [{add_response.status_code}] already defined!" existed_codes.append(add_response.status_code) response_models = [m for m in add_response.models] valid_response_models = True try: valid_response_models = all( [issubclass(m, BaseModel) for m in response_models] ) except AttributeError as ae: valid_response_models = False if not valid_response_models: raise ValueError( "All response models must be" "a subclass of `pydantic.BaseModel`" "model." ) if ( add_response.content_type == "application/json" or lenient_issubclass(route.content_type, JSONResponse) ): if len(response_models): schema_field = Field( name=f"Additional_response_{add_response.status_code}", type_=Union[tuple(response_models)], class_validators=[], default=None, required=False, model_config=UnconstrainedConfig, schema=Schema(None), ) else: schema_field = None else: schema_field = None add_resp_description = AdditionalResponseDescription( description=add_response.description, content_type=add_response.content_type, schema_field=schema_field, ) route.additional_responses[ add_response.status_code ] = add_resp_description self.add_api_route( prefix + route.path, route.endpoint, response_model=route.response_model, status_code=route.status_code, tags=(route.tags or []) + (tags or []), summary=route.summary, description=route.description, response_description=route.response_description, additional_responses=route.additional_responses, deprecated=route.deprecated, methods=route.methods, operation_id=route.operation_id, include_in_schema=route.include_in_schema, content_type=route.content_type, name=route.name, ) elif isinstance(route, routing.Route): self.add_route( prefix + route.path, route.endpoint, methods=route.methods, include_in_schema=route.include_in_schema, name=route.name, ) def get( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["GET"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def put( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["PUT"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def post( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["POST"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def delete( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["DELETE"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def options( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["OPTIONS"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def head( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["HEAD"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def patch( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["PATCH"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, ) def trace( self, path: str, *, response_model: Type[BaseModel] = None, status_code: int = 200, tags: List[str] = None, summary: str = None, description: str = None, response_description: str = "Successful Response", additional_responses: List[AdditionalResponse] = [], deprecated: bool = None, operation_id: str = None, include_in_schema: bool = True, content_type: Type[Response] = JSONResponse, name: str = None, ) -> Callable: return self.api_route( path=path, response_model=response_model, status_code=status_code, tags=tags or [], summary=summary, description=description, response_description=response_description, additional_responses=additional_responses, deprecated=deprecated, methods=["TRACE"], operation_id=operation_id, include_in_schema=include_in_schema, content_type=content_type, name=name, )