|
|
@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING |
|
|
|
from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING |
|
|
|
from discord.enums import Enum |
|
|
|
import time |
|
|
|
import asyncio |
|
|
@ -33,12 +33,14 @@ from collections import deque |
|
|
|
|
|
|
|
from ...abc import PrivateChannel |
|
|
|
from .errors import MaxConcurrencyReached |
|
|
|
from .context import Context |
|
|
|
from discord.app_commands import Cooldown as Cooldown |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from typing_extensions import Self |
|
|
|
|
|
|
|
from ...message import Message |
|
|
|
from ._types import BotT |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'BucketType', |
|
|
@ -48,6 +50,8 @@ __all__ = ( |
|
|
|
'MaxConcurrency', |
|
|
|
) |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
|
|
|
|
|
|
|
|
class BucketType(Enum): |
|
|
|
default = 0 |
|
|
@ -58,7 +62,7 @@ class BucketType(Enum): |
|
|
|
category = 5 |
|
|
|
role = 6 |
|
|
|
|
|
|
|
def get_key(self, msg: Message) -> Any: |
|
|
|
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any: |
|
|
|
if self is BucketType.user: |
|
|
|
return msg.author.id |
|
|
|
elif self is BucketType.guild: |
|
|
@ -76,22 +80,22 @@ class BucketType(Enum): |
|
|
|
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do |
|
|
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore |
|
|
|
|
|
|
|
def __call__(self, msg: Message) -> Any: |
|
|
|
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any: |
|
|
|
return self.get_key(msg) |
|
|
|
|
|
|
|
|
|
|
|
class CooldownMapping: |
|
|
|
class CooldownMapping(Generic[T]): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
original: Optional[Cooldown], |
|
|
|
type: Callable[[Message], Any], |
|
|
|
type: Callable[[T], Any], |
|
|
|
) -> None: |
|
|
|
if not callable(type): |
|
|
|
raise TypeError('Cooldown type must be a BucketType or callable') |
|
|
|
|
|
|
|
self._cache: Dict[Any, Cooldown] = {} |
|
|
|
self._cooldown: Optional[Cooldown] = original |
|
|
|
self._type: Callable[[Message], Any] = type |
|
|
|
self._type: Callable[[T], Any] = type |
|
|
|
|
|
|
|
def copy(self) -> CooldownMapping: |
|
|
|
ret = CooldownMapping(self._cooldown, self._type) |
|
|
@ -103,14 +107,14 @@ class CooldownMapping: |
|
|
|
return self._cooldown is not None |
|
|
|
|
|
|
|
@property |
|
|
|
def type(self) -> Callable[[Message], Any]: |
|
|
|
def type(self) -> Callable[[T], Any]: |
|
|
|
return self._type |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self: |
|
|
|
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self: |
|
|
|
return cls(Cooldown(rate, per), type) |
|
|
|
|
|
|
|
def _bucket_key(self, msg: Message) -> Any: |
|
|
|
def _bucket_key(self, msg: T) -> Any: |
|
|
|
return self._type(msg) |
|
|
|
|
|
|
|
def _verify_cache_integrity(self, current: Optional[float] = None) -> None: |
|
|
@ -122,10 +126,10 @@ class CooldownMapping: |
|
|
|
for k in dead_keys: |
|
|
|
del self._cache[k] |
|
|
|
|
|
|
|
def create_bucket(self, message: Message) -> Cooldown: |
|
|
|
def create_bucket(self, message: T) -> Cooldown: |
|
|
|
return self._cooldown.copy() # type: ignore |
|
|
|
|
|
|
|
def get_bucket(self, message: Message, current: Optional[float] = None) -> Optional[Cooldown]: |
|
|
|
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]: |
|
|
|
if self._type is BucketType.default: |
|
|
|
return self._cooldown |
|
|
|
|
|
|
@ -140,21 +144,21 @@ class CooldownMapping: |
|
|
|
|
|
|
|
return bucket |
|
|
|
|
|
|
|
def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: |
|
|
|
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: |
|
|
|
bucket = self.get_bucket(message, current) |
|
|
|
if bucket is None: |
|
|
|
return None |
|
|
|
return bucket.update_rate_limit(current, tokens=tokens) |
|
|
|
|
|
|
|
|
|
|
|
class DynamicCooldownMapping(CooldownMapping): |
|
|
|
class DynamicCooldownMapping(CooldownMapping[T]): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
factory: Callable[[Message], Optional[Cooldown]], |
|
|
|
type: Callable[[Message], Any], |
|
|
|
factory: Callable[[T], Optional[Cooldown]], |
|
|
|
type: Callable[[T], Any], |
|
|
|
) -> None: |
|
|
|
super().__init__(None, type) |
|
|
|
self._factory: Callable[[Message], Optional[Cooldown]] = factory |
|
|
|
self._factory: Callable[[T], Optional[Cooldown]] = factory |
|
|
|
|
|
|
|
def copy(self) -> DynamicCooldownMapping: |
|
|
|
ret = DynamicCooldownMapping(self._factory, self._type) |
|
|
@ -165,7 +169,7 @@ class DynamicCooldownMapping(CooldownMapping): |
|
|
|
def valid(self) -> bool: |
|
|
|
return True |
|
|
|
|
|
|
|
def create_bucket(self, message: Message) -> Optional[Cooldown]: |
|
|
|
def create_bucket(self, message: T) -> Optional[Cooldown]: |
|
|
|
return self._factory(message) |
|
|
|
|
|
|
|
|
|
|
|