Browse Source

Merge b2855f0326 into 6df50d40fe

pull/13626/merge
jmahoney-eab 5 days ago
committed by GitHub
parent
commit
13abc29396
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 11
      fastapi/applications.py
  2. 56
      fastapi/routing.py
  3. 2
      tests/test_tutorial/test_bigger_applications/test_main.py
  4. 52
      tests/test_ws_router.py

11
fastapi/applications.py

@ -1179,12 +1179,14 @@ class FastAPI(Starlette):
name: Optional[str] = None, name: Optional[str] = None,
*, *,
dependencies: Optional[Sequence[Depends]] = None, dependencies: Optional[Sequence[Depends]] = None,
tags: Optional[List[Union[str, Enum]]] = None,
) -> None: ) -> None:
self.router.add_api_websocket_route( self.router.add_api_websocket_route(
path, path,
endpoint, endpoint,
name=name, name=name,
dependencies=dependencies, dependencies=dependencies,
tags=tags,
) )
def websocket( def websocket(
@ -1218,6 +1220,14 @@ class FastAPI(Starlette):
""" """
), ),
] = None, ] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to this WebSocket.
"""
),
] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]: ) -> Callable[[DecoratedCallable], DecoratedCallable]:
""" """
Decorate a WebSocket function. Decorate a WebSocket function.
@ -1247,6 +1257,7 @@ class FastAPI(Starlette):
func, func,
name=name, name=name,
dependencies=dependencies, dependencies=dependencies,
tags=tags,
) )
return func return func

56
fastapi/routing.py

@ -12,6 +12,7 @@ from typing import (
Collection, Collection,
Coroutine, Coroutine,
Dict, Dict,
Iterable,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -393,12 +394,14 @@ class APIWebSocketRoute(routing.WebSocketRoute):
endpoint: Callable[..., Any], endpoint: Callable[..., Any],
*, *,
name: Optional[str] = None, name: Optional[str] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None, dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
@ -919,9 +922,7 @@ class APIRouter(routing.Router):
current_response_class = get_value_or_default( current_response_class = get_value_or_default(
response_class, self.default_response_class response_class, self.default_response_class
) )
current_tags = self.tags.copy() current_tags = self.combine_tags(tags or [])
if tags:
current_tags.extend(tags)
current_dependencies = self.dependencies.copy() current_dependencies = self.dependencies.copy()
if dependencies: if dependencies:
current_dependencies.extend(dependencies) current_dependencies.extend(dependencies)
@ -936,7 +937,7 @@ class APIRouter(routing.Router):
endpoint=endpoint, endpoint=endpoint,
response_model=response_model, response_model=response_model,
status_code=status_code, status_code=status_code,
tags=current_tags, tags=list(current_tags),
dependencies=current_dependencies, dependencies=current_dependencies,
summary=summary, summary=summary,
description=description, description=description,
@ -1029,16 +1030,20 @@ class APIRouter(routing.Router):
endpoint: Callable[..., Any], endpoint: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
*, *,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None, dependencies: Optional[Sequence[params.Depends]] = None,
) -> None: ) -> None:
current_dependencies = self.dependencies.copy() current_dependencies = self.dependencies.copy()
if dependencies: if dependencies:
current_dependencies.extend(dependencies) current_dependencies.extend(dependencies)
current_tags = self.combine_tags(tags)
route = APIWebSocketRoute( route = APIWebSocketRoute(
self.prefix + path, self.prefix + path,
endpoint=endpoint, endpoint=endpoint,
name=name, name=name,
tags=current_tags,
dependencies=current_dependencies, dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
) )
@ -1063,6 +1068,14 @@ class APIRouter(routing.Router):
), ),
] = None, ] = None,
*, *,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to this WebSocket.
"""
),
] = None,
dependencies: Annotated[ dependencies: Annotated[
Optional[Sequence[params.Depends]], Optional[Sequence[params.Depends]],
Doc( Doc(
@ -1105,7 +1118,7 @@ class APIRouter(routing.Router):
def decorator(func: DecoratedCallable) -> DecoratedCallable: def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route( self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies path, func, name=name, tags=tags, dependencies=dependencies
) )
return func return func
@ -1279,11 +1292,7 @@ class APIRouter(routing.Router):
default_response_class, default_response_class,
self.default_response_class, self.default_response_class,
) )
current_tags = [] current_tags = self.combine_tags(tags, route)
if tags:
current_tags.extend(tags)
if route.tags:
current_tags.extend(route.tags)
current_dependencies: List[params.Depends] = [] current_dependencies: List[params.Depends] = []
if dependencies: if dependencies:
current_dependencies.extend(dependencies) current_dependencies.extend(dependencies)
@ -1345,11 +1354,14 @@ class APIRouter(routing.Router):
current_dependencies.extend(dependencies) current_dependencies.extend(dependencies)
if route.dependencies: if route.dependencies:
current_dependencies.extend(route.dependencies) current_dependencies.extend(route.dependencies)
current_tags = self.combine_tags(tags, route)
self.add_api_websocket_route( self.add_api_websocket_route(
prefix + route.path, prefix + route.path,
route.endpoint, route.endpoint,
dependencies=current_dependencies, dependencies=current_dependencies,
name=route.name, name=route.name,
tags=current_tags,
) )
elif isinstance(route, routing.WebSocketRoute): elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route( self.add_websocket_route(
@ -4438,3 +4450,27 @@ class APIRouter(routing.Router):
return func return func
return decorator return decorator
def combine_tags(
self,
*entities: Annotated[
Union[None, str, routing.Route, Sequence],
Doc(
"""
Combine the router's current tags with those of the provided entities.
Supports None, strings, iterables, and Route objects with a `tags` attribute.
"""
),
],
) -> List[str]:
tags = set(self.tags or [])
for entity in entities:
if entity is None:
continue
if isinstance(entity, str):
tags.add(entity)
elif isinstance(entity, Iterable):
tags.update(entity)
elif hasattr(entity, "tags"):
tags = tags.union(entity.tags)
return sorted(tags)

2
tests/test_tutorial/test_bigger_applications/test_main.py

@ -580,7 +580,7 @@ def test_openapi_schema(client: TestClient):
}, },
}, },
"put": { "put": {
"tags": ["items", "custom"], "tags": ["custom", "items"],
"summary": "Update Item", "summary": "Update Item",
"operationId": "update_item_items__item_id__put", "operationId": "update_item_items__item_id__put",
"parameters": [ "parameters": [

52
tests/test_ws_router.py

@ -13,9 +13,9 @@ from fastapi import (
from fastapi.middleware import Middleware from fastapi.middleware import Middleware
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
router = APIRouter() router = APIRouter(tags=["base"])
prefix_router = APIRouter() prefix_router = APIRouter(tags=["prefix"])
native_prefix_route = APIRouter(prefix="/native") native_prefix_router = APIRouter(prefix="/native", tags=["native"])
app = FastAPI() app = FastAPI()
@ -68,7 +68,7 @@ async def router_ws_decorator_depends(
await websocket.close() await websocket.close()
@native_prefix_route.websocket("/") @native_prefix_router.websocket("/")
async def router_native_prefix_ws(websocket: WebSocket): async def router_native_prefix_ws(websocket: WebSocket):
await websocket.accept() await websocket.accept()
await websocket.send_text("Hello, router with native prefix!") await websocket.send_text("Hello, router with native prefix!")
@ -104,11 +104,33 @@ async def router_ws_custom_error(websocket: WebSocket):
raise CustomError() raise CustomError()
@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"])
@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"])
async def router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
@prefix_router.websocket(
"/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]
)
async def prefix_router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
@native_prefix_router.websocket(
"/test_tags/",
name="test-native-prefix-router-tags",
tags=["test-native-prefix-router-tags"],
)
async def native_prefix_router_ws_test_tags(websocket: WebSocket):
pass # pragma: no cover
def make_app(app=None, **kwargs): def make_app(app=None, **kwargs):
app = app or FastAPI(**kwargs) app = app or FastAPI(**kwargs)
app.include_router(router) app.include_router(router)
app.include_router(prefix_router, prefix="/prefix") app.include_router(prefix_router, prefix="/prefix")
app.include_router(native_prefix_route) app.include_router(native_prefix_router)
return app return app
@ -269,3 +291,23 @@ def test_depend_err_handler():
pass # pragma: no cover pass # pragma: no cover
assert e.value.code == 1002 assert e.value.code == 1002
assert "foo" in e.value.reason assert "foo" in e.value.reason
@pytest.mark.parametrize(
"route_name,route_tags",
[
("test-app-tags", ["test-app-tags"]),
("test-router-tags", ["base", "test-router-tags"]),
("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]),
(
"test-native-prefix-router-tags",
["native", "test-native-prefix-router-tags"],
),
],
)
def test_websocket_tags(route_name, route_tags):
"""
Verify that it is possible to add tags to websocket routes
"""
route = next(route for route in app.routes if route.name == route_name)
assert sorted(route.tags) == sorted(route_tags)

Loading…
Cancel
Save