From c5cfae87916fbe10550ba8d10a40f40488d4c928 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 22 Jan 2024 22:29:29 +0100 Subject: [PATCH] Add middleware parameter --- fastapi/routing.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index acebabfca..6863947ed 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -56,6 +56,7 @@ from pydantic import BaseModel from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import ( @@ -369,6 +370,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): endpoint: Callable[..., Any], *, name: Optional[str] = None, + middleware: Optional[Sequence[Middleware]] = None, dependencies: Optional[Sequence[params.Depends]] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: @@ -390,6 +392,9 @@ class APIWebSocketRoute(routing.WebSocketRoute): dependency_overrides_provider=dependency_overrides_provider, ) ) + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(app=self.app, *args, **kwargs) def matches(self, scope: Scope) -> Tuple[Match, Scope]: match, child_scope = super().matches(scope) @@ -432,6 +437,7 @@ class APIRoute(routing.Route): generate_unique_id_function: Union[ Callable[["APIRoute"], str], DefaultPlaceholder ] = Default(generate_unique_id), + middleware: Optional[Sequence[Middleware]] = None, ) -> None: self.path = path self.endpoint = endpoint @@ -462,6 +468,9 @@ class APIRoute(routing.Route): self.responses = responses or {} self.name = get_name(endpoint) if name is None else name self.path_regex, self.path_format, self.param_convertors = compile_path(path) + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(app=self.app, *args, **kwargs) if methods is None: methods = ["GET"] self.methods: Set[str] = {method.upper() for method in methods} @@ -795,6 +804,7 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + middleware: Optional[Sequence[Middleware]] = None, ) -> None: super().__init__( routes=routes, @@ -803,6 +813,7 @@ class APIRouter(routing.Router): on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan, + middleware=middleware, ) if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" @@ -873,6 +884,7 @@ class APIRouter(routing.Router): generate_unique_id_function: Union[ Callable[[APIRoute], str], DefaultPlaceholder ] = Default(generate_unique_id), + middleware: Optional[Sequence[Middleware]] = None, ) -> None: route_class = route_class_override or self.route_class responses = responses or {} @@ -919,6 +931,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + middleware=middleware, ) self.routes.append(route) @@ -951,6 +964,7 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + middleware: Optional[Sequence[Middleware]] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_route( @@ -979,6 +993,7 @@ class APIRouter(routing.Router): callbacks=callbacks, openapi_extra=openapi_extra, generate_unique_id_function=generate_unique_id_function, + middleware=middleware, ) return func @@ -990,6 +1005,7 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, *, + middleware: Optional[Sequence[Middleware]] = None, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: current_dependencies = self.dependencies.copy() @@ -1000,6 +1016,7 @@ class APIRouter(routing.Router): self.prefix + path, endpoint=endpoint, name=name, + middleware=middleware, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1024,6 +1041,14 @@ class APIRouter(routing.Router): ), ] = None, *, + middleware: Annotated[ + Optional[Sequence[Middleware]], + Doc( + """ + A list of middleware to apply to the WebSocket. + """ + ), + ] = None, dependencies: Annotated[ Optional[Sequence[params.Depends]], Doc( @@ -1066,7 +1091,7 @@ class APIRouter(routing.Router): def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( - path, func, name=name, dependencies=dependencies + path, func, name=name, middleware=middleware, dependencies=dependencies ) return func @@ -1192,6 +1217,15 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + middleware: Annotated[ + Optional[Sequence[Middleware]], + Doc( + """ + A list of middleware to apply to all the *path operations* in this + router. + """ + ), + ] = None, ) -> None: """ Include another `APIRouter` in the same current `APIRouter`. @@ -1290,6 +1324,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=route.openapi_extra, generate_unique_id_function=current_generate_unique_id, + middleware=middleware, ) elif isinstance(route, routing.Route): methods = list(route.methods or [])