|
|
@ -12,6 +12,7 @@ from typing import ( |
|
|
|
Collection, |
|
|
|
Coroutine, |
|
|
|
Dict, |
|
|
|
Iterable, |
|
|
|
List, |
|
|
|
Mapping, |
|
|
|
Optional, |
|
|
@ -393,12 +394,14 @@ class APIWebSocketRoute(routing.WebSocketRoute): |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
*, |
|
|
|
name: Optional[str] = None, |
|
|
|
tags: Optional[List[Union[str, Enum]]] = None, |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
dependency_overrides_provider: Optional[Any] = None, |
|
|
|
) -> None: |
|
|
|
self.path = path |
|
|
|
self.endpoint = endpoint |
|
|
|
self.name = get_name(endpoint) if name is None else name |
|
|
|
self.tags: List[Union[str, Enum]] = tags or [] |
|
|
|
self.dependencies = list(dependencies or []) |
|
|
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path) |
|
|
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) |
|
|
@ -919,9 +922,7 @@ class APIRouter(routing.Router): |
|
|
|
current_response_class = get_value_or_default( |
|
|
|
response_class, self.default_response_class |
|
|
|
) |
|
|
|
current_tags = self.tags.copy() |
|
|
|
if tags: |
|
|
|
current_tags.extend(tags) |
|
|
|
current_tags = self.combine_tags(tags or []) |
|
|
|
current_dependencies = self.dependencies.copy() |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
@ -936,7 +937,7 @@ class APIRouter(routing.Router): |
|
|
|
endpoint=endpoint, |
|
|
|
response_model=response_model, |
|
|
|
status_code=status_code, |
|
|
|
tags=current_tags, |
|
|
|
tags=list(current_tags), |
|
|
|
dependencies=current_dependencies, |
|
|
|
summary=summary, |
|
|
|
description=description, |
|
|
@ -1029,16 +1030,20 @@ class APIRouter(routing.Router): |
|
|
|
endpoint: Callable[..., Any], |
|
|
|
name: Optional[str] = None, |
|
|
|
*, |
|
|
|
tags: Optional[List[Union[str, Enum]]] = None, |
|
|
|
dependencies: Optional[Sequence[params.Depends]] = None, |
|
|
|
) -> None: |
|
|
|
current_dependencies = self.dependencies.copy() |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
|
|
|
|
|
current_tags = self.combine_tags(tags) |
|
|
|
|
|
|
|
route = APIWebSocketRoute( |
|
|
|
self.prefix + path, |
|
|
|
endpoint=endpoint, |
|
|
|
name=name, |
|
|
|
tags=current_tags, |
|
|
|
dependencies=current_dependencies, |
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider, |
|
|
|
) |
|
|
@ -1063,6 +1068,14 @@ class APIRouter(routing.Router): |
|
|
|
), |
|
|
|
] = None, |
|
|
|
*, |
|
|
|
tags: Annotated[ |
|
|
|
Optional[List[Union[str, Enum]]], |
|
|
|
Doc( |
|
|
|
""" |
|
|
|
A list of tags to be applied to this WebSocket. |
|
|
|
""" |
|
|
|
), |
|
|
|
] = None, |
|
|
|
dependencies: Annotated[ |
|
|
|
Optional[Sequence[params.Depends]], |
|
|
|
Doc( |
|
|
@ -1105,7 +1118,7 @@ class APIRouter(routing.Router): |
|
|
|
|
|
|
|
def decorator(func: DecoratedCallable) -> DecoratedCallable: |
|
|
|
self.add_api_websocket_route( |
|
|
|
path, func, name=name, dependencies=dependencies |
|
|
|
path, func, name=name, tags=tags, dependencies=dependencies |
|
|
|
) |
|
|
|
return func |
|
|
|
|
|
|
@ -1279,11 +1292,7 @@ class APIRouter(routing.Router): |
|
|
|
default_response_class, |
|
|
|
self.default_response_class, |
|
|
|
) |
|
|
|
current_tags = [] |
|
|
|
if tags: |
|
|
|
current_tags.extend(tags) |
|
|
|
if route.tags: |
|
|
|
current_tags.extend(route.tags) |
|
|
|
current_tags = self.combine_tags(tags, route) |
|
|
|
current_dependencies: List[params.Depends] = [] |
|
|
|
if dependencies: |
|
|
|
current_dependencies.extend(dependencies) |
|
|
@ -1345,11 +1354,14 @@ class APIRouter(routing.Router): |
|
|
|
current_dependencies.extend(dependencies) |
|
|
|
if route.dependencies: |
|
|
|
current_dependencies.extend(route.dependencies) |
|
|
|
|
|
|
|
current_tags = self.combine_tags(tags, route) |
|
|
|
self.add_api_websocket_route( |
|
|
|
prefix + route.path, |
|
|
|
route.endpoint, |
|
|
|
dependencies=current_dependencies, |
|
|
|
name=route.name, |
|
|
|
tags=current_tags, |
|
|
|
) |
|
|
|
elif isinstance(route, routing.WebSocketRoute): |
|
|
|
self.add_websocket_route( |
|
|
@ -4438,3 +4450,27 @@ class APIRouter(routing.Router): |
|
|
|
return func |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
def combine_tags( |
|
|
|
self, |
|
|
|
*entities: Annotated[ |
|
|
|
Union[None, str, routing.Route, Sequence], |
|
|
|
Doc( |
|
|
|
""" |
|
|
|
Combine the router's current tags with those of the provided entities. |
|
|
|
Supports None, strings, iterables, and Route objects with a `tags` attribute. |
|
|
|
""" |
|
|
|
), |
|
|
|
], |
|
|
|
) -> List[str]: |
|
|
|
tags = set(self.tags or []) |
|
|
|
for entity in entities: |
|
|
|
if entity is None: |
|
|
|
continue |
|
|
|
if isinstance(entity, str): |
|
|
|
tags.add(entity) |
|
|
|
elif isinstance(entity, Iterable): |
|
|
|
tags.update(entity) |
|
|
|
elif hasattr(entity, "tags"): |
|
|
|
tags = tags.union(entity.tags) |
|
|
|
return sorted(tags) |
|
|
|