Browse Source

Add middleware parameter

pull/11398/head
Marcelo Trylesinski 2 years ago
parent
commit
c5cfae8791
  1. 37
      fastapi/routing.py

37
fastapi/routing.py

@ -56,6 +56,7 @@ from pydantic import BaseModel
from starlette import routing from starlette import routing
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.routing import ( from starlette.routing import (
@ -369,6 +370,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
endpoint: Callable[..., Any], endpoint: Callable[..., Any],
*, *,
name: Optional[str] = None, name: Optional[str] = None,
middleware: Optional[Sequence[Middleware]] = None,
dependencies: Optional[Sequence[params.Depends]] = None, dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
) -> None: ) -> None:
@ -390,6 +392,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
) )
) )
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
def matches(self, scope: Scope) -> Tuple[Match, Scope]: def matches(self, scope: Scope) -> Tuple[Match, Scope]:
match, child_scope = super().matches(scope) match, child_scope = super().matches(scope)
@ -432,6 +437,7 @@ class APIRoute(routing.Route):
generate_unique_id_function: Union[ generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id), ] = Default(generate_unique_id),
middleware: Optional[Sequence[Middleware]] = None,
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
@ -462,6 +468,9 @@ class APIRoute(routing.Route):
self.responses = responses or {} self.responses = responses or {}
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
if methods is None: if methods is None:
methods = ["GET"] methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods} self.methods: Set[str] = {method.upper() for method in methods}
@ -795,6 +804,7 @@ class APIRouter(routing.Router):
""" """
), ),
] = Default(generate_unique_id), ] = Default(generate_unique_id),
middleware: Optional[Sequence[Middleware]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
routes=routes, routes=routes,
@ -803,6 +813,7 @@ class APIRouter(routing.Router):
on_startup=on_startup, on_startup=on_startup,
on_shutdown=on_shutdown, on_shutdown=on_shutdown,
lifespan=lifespan, lifespan=lifespan,
middleware=middleware,
) )
if prefix: if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'" assert prefix.startswith("/"), "A path prefix must start with '/'"
@ -873,6 +884,7 @@ class APIRouter(routing.Router):
generate_unique_id_function: Union[ generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id), ] = Default(generate_unique_id),
middleware: Optional[Sequence[Middleware]] = None,
) -> None: ) -> None:
route_class = route_class_override or self.route_class route_class = route_class_override or self.route_class
responses = responses or {} responses = responses or {}
@ -919,6 +931,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks, callbacks=current_callbacks,
openapi_extra=openapi_extra, openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id, generate_unique_id_function=current_generate_unique_id,
middleware=middleware,
) )
self.routes.append(route) self.routes.append(route)
@ -951,6 +964,7 @@ class APIRouter(routing.Router):
generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id generate_unique_id
), ),
middleware: Optional[Sequence[Middleware]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_route( self.add_api_route(
@ -979,6 +993,7 @@ class APIRouter(routing.Router):
callbacks=callbacks, callbacks=callbacks,
openapi_extra=openapi_extra, openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function, generate_unique_id_function=generate_unique_id_function,
middleware=middleware,
) )
return func return func
@ -990,6 +1005,7 @@ class APIRouter(routing.Router):
endpoint: Callable[..., Any], endpoint: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
*, *,
middleware: Optional[Sequence[Middleware]] = None,
dependencies: Optional[Sequence[params.Depends]] = None, dependencies: Optional[Sequence[params.Depends]] = None,
) -> None: ) -> None:
current_dependencies = self.dependencies.copy() current_dependencies = self.dependencies.copy()
@ -1000,6 +1016,7 @@ class APIRouter(routing.Router):
self.prefix + path, self.prefix + path,
endpoint=endpoint, endpoint=endpoint,
name=name, name=name,
middleware=middleware,
dependencies=current_dependencies, dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
) )
@ -1024,6 +1041,14 @@ class APIRouter(routing.Router):
), ),
] = None, ] = None,
*, *,
middleware: Annotated[
Optional[Sequence[Middleware]],
Doc(
"""
A list of middleware to apply to the WebSocket.
"""
),
] = None,
dependencies: Annotated[ dependencies: Annotated[
Optional[Sequence[params.Depends]], Optional[Sequence[params.Depends]],
Doc( Doc(
@ -1066,7 +1091,7 @@ class APIRouter(routing.Router):
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route( self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies path, func, name=name, middleware=middleware, dependencies=dependencies
) )
return func return func
@ -1192,6 +1217,15 @@ class APIRouter(routing.Router):
""" """
), ),
] = Default(generate_unique_id), ] = Default(generate_unique_id),
middleware: Annotated[
Optional[Sequence[Middleware]],
Doc(
"""
A list of middleware to apply to all the *path operations* in this
router.
"""
),
] = None,
) -> None: ) -> None:
""" """
Include another `APIRouter` in the same current `APIRouter`. Include another `APIRouter` in the same current `APIRouter`.
@ -1290,6 +1324,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks, callbacks=current_callbacks,
openapi_extra=route.openapi_extra, openapi_extra=route.openapi_extra,
generate_unique_id_function=current_generate_unique_id, generate_unique_id_function=current_generate_unique_id,
middleware=middleware,
) )
elif isinstance(route, routing.Route): elif isinstance(route, routing.Route):
methods = list(route.methods or []) methods = list(route.methods or [])

Loading…
Cancel
Save