|
|
@ -56,6 +56,7 @@ from pydantic import BaseModel |
|
|
|
from starlette import routing |
|
|
|
from starlette.concurrency import run_in_threadpool |
|
|
|
from starlette.exceptions import HTTPException |
|
|
|
from starlette.middleware import Middleware |
|
|
|
from starlette.requests import Request |
|
|
|
from starlette.responses import JSONResponse, Response |
|
|
|
from starlette.routing import ( |
|
|
@ -369,6 +370,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
*, |
|
|
|
name: Optional[str] = None, |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
) -> None: |
|
|
@ -390,6 +392,9 @@ class APIWebSocketRoute(routing.WebSocketRoute): |
|
|
|
dependency_overrides_provider=dependency_overrides_provider, |
|
|
|
) |
|
|
|
) |
|
|
|
if middleware is not None: |
|
|
|
for cls, args, kwargs in reversed(middleware): |
|
|
|
self.app = cls(app=self.app, *args, **kwargs) |
|
|
|
|
|
|
|
def matches(self, scope: Scope) -> Tuple[Match, Scope]: |
|
|
|
match, child_scope = super().matches(scope) |
|
|
@ -432,6 +437,7 @@ class APIRoute(routing.Route): |
|
|
|
generate_unique_id_function: Union[ |
|
|
|
Callable[["APIRoute"], str], DefaultPlaceholder |
|
|
|
] = Default(generate_unique_id), |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
) -> None: |
|
|
|
self.path = path |
|
|
|
self.endpoint = endpoint |
|
|
@ -462,6 +468,9 @@ class APIRoute(routing.Route): |
|
|
|
self.responses = responses or {} |
|
|
|
self.name = get_name(endpoint) if name is None else name |
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path) |
|
|
|
if middleware is not None: |
|
|
|
for cls, args, kwargs in reversed(middleware): |
|
|
|
self.app = cls(app=self.app, *args, **kwargs) |
|
|
|
if methods is None: |
|
|
|
methods = ["GET"] |
|
|
|
self.methods: Set[str] = {method.upper() for method in methods} |
|
|
@ -795,6 +804,7 @@ class APIRouter(routing.Router): |
|
|
|
""" |
|
|
|
), |
|
|
|
] = Default(generate_unique_id), |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
) -> None: |
|
|
|
super().__init__( |
|
|
|
routes=routes, |
|
|
@ -803,6 +813,7 @@ class APIRouter(routing.Router): |
|
|
|
on_startup=on_startup, |
|
|
|
on_shutdown=on_shutdown, |
|
|
|
lifespan=lifespan, |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
if prefix: |
|
|
|
assert prefix.startswith("/"), "A path prefix must start with '/'" |
|
|
@ -873,6 +884,7 @@ class APIRouter(routing.Router): |
|
|
|
generate_unique_id_function: Union[ |
|
|
|
Callable[[APIRoute], str], DefaultPlaceholder |
|
|
|
] = Default(generate_unique_id), |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
) -> None: |
|
|
|
route_class = route_class_override or self.route_class |
|
|
|
responses = responses or {} |
|
|
@ -919,6 +931,7 @@ class APIRouter(routing.Router): |
|
|
|
callbacks=current_callbacks, |
|
|
|
openapi_extra=openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id, |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
self.routes.append(route) |
|
|
|
|
|
|
@ -951,6 +964,7 @@ class APIRouter(routing.Router): |
|
|
|
generate_unique_id_function: Callable[[APIRoute], str] = Default( |
|
|
|
generate_unique_id |
|
|
|
), |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
) -> Callable[[DecoratedCallable], DecoratedCallable]: |
|
|
|
def decorator(func: DecoratedCallable) -> DecoratedCallable: |
|
|
|
self.add_api_route( |
|
|
@ -979,6 +993,7 @@ class APIRouter(routing.Router): |
|
|
|
callbacks=callbacks, |
|
|
|
openapi_extra=openapi_extra, |
|
|
|
generate_unique_id_function=generate_unique_id_function, |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
return func |
|
|
|
|
|
|
@ -990,6 +1005,7 @@ class APIRouter(routing.Router): |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
name: Optional[str] = None, |
|
|
|
*, |
|
|
|
middleware: Optional[Sequence[Middleware]] = None, |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
) -> None: |
|
|
|
current_dependencies = self.dependencies.copy() |
|
|
@ -1000,6 +1016,7 @@ class APIRouter(routing.Router): |
|
|
|
self.prefix + path, |
|
|
|
endpoint=endpoint, |
|
|
|
name=name, |
|
|
|
middleware=middleware, |
|
|
|
dependencies=current_dependencies, |
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider, |
|
|
|
) |
|
|
@ -1024,6 +1041,14 @@ class APIRouter(routing.Router): |
|
|
|
), |
|
|
|
] = None, |
|
|
|
*, |
|
|
|
middleware: Annotated[ |
|
|
|
Optional[Sequence[Middleware]], |
|
|
|
Doc( |
|
|
|
""" |
|
|
|
A list of middleware to apply to the WebSocket. |
|
|
|
""" |
|
|
|
), |
|
|
|
] = None, |
|
|
|
dependencies: Annotated[ |
|
|
|
Optional[Sequence[params.Depends]], |
|
|
|
Doc( |
|
|
@ -1066,7 +1091,7 @@ class APIRouter(routing.Router): |
|
|
|
|
|
|
|
def decorator(func: DecoratedCallable) -> DecoratedCallable: |
|
|
|
self.add_api_websocket_route( |
|
|
|
path, func, name=name, dependencies=dependencies |
|
|
|
path, func, name=name, middleware=middleware, dependencies=dependencies |
|
|
|
) |
|
|
|
return func |
|
|
|
|
|
|
@ -1192,6 +1217,15 @@ class APIRouter(routing.Router): |
|
|
|
""" |
|
|
|
), |
|
|
|
] = Default(generate_unique_id), |
|
|
|
middleware: Annotated[ |
|
|
|
Optional[Sequence[Middleware]], |
|
|
|
Doc( |
|
|
|
""" |
|
|
|
A list of middleware to apply to all the *path operations* in this |
|
|
|
router. |
|
|
|
""" |
|
|
|
), |
|
|
|
] = None, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Include another `APIRouter` in the same current `APIRouter`. |
|
|
@ -1290,6 +1324,7 @@ class APIRouter(routing.Router): |
|
|
|
callbacks=current_callbacks, |
|
|
|
openapi_extra=route.openapi_extra, |
|
|
|
generate_unique_id_function=current_generate_unique_id, |
|
|
|
middleware=middleware, |
|
|
|
) |
|
|
|
elif isinstance(route, routing.Route): |
|
|
|
methods = list(route.methods or []) |
|
|
|