diff --git a/fastapi/applications.py b/fastapi/applications.py index 05c7bd2be..7591ce4a2 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1179,12 +1179,14 @@ class FastAPI(Starlette): name: Optional[str] = None, *, dependencies: Optional[Sequence[Depends]] = None, + tags: Optional[List[Union[str, Enum]]] = None, ) -> None: self.router.add_api_websocket_route( path, endpoint, name=name, dependencies=dependencies, + tags=tags, ) def websocket( @@ -1218,6 +1220,14 @@ class FastAPI(Starlette): """ ), ] = None, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ), + ] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Decorate a WebSocket function. @@ -1247,6 +1257,7 @@ class FastAPI(Starlette): func, name=name, dependencies=dependencies, + tags=tags, ) return func diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..37471255c 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -12,6 +12,7 @@ from typing import ( Collection, Coroutine, Dict, + Iterable, List, Mapping, Optional, @@ -393,12 +394,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): endpoint: Callable[..., Any], *, name: Optional[str] = None, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name + self.tags: List[Union[str, Enum]] = tags or [] self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=self.path_format, call=self.endpoint) @@ -919,9 +922,7 @@ class APIRouter(routing.Router): current_response_class = get_value_or_default( response_class, self.default_response_class ) - current_tags = self.tags.copy() - if tags: - current_tags.extend(tags) + current_tags = self.combine_tags(tags or []) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -936,7 +937,7 @@ class APIRouter(routing.Router): endpoint=endpoint, response_model=response_model, status_code=status_code, - tags=current_tags, + tags=list(current_tags), dependencies=current_dependencies, summary=summary, description=description, @@ -1029,16 +1030,20 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, *, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) + current_tags = self.combine_tags(tags) + route = APIWebSocketRoute( self.prefix + path, endpoint=endpoint, name=name, + tags=current_tags, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1063,6 +1068,14 @@ class APIRouter(routing.Router): ), ] = None, *, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ), + ] = None, dependencies: Annotated[ Optional[Sequence[params.Depends]], Doc( @@ -1105,7 +1118,7 @@ class APIRouter(routing.Router): def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( - path, func, name=name, dependencies=dependencies + path, func, name=name, tags=tags, dependencies=dependencies ) return func @@ -1279,11 +1292,7 @@ class APIRouter(routing.Router): default_response_class, self.default_response_class, ) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) + current_tags = self.combine_tags(tags, route) current_dependencies: List[params.Depends] = [] if dependencies: current_dependencies.extend(dependencies) @@ -1345,11 +1354,14 @@ class APIRouter(routing.Router): current_dependencies.extend(dependencies) if route.dependencies: current_dependencies.extend(route.dependencies) + + current_tags = self.combine_tags(tags, route) self.add_api_websocket_route( prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name, + tags=current_tags, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( @@ -4438,3 +4450,27 @@ class APIRouter(routing.Router): return func return decorator + + def combine_tags( + self, + *entities: Annotated[ + Union[None, str, routing.Route, Sequence], + Doc( + """ + Combine the router's current tags with those of the provided entities. + Supports None, strings, iterables, and Route objects with a `tags` attribute. + """ + ), + ], + ) -> List[str]: + tags = set(self.tags or []) + for entity in entities: + if entity is None: + continue + if isinstance(entity, str): + tags.add(entity) + elif isinstance(entity, Iterable): + tags.update(entity) + elif hasattr(entity, "tags"): + tags = tags.union(entity.tags) + return sorted(tags) diff --git a/tests/test_tutorial/test_bigger_applications/test_main.py b/tests/test_tutorial/test_bigger_applications/test_main.py index fe40fad7d..6deedb57e 100644 --- a/tests/test_tutorial/test_bigger_applications/test_main.py +++ b/tests/test_tutorial/test_bigger_applications/test_main.py @@ -580,7 +580,7 @@ def test_openapi_schema(client: TestClient): }, }, "put": { - "tags": ["items", "custom"], + "tags": ["custom", "items"], "summary": "Update Item", "operationId": "update_item_items__item_id__put", "parameters": [ diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 240a42bb0..c917dc17b 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -13,9 +13,9 @@ from fastapi import ( from fastapi.middleware import Middleware from fastapi.testclient import TestClient -router = APIRouter() -prefix_router = APIRouter() -native_prefix_route = APIRouter(prefix="/native") +router = APIRouter(tags=["base"]) +prefix_router = APIRouter(tags=["prefix"]) +native_prefix_router = APIRouter(prefix="/native", tags=["native"]) app = FastAPI() @@ -68,7 +68,7 @@ async def router_ws_decorator_depends( await websocket.close() -@native_prefix_route.websocket("/") +@native_prefix_router.websocket("/") async def router_native_prefix_ws(websocket: WebSocket): await websocket.accept() await websocket.send_text("Hello, router with native prefix!") @@ -104,11 +104,33 @@ async def router_ws_custom_error(websocket: WebSocket): raise CustomError() +@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"]) +@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"]) +async def router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + + +@prefix_router.websocket( + "/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"] +) +async def prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + + +@native_prefix_router.websocket( + "/test_tags/", + name="test-native-prefix-router-tags", + tags=["test-native-prefix-router-tags"], +) +async def native_prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + + def make_app(app=None, **kwargs): app = app or FastAPI(**kwargs) app.include_router(router) app.include_router(prefix_router, prefix="/prefix") - app.include_router(native_prefix_route) + app.include_router(native_prefix_router) return app @@ -269,3 +291,23 @@ def test_depend_err_handler(): pass # pragma: no cover assert e.value.code == 1002 assert "foo" in e.value.reason + + +@pytest.mark.parametrize( + "route_name,route_tags", + [ + ("test-app-tags", ["test-app-tags"]), + ("test-router-tags", ["base", "test-router-tags"]), + ("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]), + ( + "test-native-prefix-router-tags", + ["native", "test-native-prefix-router-tags"], + ), + ], +) +def test_websocket_tags(route_name, route_tags): + """ + Verify that it is possible to add tags to websocket routes + """ + route = next(route for route in app.routes if route.name == route_name) + assert sorted(route.tags) == sorted(route_tags)