|
|
@ -40,7 +40,6 @@ if TYPE_CHECKING: |
|
|
|
from typing_extensions import Self |
|
|
|
|
|
|
|
from ...message import Message |
|
|
|
from ._types import BotT |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'BucketType', |
|
|
@ -50,7 +49,7 @@ __all__ = ( |
|
|
|
'MaxConcurrency', |
|
|
|
) |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
T_contra = TypeVar('T_contra', contravariant=True) |
|
|
|
|
|
|
|
|
|
|
|
class BucketType(Enum): |
|
|
@ -62,7 +61,7 @@ class BucketType(Enum): |
|
|
|
category = 5 |
|
|
|
role = 6 |
|
|
|
|
|
|
|
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any: |
|
|
|
def get_key(self, msg: Union[Message, Context[Any]]) -> Any: |
|
|
|
if self is BucketType.user: |
|
|
|
return msg.author.id |
|
|
|
elif self is BucketType.guild: |
|
|
@ -80,24 +79,24 @@ class BucketType(Enum): |
|
|
|
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do |
|
|
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore |
|
|
|
|
|
|
|
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any: |
|
|
|
def __call__(self, msg: Union[Message, Context[Any]]) -> Any: |
|
|
|
return self.get_key(msg) |
|
|
|
|
|
|
|
|
|
|
|
class CooldownMapping(Generic[T]): |
|
|
|
class CooldownMapping(Generic[T_contra]): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
original: Optional[Cooldown], |
|
|
|
type: Callable[[T], Any], |
|
|
|
type: Callable[[T_contra], Any], |
|
|
|
) -> None: |
|
|
|
if not callable(type): |
|
|
|
raise TypeError('Cooldown type must be a BucketType or callable') |
|
|
|
|
|
|
|
self._cache: Dict[Any, Cooldown] = {} |
|
|
|
self._cooldown: Optional[Cooldown] = original |
|
|
|
self._type: Callable[[T], Any] = type |
|
|
|
self._type: Callable[[T_contra], Any] = type |
|
|
|
|
|
|
|
def copy(self) -> CooldownMapping: |
|
|
|
def copy(self) -> CooldownMapping[T_contra]: |
|
|
|
ret = CooldownMapping(self._cooldown, self._type) |
|
|
|
ret._cache = self._cache.copy() |
|
|
|
return ret |
|
|
@ -107,14 +106,14 @@ class CooldownMapping(Generic[T]): |
|
|
|
return self._cooldown is not None |
|
|
|
|
|
|
|
@property |
|
|
|
def type(self) -> Callable[[T], Any]: |
|
|
|
def type(self) -> Callable[[T_contra], Any]: |
|
|
|
return self._type |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self: |
|
|
|
def from_cooldown(cls, rate: float, per: float, type: Callable[[T_contra], Any]) -> Self: |
|
|
|
return cls(Cooldown(rate, per), type) |
|
|
|
|
|
|
|
def _bucket_key(self, msg: T) -> Any: |
|
|
|
def _bucket_key(self, msg: T_contra) -> Any: |
|
|
|
return self._type(msg) |
|
|
|
|
|
|
|
def _verify_cache_integrity(self, current: Optional[float] = None) -> None: |
|
|
@ -126,10 +125,10 @@ class CooldownMapping(Generic[T]): |
|
|
|
for k in dead_keys: |
|
|
|
del self._cache[k] |
|
|
|
|
|
|
|
def create_bucket(self, message: T) -> Cooldown: |
|
|
|
def create_bucket(self, message: T_contra) -> Cooldown: |
|
|
|
return self._cooldown.copy() # type: ignore |
|
|
|
|
|
|
|
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]: |
|
|
|
def get_bucket(self, message: T_contra, current: Optional[float] = None) -> Optional[Cooldown]: |
|
|
|
if self._type is BucketType.default: |
|
|
|
return self._cooldown |
|
|
|
|
|
|
@ -144,23 +143,23 @@ class CooldownMapping(Generic[T]): |
|
|
|
|
|
|
|
return bucket |
|
|
|
|
|
|
|
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: |
|
|
|
def update_rate_limit(self, message: T_contra, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: |
|
|
|
bucket = self.get_bucket(message, current) |
|
|
|
if bucket is None: |
|
|
|
return None |
|
|
|
return bucket.update_rate_limit(current, tokens=tokens) |
|
|
|
|
|
|
|
|
|
|
|
class DynamicCooldownMapping(CooldownMapping[T]): |
|
|
|
class DynamicCooldownMapping(CooldownMapping[T_contra]): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
factory: Callable[[T], Optional[Cooldown]], |
|
|
|
type: Callable[[T], Any], |
|
|
|
factory: Callable[[T_contra], Optional[Cooldown]], |
|
|
|
type: Callable[[T_contra], Any], |
|
|
|
) -> None: |
|
|
|
super().__init__(None, type) |
|
|
|
self._factory: Callable[[T], Optional[Cooldown]] = factory |
|
|
|
self._factory: Callable[[T_contra], Optional[Cooldown]] = factory |
|
|
|
|
|
|
|
def copy(self) -> DynamicCooldownMapping: |
|
|
|
def copy(self) -> DynamicCooldownMapping[T_contra]: |
|
|
|
ret = DynamicCooldownMapping(self._factory, self._type) |
|
|
|
ret._cache = self._cache.copy() |
|
|
|
return ret |
|
|
@ -169,7 +168,7 @@ class DynamicCooldownMapping(CooldownMapping[T]): |
|
|
|
def valid(self) -> bool: |
|
|
|
return True |
|
|
|
|
|
|
|
def create_bucket(self, message: T) -> Optional[Cooldown]: |
|
|
|
def create_bucket(self, message: T_contra) -> Optional[Cooldown]: |
|
|
|
return self._factory(message) |
|
|
|
|
|
|
|
|
|
|
@ -254,10 +253,10 @@ class MaxConcurrency: |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>' |
|
|
|
|
|
|
|
def get_key(self, message: Message) -> Any: |
|
|
|
def get_key(self, message: Union[Message, Context[Any]]) -> Any: |
|
|
|
return self.per.get_key(message) |
|
|
|
|
|
|
|
async def acquire(self, message: Message) -> None: |
|
|
|
async def acquire(self, message: Union[Message, Context[Any]]) -> None: |
|
|
|
key = self.get_key(message) |
|
|
|
|
|
|
|
try: |
|
|
@ -269,7 +268,7 @@ class MaxConcurrency: |
|
|
|
if not acquired: |
|
|
|
raise MaxConcurrencyReached(self.number, self.per) |
|
|
|
|
|
|
|
async def release(self, message: Message) -> None: |
|
|
|
async def release(self, message: Union[Message, Context[Any]]) -> None: |
|
|
|
# Technically there's no reason for this function to be async |
|
|
|
# But it might be more useful in the future |
|
|
|
key = self.get_key(message) |
|
|
|