Browse Source

Improve typing of app command transformers

This allows subclasses of transformers to specify a specialization for
interaction without violating covariance of parameter types
pull/9963/head
Michael H 6 months ago
committed by GitHub
parent
commit
3e168a93bf
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 32
      discord/app_commands/transformers.py

32
discord/app_commands/transformers.py

@ -34,6 +34,7 @@ from typing import (
ClassVar,
Coroutine,
Dict,
Generic,
List,
Literal,
Optional,
@ -56,6 +57,7 @@ from ..user import User
from ..role import Role
from ..member import Member
from ..message import Attachment
from .._types import ClientT
__all__ = (
'Transformer',
@ -191,7 +193,7 @@ class CommandParameter:
return self.name if self._rename is MISSING else str(self._rename)
class Transformer:
class Transformer(Generic[ClientT]):
"""The base class that allows a type annotation in an application command parameter
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type.
@ -304,7 +306,7 @@ class Transformer:
else:
return name
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
"""|maybecoro|
Transforms the converted option value into another value.
@ -324,7 +326,7 @@ class Transformer:
raise NotImplementedError('Derived classes need to implement this.')
async def autocomplete(
self, interaction: Interaction, value: Union[int, float, str], /
self, interaction: Interaction[ClientT], value: Union[int, float, str], /
) -> List[Choice[Union[int, float, str]]]:
"""|coro|
@ -352,7 +354,7 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.')
class IdentityTransformer(Transformer):
class IdentityTransformer(Transformer[ClientT]):
def __init__(self, type: AppCommandOptionType) -> None:
self._type = type
@ -360,7 +362,7 @@ class IdentityTransformer(Transformer):
def type(self) -> AppCommandOptionType:
return self._type
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return value
@ -489,7 +491,7 @@ class EnumNameTransformer(Transformer):
return self._enum[value]
class InlineTransformer(Transformer):
class InlineTransformer(Transformer[ClientT]):
def __init__(self, annotation: Any) -> None:
super().__init__()
self.annotation: Any = annotation
@ -502,7 +504,7 @@ class InlineTransformer(Transformer):
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.string
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return await self.annotation.transform(interaction, value)
@ -611,18 +613,18 @@ else:
return transformer
class MemberTransformer(Transformer):
class MemberTransformer(Transformer[ClientT]):
@property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user
async def transform(self, interaction: Interaction, value: Any, /) -> Member:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member:
if not isinstance(value, Member):
raise TransformerError(value, self.type, self)
return value
class BaseChannelTransformer(Transformer):
class BaseChannelTransformer(Transformer[ClientT]):
def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__()
if len(channel_types) == 1:
@ -654,22 +656,22 @@ class BaseChannelTransformer(Transformer):
def channel_types(self) -> List[ChannelType]:
return self._channel_types
async def transform(self, interaction: Interaction, value: Any, /):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
class RawChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any, /):
class RawChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if not isinstance(value, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return value
class UnionChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any, /):
class UnionChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if isinstance(value, self._types):
return value

Loading…
Cancel
Save