diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index d012c52b9..e7b001727 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -34,6 +34,7 @@ from typing import ( ClassVar, Coroutine, Dict, + Generic, List, Literal, Optional, @@ -56,6 +57,7 @@ from ..user import User from ..role import Role from ..member import Member from ..message import Attachment +from .._types import ClientT __all__ = ( 'Transformer', @@ -191,7 +193,7 @@ class CommandParameter: return self.name if self._rename is MISSING else str(self._rename) -class Transformer: +class Transformer(Generic[ClientT]): """The base class that allows a type annotation in an application command parameter to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one from this type. @@ -304,7 +306,7 @@ class Transformer: else: return name - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: """|maybecoro| Transforms the converted option value into another value. @@ -324,7 +326,7 @@ class Transformer: raise NotImplementedError('Derived classes need to implement this.') async def autocomplete( - self, interaction: Interaction, value: Union[int, float, str], / + self, interaction: Interaction[ClientT], value: Union[int, float, str], / ) -> List[Choice[Union[int, float, str]]]: """|coro| @@ -352,7 +354,7 @@ class Transformer: raise NotImplementedError('Derived classes can implement this.') -class IdentityTransformer(Transformer): +class IdentityTransformer(Transformer[ClientT]): def __init__(self, type: AppCommandOptionType) -> None: self._type = type @@ -360,7 +362,7 @@ class IdentityTransformer(Transformer): def type(self) -> AppCommandOptionType: return self._type - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: return value @@ -489,7 +491,7 @@ class EnumNameTransformer(Transformer): return self._enum[value] -class InlineTransformer(Transformer): +class InlineTransformer(Transformer[ClientT]): def __init__(self, annotation: Any) -> None: super().__init__() self.annotation: Any = annotation @@ -502,7 +504,7 @@ class InlineTransformer(Transformer): def type(self) -> AppCommandOptionType: return AppCommandOptionType.string - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: return await self.annotation.transform(interaction, value) @@ -611,18 +613,18 @@ else: return transformer -class MemberTransformer(Transformer): +class MemberTransformer(Transformer[ClientT]): @property def type(self) -> AppCommandOptionType: return AppCommandOptionType.user - async def transform(self, interaction: Interaction, value: Any, /) -> Member: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member: if not isinstance(value, Member): raise TransformerError(value, self.type, self) return value -class BaseChannelTransformer(Transformer): +class BaseChannelTransformer(Transformer[ClientT]): def __init__(self, *channel_types: Type[Any]) -> None: super().__init__() if len(channel_types) == 1: @@ -654,22 +656,22 @@ class BaseChannelTransformer(Transformer): def channel_types(self) -> List[ChannelType]: return self._channel_types - async def transform(self, interaction: Interaction, value: Any, /): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): resolved = value.resolve() if resolved is None or not isinstance(resolved, self._types): raise TransformerError(value, AppCommandOptionType.channel, self) return resolved -class RawChannelTransformer(BaseChannelTransformer): - async def transform(self, interaction: Interaction, value: Any, /): +class RawChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): if not isinstance(value, self._types): raise TransformerError(value, AppCommandOptionType.channel, self) return value -class UnionChannelTransformer(BaseChannelTransformer): - async def transform(self, interaction: Interaction, value: Any, /): +class UnionChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): if isinstance(value, self._types): return value