Browse Source

Fix various generics throughout the public interface

Fix CooldownMapping generic typing and ensure other public methods 
have proper generics
pull/8287/head
Bryan Forbes 3 years ago
committed by GitHub
parent
commit
07ad6951fb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      discord/app_commands/commands.py
  2. 45
      discord/ext/commands/cooldowns.py
  3. 20
      discord/ext/commands/core.py

2
discord/app_commands/commands.py

@ -134,7 +134,7 @@ else:
AutocompleteCallback = Callable[..., Coro[T]]
CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', CommandCallback, ContextMenuCallback]
CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', 'CommandCallback[Any, ..., Any]', ContextMenuCallback]
# The re module doesn't support \p{} so we have to list characters from Thai and Devanagari manually.
THAI_COMBINING = r'\u0e31-\u0e3a\u0e47-\u0e4e'

45
discord/ext/commands/cooldowns.py

@ -40,7 +40,6 @@ if TYPE_CHECKING:
from typing_extensions import Self
from ...message import Message
from ._types import BotT
__all__ = (
'BucketType',
@ -50,7 +49,7 @@ __all__ = (
'MaxConcurrency',
)
T = TypeVar('T')
T_contra = TypeVar('T_contra', contravariant=True)
class BucketType(Enum):
@ -62,7 +61,7 @@ class BucketType(Enum):
category = 5
role = 6
def get_key(self, msg: Union[Message, Context[BotT]]) -> Any:
def get_key(self, msg: Union[Message, Context[Any]]) -> Any:
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
@ -80,24 +79,24 @@ class BucketType(Enum):
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg: Union[Message, Context[BotT]]) -> Any:
def __call__(self, msg: Union[Message, Context[Any]]) -> Any:
return self.get_key(msg)
class CooldownMapping(Generic[T]):
class CooldownMapping(Generic[T_contra]):
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[T], Any],
type: Callable[[T_contra], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[T], Any] = type
self._type: Callable[[T_contra], Any] = type
def copy(self) -> CooldownMapping:
def copy(self) -> CooldownMapping[T_contra]:
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy()
return ret
@ -107,14 +106,14 @@ class CooldownMapping(Generic[T]):
return self._cooldown is not None
@property
def type(self) -> Callable[[T], Any]:
def type(self) -> Callable[[T_contra], Any]:
return self._type
@classmethod
def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self:
def from_cooldown(cls, rate: float, per: float, type: Callable[[T_contra], Any]) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: T) -> Any:
def _bucket_key(self, msg: T_contra) -> Any:
return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
@ -126,10 +125,10 @@ class CooldownMapping(Generic[T]):
for k in dead_keys:
del self._cache[k]
def create_bucket(self, message: T) -> Cooldown:
def create_bucket(self, message: T_contra) -> Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]:
def get_bucket(self, message: T_contra, current: Optional[float] = None) -> Optional[Cooldown]:
if self._type is BucketType.default:
return self._cooldown
@ -144,23 +143,23 @@ class CooldownMapping(Generic[T]):
return bucket
def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
def update_rate_limit(self, message: T_contra, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
bucket = self.get_bucket(message, current)
if bucket is None:
return None
return bucket.update_rate_limit(current, tokens=tokens)
class DynamicCooldownMapping(CooldownMapping[T]):
class DynamicCooldownMapping(CooldownMapping[T_contra]):
def __init__(
self,
factory: Callable[[T], Optional[Cooldown]],
type: Callable[[T], Any],
factory: Callable[[T_contra], Optional[Cooldown]],
type: Callable[[T_contra], Any],
) -> None:
super().__init__(None, type)
self._factory: Callable[[T], Optional[Cooldown]] = factory
self._factory: Callable[[T_contra], Optional[Cooldown]] = factory
def copy(self) -> DynamicCooldownMapping:
def copy(self) -> DynamicCooldownMapping[T_contra]:
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@ -169,7 +168,7 @@ class DynamicCooldownMapping(CooldownMapping[T]):
def valid(self) -> bool:
return True
def create_bucket(self, message: T) -> Optional[Cooldown]:
def create_bucket(self, message: T_contra) -> Optional[Cooldown]:
return self._factory(message)
@ -254,10 +253,10 @@ class MaxConcurrency:
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
def get_key(self, message: Message) -> Any:
def get_key(self, message: Union[Message, Context[Any]]) -> Any:
return self.per.get_key(message)
async def acquire(self, message: Message) -> None:
async def acquire(self, message: Union[Message, Context[Any]]) -> None:
key = self.get_key(message)
try:
@ -269,7 +268,7 @@ class MaxConcurrency:
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message: Message) -> None:
async def release(self, message: Union[Message, Context[Any]]) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)

20
discord/ext/commands/core.py

@ -91,9 +91,9 @@ __all__ = (
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CommandT = TypeVar('CommandT', bound='Command')
CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]')
if TYPE_CHECKING:
P = ParamSpec('P')
@ -404,10 +404,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
buckets: CooldownMapping[Context] = cooldown
buckets: CooldownMapping[Context[Any]] = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self._buckets: CooldownMapping[Context] = buckets
self._buckets: CooldownMapping[Context[Any]] = buckets
try:
max_concurrency = func.__commands_max_concurrency__
@ -452,15 +452,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
@property
def callback(
self,
) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]:
) -> Union[Callable[Concatenate[CogT, Context[Any], P], Coro[T]], Callable[Concatenate[Context[Any], P], Coro[T]],]:
return self._callback
@callback.setter
def callback(
self,
function: Union[
Callable[Concatenate[CogT, Context, P], Coro[T]],
Callable[Concatenate[Context, P], Coro[T]],
Callable[Concatenate[CogT, Context[Any], P], Coro[T]],
Callable[Concatenate[Context[Any], P], Coro[T]],
],
) -> None:
self._callback = function
@ -2394,7 +2394,7 @@ def is_nsfw() -> Check[Any]:
def cooldown(
rate: int,
per: float,
type: Union[BucketType, Callable[[Context], Any]] = BucketType.default,
type: Union[BucketType, Callable[[Context[Any]], Any]] = BucketType.default,
) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command`
@ -2433,8 +2433,8 @@ def cooldown(
def dynamic_cooldown(
cooldown: Callable[[Context], Optional[Cooldown]],
type: Union[BucketType, Callable[[Context], Any]],
cooldown: Callable[[Context[Any]], Optional[Cooldown]],
type: Union[BucketType, Callable[[Context[Any]], Any]],
) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command`

Loading…
Cancel
Save