diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..9ee4f0d97 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -392,12 +392,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): endpoint: Callable[..., Any], *, name: Optional[str] = None, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name + self.tags: List[Union[str, Enum]] = tags or [] self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=self.path_format, call=self.endpoint) @@ -1028,6 +1030,7 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, *, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: current_dependencies = self.dependencies.copy() @@ -1038,6 +1041,7 @@ class APIRouter(routing.Router): self.prefix + path, endpoint=endpoint, name=name, + tags=tags, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1062,6 +1066,14 @@ class APIRouter(routing.Router): ), ] = None, *, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ), + ] = None, dependencies: Annotated[ Optional[Sequence[params.Depends]], Doc( @@ -1104,7 +1116,7 @@ class APIRouter(routing.Router): def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( - path, func, name=name, dependencies=dependencies + path, func, name=name, tags=tags, dependencies=dependencies ) return func @@ -1344,11 +1356,18 @@ class APIRouter(routing.Router): current_dependencies.extend(dependencies) if route.dependencies: current_dependencies.extend(route.dependencies) + + current_tags = [] + if tags: + current_tags.extend(tags) + if route.tags: + current_tags.extend(route.tags) self.add_api_websocket_route( prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name, + tags=current_tags, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 240a42bb0..3b24c04ff 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -103,6 +103,12 @@ class CustomError(Exception): async def router_ws_custom_error(websocket: WebSocket): raise CustomError() +@router.websocket("/test_tags/", name='test-tags', tags=["test"]) +async def router_ws_test_tags(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, router with tags!") + await websocket.close() + def make_app(app=None, **kwargs): app = app or FastAPI(**kwargs) @@ -269,3 +275,11 @@ def test_depend_err_handler(): pass # pragma: no cover assert e.value.code == 1002 assert "foo" in e.value.reason + + +def test_websocket_tags(): + """ + Verify that it is possible to add tags to websocket routes + """ + route = next(route for route in app.routes if route.name == 'test-tags') + assert route.tags == ["test"] \ No newline at end of file