From 3da3827227788a8a033f6cb53a6a0fa6e9ab2aa6 Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Fri, 18 Apr 2025 00:23:33 -0500 Subject: [PATCH 01/10] Added tags to websockets. --- fastapi/routing.py | 21 ++++++++++++++++++++- tests/test_ws_router.py | 14 ++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..9ee4f0d97 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -392,12 +392,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): endpoint: Callable[..., Any], *, name: Optional[str] = None, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, dependency_overrides_provider: Optional[Any] = None, ) -> None: self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name + self.tags: List[Union[str, Enum]] = tags or [] self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.dependant = get_dependant(path=self.path_format, call=self.endpoint) @@ -1028,6 +1030,7 @@ class APIRouter(routing.Router): endpoint: Callable[..., Any], name: Optional[str] = None, *, + tags: Optional[List[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, ) -> None: current_dependencies = self.dependencies.copy() @@ -1038,6 +1041,7 @@ class APIRouter(routing.Router): self.prefix + path, endpoint=endpoint, name=name, + tags=tags, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1062,6 +1066,14 @@ class APIRouter(routing.Router): ), ] = None, *, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ), + ] = None, dependencies: Annotated[ Optional[Sequence[params.Depends]], Doc( @@ -1104,7 +1116,7 @@ class APIRouter(routing.Router): def decorator(func: DecoratedCallable) -> DecoratedCallable: self.add_api_websocket_route( - path, func, name=name, dependencies=dependencies + path, func, name=name, tags=tags, dependencies=dependencies ) return func @@ -1344,11 +1356,18 @@ class APIRouter(routing.Router): current_dependencies.extend(dependencies) if route.dependencies: current_dependencies.extend(route.dependencies) + + current_tags = [] + if tags: + current_tags.extend(tags) + if route.tags: + current_tags.extend(route.tags) self.add_api_websocket_route( prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name, + tags=current_tags, ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 240a42bb0..3b24c04ff 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -103,6 +103,12 @@ class CustomError(Exception): async def router_ws_custom_error(websocket: WebSocket): raise CustomError() +@router.websocket("/test_tags/", name='test-tags', tags=["test"]) +async def router_ws_test_tags(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, router with tags!") + await websocket.close() + def make_app(app=None, **kwargs): app = app or FastAPI(**kwargs) @@ -269,3 +275,11 @@ def test_depend_err_handler(): pass # pragma: no cover assert e.value.code == 1002 assert "foo" in e.value.reason + + +def test_websocket_tags(): + """ + Verify that it is possible to add tags to websocket routes + """ + route = next(route for route in app.routes if route.name == 'test-tags') + assert route.tags == ["test"] \ No newline at end of file From 0969b08abec85d44a045c867ee59e14e5bf4c32a Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Fri, 18 Apr 2025 16:00:42 -0500 Subject: [PATCH 02/10] Add newline at end --- tests/test_ws_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 3b24c04ff..7c0fc7ab3 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -282,4 +282,4 @@ def test_websocket_tags(): Verify that it is possible to add tags to websocket routes """ route = next(route for route in app.routes if route.name == 'test-tags') - assert route.tags == ["test"] \ No newline at end of file + assert route.tags == ["test"] From 87749c1466c3d50deb17bae57098a44135948001 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:01:40 +0000 Subject: [PATCH 03/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_ws_router.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 7c0fc7ab3..0fc2527bb 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -103,7 +103,8 @@ class CustomError(Exception): async def router_ws_custom_error(websocket: WebSocket): raise CustomError() -@router.websocket("/test_tags/", name='test-tags', tags=["test"]) + +@router.websocket("/test_tags/", name="test-tags", tags=["test"]) async def router_ws_test_tags(websocket: WebSocket): await websocket.accept() await websocket.send_text("Hello, router with tags!") @@ -281,5 +282,5 @@ def test_websocket_tags(): """ Verify that it is possible to add tags to websocket routes """ - route = next(route for route in app.routes if route.name == 'test-tags') + route = next(route for route in app.routes if route.name == "test-tags") assert route.tags == ["test"] From 1e19d6a60ef0c1e7c17c61fb7173237a2a4ba3a5 Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Fri, 18 Apr 2025 16:06:10 -0500 Subject: [PATCH 04/10] Added pragma no cover --- tests/test_ws_router.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 7c0fc7ab3..c79a5304e 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -105,9 +105,7 @@ async def router_ws_custom_error(websocket: WebSocket): @router.websocket("/test_tags/", name='test-tags', tags=["test"]) async def router_ws_test_tags(websocket: WebSocket): - await websocket.accept() - await websocket.send_text("Hello, router with tags!") - await websocket.close() + pass # pragma: no cover def make_app(app=None, **kwargs): From 4f418eefe0d21a703289fd3521cf1e3b59252e81 Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Mon, 21 Apr 2025 12:17:41 -0500 Subject: [PATCH 05/10] Refactor some stuff --- fastapi/applications.py | 11 ++++++ fastapi/routing.py | 36 +++++++++++-------- .../test_bigger_applications/test_main.py | 2 +- tests/test_ws_router.py | 35 +++++++++++++----- 4 files changed, 59 insertions(+), 25 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..5c2d31a5d 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1179,12 +1179,14 @@ class FastAPI(Starlette): name: Optional[str] = None, *, dependencies: Optional[Sequence[Depends]] = None, + tags: Optional[List[Union[str, Enum]]] = None, ) -> None: self.router.add_api_websocket_route( path, endpoint, name=name, dependencies=dependencies, + tags=tags, ) def websocket( @@ -1218,6 +1220,14 @@ class FastAPI(Starlette): """ ), ] = None, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ) + ] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Decorate a WebSocket function. @@ -1247,6 +1257,7 @@ class FastAPI(Starlette): func, name=name, dependencies=dependencies, + tags=tags, ) return func diff --git a/fastapi/routing.py b/fastapi/routing.py index 9ee4f0d97..dfd206099 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -11,6 +11,7 @@ from typing import ( Callable, Coroutine, Dict, + Iterable, List, Mapping, Optional, @@ -920,9 +921,7 @@ class APIRouter(routing.Router): current_response_class = get_value_or_default( response_class, self.default_response_class ) - current_tags = self.tags.copy() - if tags: - current_tags.extend(tags) + current_tags = self.combine_tags(tags or []) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -937,7 +936,7 @@ class APIRouter(routing.Router): endpoint=endpoint, response_model=response_model, status_code=status_code, - tags=current_tags, + tags=list(current_tags), dependencies=current_dependencies, summary=summary, description=description, @@ -1037,11 +1036,13 @@ class APIRouter(routing.Router): if dependencies: current_dependencies.extend(dependencies) + current_tags = self.combine_tags(tags) + route = APIWebSocketRoute( self.prefix + path, endpoint=endpoint, name=name, - tags=tags, + tags=current_tags, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1290,11 +1291,7 @@ class APIRouter(routing.Router): default_response_class, self.default_response_class, ) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) + current_tags = self.combine_tags(tags, route) current_dependencies: List[params.Depends] = [] if dependencies: current_dependencies.extend(dependencies) @@ -1357,11 +1354,7 @@ class APIRouter(routing.Router): if route.dependencies: current_dependencies.extend(route.dependencies) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) + current_tags = self.combine_tags(tags, route) self.add_api_websocket_route( prefix + route.path, route.endpoint, @@ -4456,3 +4449,16 @@ class APIRouter(routing.Router): return func return decorator + + def combine_tags(self, *entities): + tags = set(self.tags or []) + for entity in entities: + if entity is None: + continue + if isinstance(entity, str): + tags.add(entity) + elif isinstance(entity, Iterable): + tags.update(entity) + elif hasattr(entity, "tags"): + tags = tags.union(entity.tags) + return sorted(tags) diff --git a/tests/test_tutorial/test_bigger_applications/test_main.py b/tests/test_tutorial/test_bigger_applications/test_main.py index fe40fad7d..6deedb57e 100644 --- a/tests/test_tutorial/test_bigger_applications/test_main.py +++ b/tests/test_tutorial/test_bigger_applications/test_main.py @@ -580,7 +580,7 @@ def test_openapi_schema(client: TestClient): }, }, "put": { - "tags": ["items", "custom"], + "tags": ["custom", "items"], "summary": "Update Item", "operationId": "update_item_items__item_id__put", "parameters": [ diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index bc3988326..740a3ccae 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -13,9 +13,9 @@ from fastapi import ( from fastapi.middleware import Middleware from fastapi.testclient import TestClient -router = APIRouter() -prefix_router = APIRouter() -native_prefix_route = APIRouter(prefix="/native") +router = APIRouter(tags=['base']) +prefix_router = APIRouter(tags=['prefix']) +native_prefix_router = APIRouter(prefix="/native", tags=['native']) app = FastAPI() @@ -68,7 +68,7 @@ async def router_ws_decorator_depends( await websocket.close() -@native_prefix_route.websocket("/") +@native_prefix_router.websocket("/") async def router_native_prefix_ws(websocket: WebSocket): await websocket.accept() await websocket.send_text("Hello, router with native prefix!") @@ -104,16 +104,27 @@ async def router_ws_custom_error(websocket: WebSocket): raise CustomError() -@router.websocket("/test_tags/", name="test-tags", tags=["test"]) +@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"]) + +@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"]) async def router_ws_test_tags(websocket: WebSocket): pass # pragma: no cover +@prefix_router.websocket("/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]) +async def prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + +@native_prefix_router.websocket("/test_tags/", name="test-native-prefix-router-tags", + tags=["test-native-prefix-router-tags"]) +async def native_prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + def make_app(app=None, **kwargs): app = app or FastAPI(**kwargs) app.include_router(router) app.include_router(prefix_router, prefix="/prefix") - app.include_router(native_prefix_route) + app.include_router(native_prefix_router) return app @@ -276,9 +287,15 @@ def test_depend_err_handler(): assert "foo" in e.value.reason -def test_websocket_tags(): +@pytest.mark.parametrize("route_name,route_tags", [ + ("test-app-tags", ["test-app-tags"]), + ("test-router-tags", ["base", "test-router-tags"]), + ("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]), + ("test-native-prefix-router-tags", ["native", "test-native-prefix-router-tags"]), +]) +def test_websocket_tags(route_name, route_tags): """ Verify that it is possible to add tags to websocket routes """ - route = next(route for route in app.routes if route.name == "test-tags") - assert route.tags == ["test"] + route = next(route for route in app.routes if route.name == route_name) + assert sorted(route.tags) == sorted(route_tags) From 220190fb38369d7f1a184364a5eeb627f58e2ffb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Apr 2025 17:17:50 +0000 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/applications.py | 4 ++-- tests/test_ws_router.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 5c2d31a5d..8d5325a68 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1226,8 +1226,8 @@ class FastAPI(Starlette): """ A list of tags to be applied to this WebSocket. """ - ) - ] = None + ), + ] = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Decorate a WebSocket function. diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index 740a3ccae..c917dc17b 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -13,9 +13,9 @@ from fastapi import ( from fastapi.middleware import Middleware from fastapi.testclient import TestClient -router = APIRouter(tags=['base']) -prefix_router = APIRouter(tags=['prefix']) -native_prefix_router = APIRouter(prefix="/native", tags=['native']) +router = APIRouter(tags=["base"]) +prefix_router = APIRouter(tags=["prefix"]) +native_prefix_router = APIRouter(prefix="/native", tags=["native"]) app = FastAPI() @@ -105,17 +105,23 @@ async def router_ws_custom_error(websocket: WebSocket): @app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"]) - @router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"]) async def router_ws_test_tags(websocket: WebSocket): pass # pragma: no cover -@prefix_router.websocket("/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]) + +@prefix_router.websocket( + "/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"] +) async def prefix_router_ws_test_tags(websocket: WebSocket): pass # pragma: no cover -@native_prefix_router.websocket("/test_tags/", name="test-native-prefix-router-tags", - tags=["test-native-prefix-router-tags"]) + +@native_prefix_router.websocket( + "/test_tags/", + name="test-native-prefix-router-tags", + tags=["test-native-prefix-router-tags"], +) async def native_prefix_router_ws_test_tags(websocket: WebSocket): pass # pragma: no cover @@ -287,12 +293,18 @@ def test_depend_err_handler(): assert "foo" in e.value.reason -@pytest.mark.parametrize("route_name,route_tags", [ - ("test-app-tags", ["test-app-tags"]), - ("test-router-tags", ["base", "test-router-tags"]), - ("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]), - ("test-native-prefix-router-tags", ["native", "test-native-prefix-router-tags"]), -]) +@pytest.mark.parametrize( + "route_name,route_tags", + [ + ("test-app-tags", ["test-app-tags"]), + ("test-router-tags", ["base", "test-router-tags"]), + ("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]), + ( + "test-native-prefix-router-tags", + ["native", "test-native-prefix-router-tags"], + ), + ], +) def test_websocket_tags(route_name, route_tags): """ Verify that it is possible to add tags to websocket routes From 33feaecb169a389fe609d19b61ebfa69ef9c9437 Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Mon, 21 Apr 2025 12:44:55 -0500 Subject: [PATCH 07/10] Fix untyped call --- fastapi/routing.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index dfd206099..84a2840aa 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -4450,7 +4450,18 @@ class APIRouter(routing.Router): return decorator - def combine_tags(self, *entities): + def combine_tags( + self, + *entities: Annotated[ + None | str | routing.Route | Iterable, + Doc( + """ + Combine the router's current tags with those of the provided entities. + Supports None, strings, iterables, and Route objects with a `tags` attribute. + """ + ), + ] + ) -> List[str]: tags = set(self.tags or []) for entity in entities: if entity is None: From 7a5795613aa373d1074289fe9a8990bd075058c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Apr 2025 17:45:04 +0000 Subject: [PATCH 08/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 84a2840aa..2f09899bc 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -4451,16 +4451,16 @@ class APIRouter(routing.Router): return decorator def combine_tags( - self, - *entities: Annotated[ - None | str | routing.Route | Iterable, - Doc( - """ + self, + *entities: Annotated[ + None | str | routing.Route | Iterable, + Doc( + """ Combine the router's current tags with those of the provided entities. Supports None, strings, iterables, and Route objects with a `tags` attribute. """ - ), - ] + ), + ], ) -> List[str]: tags = set(self.tags or []) for entity in entities: From 50ea206665aa99c1197ab675a4f48079a12cdafd Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Mon, 21 Apr 2025 16:06:33 -0500 Subject: [PATCH 09/10] Fix typings for earlier python versions --- fastapi/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 84a2840aa..e0ea705f3 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -4453,7 +4453,7 @@ class APIRouter(routing.Router): def combine_tags( self, *entities: Annotated[ - None | str | routing.Route | Iterable, + Union[None, str, routing.Route, Sequence], Doc( """ Combine the router's current tags with those of the provided entities. From b2855f032605594f7bf91186ed1dff068dd90cf7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Apr 2025 21:07:51 +0000 Subject: [PATCH 10/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index e0ea705f3..5e2df3369 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -4451,16 +4451,16 @@ class APIRouter(routing.Router): return decorator def combine_tags( - self, - *entities: Annotated[ - Union[None, str, routing.Route, Sequence], - Doc( - """ + self, + *entities: Annotated[ + Union[None, str, routing.Route, Sequence], + Doc( + """ Combine the router's current tags with those of the provided entities. Supports None, strings, iterables, and Route objects with a `tags` attribute. """ - ), - ] + ), + ], ) -> List[str]: tags = set(self.tags or []) for entity in entities: