|
|
@ -2,7 +2,18 @@ import asyncio |
|
|
|
import enum |
|
|
|
import inspect |
|
|
|
import json |
|
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Callable, |
|
|
|
Coroutine, |
|
|
|
Dict, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
Set, |
|
|
|
Type, |
|
|
|
Union, |
|
|
|
) |
|
|
|
|
|
|
|
from fastapi import params |
|
|
|
from fastapi.datastructures import Default, DefaultPlaceholder |
|
|
@ -16,6 +27,7 @@ from fastapi.dependencies.utils import ( |
|
|
|
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder |
|
|
|
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError |
|
|
|
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY |
|
|
|
from fastapi.types import DecoratedCallable |
|
|
|
from fastapi.utils import ( |
|
|
|
create_cloned_field, |
|
|
|
create_response_field, |
|
|
@ -30,7 +42,8 @@ from starlette.concurrency import run_in_threadpool |
|
|
|
from starlette.exceptions import HTTPException |
|
|
|
from starlette.requests import Request |
|
|
|
from starlette.responses import JSONResponse, Response |
|
|
|
from starlette.routing import Mount # noqa |
|
|
|
from starlette.routing import BaseRoute |
|
|
|
from starlette.routing import Mount as Mount # noqa |
|
|
|
from starlette.routing import ( |
|
|
|
compile_path, |
|
|
|
get_name, |
|
|
@ -150,7 +163,7 @@ def get_request_handler( |
|
|
|
response_model_exclude_defaults: bool = False, |
|
|
|
response_model_exclude_none: bool = False, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
) -> Callable: |
|
|
|
) -> Callable[[Request], Coroutine[Any, Any, Response]]: |
|
|
|
assert dependant.call is not None, "dependant.call must be a function" |
|
|
|
is_coroutine = asyncio.iscoroutinefunction(dependant.call) |
|
|
|
is_body_form = body_field and isinstance(body_field.field_info, params.Form) |
|
|
@ -207,7 +220,7 @@ def get_request_handler( |
|
|
|
response = actual_response_class( |
|
|
|
content=response_data, |
|
|
|
status_code=status_code, |
|
|
|
background=background_tasks, |
|
|
|
background=background_tasks, # type: ignore # in Starlette |
|
|
|
) |
|
|
|
response.headers.raw.extend(sub_response.headers.raw) |
|
|
|
if sub_response.status_code: |
|
|
@ -219,7 +232,7 @@ def get_request_handler( |
|
|
|
|
|
|
|
def get_websocket_app( |
|
|
|
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None |
|
|
|
) -> Callable: |
|
|
|
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: |
|
|
|
async def app(websocket: WebSocket) -> None: |
|
|
|
solved_result = await solve_dependencies( |
|
|
|
request=websocket, |
|
|
@ -240,7 +253,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
path: str, |
|
|
|
endpoint: Callable, |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
*, |
|
|
|
name: Optional[str] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
@ -262,7 +275,7 @@ class APIRoute(routing.Route): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
path: str, |
|
|
|
endpoint: Callable, |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
*, |
|
|
|
response_model: Optional[Type[Any]] = None, |
|
|
|
status_code: int = 200, |
|
|
@ -287,7 +300,7 @@ class APIRoute(routing.Route): |
|
|
|
JSONResponse |
|
|
|
), |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
callbacks: Optional[List["APIRoute"]] = None, |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> None: |
|
|
|
# normalise enums e.g. http.HTTPStatus |
|
|
|
if isinstance(status_code, enum.IntEnum): |
|
|
@ -298,7 +311,7 @@ class APIRoute(routing.Route): |
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path) |
|
|
|
if methods is None: |
|
|
|
methods = ["GET"] |
|
|
|
self.methods = set([method.upper() for method in methods]) |
|
|
|
self.methods: Set[str] = set([method.upper() for method in methods]) |
|
|
|
self.unique_id = generate_operation_id_for_path( |
|
|
|
name=self.name, path=self.path_format, method=list(methods)[0] |
|
|
|
) |
|
|
@ -375,7 +388,7 @@ class APIRoute(routing.Route): |
|
|
|
self.callbacks = callbacks |
|
|
|
self.app = request_response(self.get_route_handler()) |
|
|
|
|
|
|
|
def get_route_handler(self) -> Callable: |
|
|
|
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: |
|
|
|
return get_request_handler( |
|
|
|
dependant=self.dependant, |
|
|
|
body_field=self.body_field, |
|
|
@ -401,23 +414,23 @@ class APIRouter(routing.Router): |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
default_response_class: Type[Response] = Default(JSONResponse), |
|
|
|
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
routes: Optional[List[routing.BaseRoute]] = None, |
|
|
|
redirect_slashes: bool = True, |
|
|
|
default: Optional[ASGIApp] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
route_class: Type[APIRoute] = APIRoute, |
|
|
|
on_startup: Optional[Sequence[Callable]] = None, |
|
|
|
on_shutdown: Optional[Sequence[Callable]] = None, |
|
|
|
deprecated: bool = None, |
|
|
|
on_startup: Optional[Sequence[Callable[[], Any]]] = None, |
|
|
|
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, |
|
|
|
deprecated: Optional[bool] = None, |
|
|
|
include_in_schema: bool = True, |
|
|
|
) -> None: |
|
|
|
super().__init__( |
|
|
|
routes=routes, |
|
|
|
routes=routes, # type: ignore # in Starlette |
|
|
|
redirect_slashes=redirect_slashes, |
|
|
|
default=default, |
|
|
|
on_startup=on_startup, |
|
|
|
on_shutdown=on_shutdown, |
|
|
|
default=default, # type: ignore # in Starlette |
|
|
|
on_startup=on_startup, # type: ignore # in Starlette |
|
|
|
on_shutdown=on_shutdown, # type: ignore # in Starlette |
|
|
|
) |
|
|
|
if prefix: |
|
|
|
assert prefix.startswith("/"), "A path prefix must start with '/'" |
|
|
@ -438,7 +451,7 @@ class APIRouter(routing.Router): |
|
|
|
def add_api_route( |
|
|
|
self, |
|
|
|
path: str, |
|
|
|
endpoint: Callable, |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
*, |
|
|
|
response_model: Optional[Type[Any]] = None, |
|
|
|
status_code: int = 200, |
|
|
@ -463,7 +476,7 @@ class APIRouter(routing.Router): |
|
|
|
), |
|
|
|
name: Optional[str] = None, |
|
|
|
route_class_override: Optional[Type[APIRoute]] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> None: |
|
|
|
route_class = route_class_override or self.route_class |
|
|
|
responses = responses or {} |
|
|
@ -532,9 +545,9 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
def decorator(func: Callable) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
def decorator(func: DecoratedCallable) -> DecoratedCallable: |
|
|
|
self.add_api_route( |
|
|
|
path, |
|
|
|
func, |
|
|
@ -565,7 +578,7 @@ class APIRouter(routing.Router): |
|
|
|
return decorator |
|
|
|
|
|
|
|
def add_api_websocket_route( |
|
|
|
self, path: str, endpoint: Callable, name: Optional[str] = None |
|
|
|
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None |
|
|
|
) -> None: |
|
|
|
route = APIWebSocketRoute( |
|
|
|
path, |
|
|
@ -575,8 +588,10 @@ class APIRouter(routing.Router): |
|
|
|
) |
|
|
|
self.routes.append(route) |
|
|
|
|
|
|
|
def websocket(self, path: str, name: Optional[str] = None) -> Callable: |
|
|
|
def decorator(func: Callable) -> Callable: |
|
|
|
def websocket( |
|
|
|
self, path: str, name: Optional[str] = None |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
def decorator(func: DecoratedCallable) -> DecoratedCallable: |
|
|
|
self.add_api_websocket_route(path, func, name=name) |
|
|
|
return func |
|
|
|
|
|
|
@ -591,8 +606,8 @@ class APIRouter(routing.Router): |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
default_response_class: Type[Response] = Default(JSONResponse), |
|
|
|
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
deprecated: bool = None, |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
deprecated: Optional[bool] = None, |
|
|
|
include_in_schema: bool = True, |
|
|
|
) -> None: |
|
|
|
if prefix: |
|
|
@ -663,10 +678,11 @@ class APIRouter(routing.Router): |
|
|
|
callbacks=current_callbacks, |
|
|
|
) |
|
|
|
elif isinstance(route, routing.Route): |
|
|
|
methods = list(route.methods or []) # type: ignore # in Starlette |
|
|
|
self.add_route( |
|
|
|
prefix + route.path, |
|
|
|
route.endpoint, |
|
|
|
methods=list(route.methods or []), |
|
|
|
methods=methods, |
|
|
|
include_in_schema=route.include_in_schema, |
|
|
|
name=route.name, |
|
|
|
) |
|
|
@ -706,8 +722,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -756,8 +772,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -806,8 +822,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -856,8 +872,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -906,8 +922,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -956,8 +972,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -1006,8 +1022,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|
response_model=response_model, |
|
|
@ -1056,8 +1072,8 @@ class APIRouter(routing.Router): |
|
|
|
include_in_schema: bool = True, |
|
|
|
response_class: Type[Response] = Default(JSONResponse), |
|
|
|
name: Optional[str] = None, |
|
|
|
callbacks: Optional[List[APIRoute]] = None, |
|
|
|
) -> Callable: |
|
|
|
callbacks: Optional[List[BaseRoute]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
|
|
|
|
return self.api_route( |
|
|
|
path=path, |
|
|
|