From bbba8c650fcb54e9cd69b6a0f6b162fcf277f9f5 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Thu, 19 Jan 2023 07:00:09 -0500 Subject: [PATCH] Add missing generic parameters on various Interaction parameters --- discord/app_commands/tree.py | 20 ++++++++++---------- discord/client.py | 8 ++++---- discord/interactions.py | 8 ++++---- discord/ui/item.py | 3 ++- discord/ui/modal.py | 5 +++-- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index da038b536..61867633c 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -40,7 +40,6 @@ from typing import ( Sequence, Set, Tuple, - TypeVar, Union, overload, ) @@ -63,11 +62,12 @@ from .translator import Translator, locale_str from ..errors import ClientException, HTTPException from ..enums import AppCommandType, InteractionType from ..utils import MISSING, _get_as_snowflake, _is_submodule +from .._types import ClientT + if TYPE_CHECKING: from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption from ..interactions import Interaction - from ..client import Client from ..abc import Snowflake from .commands import ContextMenuCallback, CommandCallback, P, T @@ -78,8 +78,6 @@ if TYPE_CHECKING: __all__ = ('CommandTree',) -ClientT = TypeVar('ClientT', bound='Client') - _log = logging.getLogger(__name__) @@ -773,7 +771,7 @@ class CommandTree(Generic[ClientT]): for key in remove: del mapping[key] - async def on_error(self, interaction: Interaction, error: AppCommandError, /) -> None: + async def on_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: """|coro| A callback that is called when any command raises an :exc:`AppCommandError`. @@ -1076,7 +1074,7 @@ class CommandTree(Generic[ClientT]): return [AppCommand(data=d, state=self._state) for d in data] - async def _dispatch_error(self, interaction: Interaction, error: AppCommandError, /) -> None: + async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: command = interaction.command interaction.command_failed = True try: @@ -1085,7 +1083,7 @@ class CommandTree(Generic[ClientT]): finally: await self.on_error(interaction, error) - def _from_interaction(self, interaction: Interaction) -> None: + def _from_interaction(self, interaction: Interaction[ClientT]) -> None: async def wrapper(): try: await self._call(interaction) @@ -1156,7 +1154,9 @@ class CommandTree(Generic[ClientT]): return (command, options) - async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int) -> None: + async def _call_context_menu( + self, interaction: Interaction[ClientT], data: ApplicationCommandInteractionData, type: int + ) -> None: name = data['name'] guild_id = _get_as_snowflake(data, 'guild_id') ctx_menu = self._context_menus.get((name, guild_id, type)) @@ -1195,7 +1195,7 @@ class CommandTree(Generic[ClientT]): else: self.client.dispatch('app_command_completion', interaction, ctx_menu) - async def interaction_check(self, interaction: Interaction, /) -> bool: + async def interaction_check(self, interaction: Interaction[ClientT], /) -> bool: """|coro| A global check to determine if an :class:`~discord.Interaction` should @@ -1206,7 +1206,7 @@ class CommandTree(Generic[ClientT]): """ return True - async def _call(self, interaction: Interaction) -> None: + async def _call(self, interaction: Interaction[ClientT]) -> None: if not await self.interaction_check(interaction): interaction.command_failed = True return diff --git a/discord/client.py b/discord/client.py index f0156dd8f..116fea7be 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1161,9 +1161,9 @@ class Client: event: Literal['app_command_completion'], /, *, - check: Optional[Callable[[Interaction, Union[Command, ContextMenu]], bool]], + check: Optional[Callable[[Interaction[Self], Union[Command, ContextMenu]], bool]], timeout: Optional[float] = None, - ) -> Tuple[Interaction, Union[Command, ContextMenu]]: + ) -> Tuple[Interaction[Self], Union[Command, ContextMenu]]: ... # AutoMod @@ -1447,9 +1447,9 @@ class Client: event: Literal['interaction'], /, *, - check: Optional[Callable[[Interaction], bool]], + check: Optional[Callable[[Interaction[Self]], bool]], timeout: Optional[float] = None, - ) -> Interaction: + ) -> Interaction[Self]: ... # Members diff --git a/discord/interactions.py b/discord/interactions.py index 478a1a8b3..b43999927 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -302,7 +302,7 @@ class Interaction(Generic[ClientT]): return tree._get_context_menu(data) @utils.cached_slot_property('_cs_response') - def response(self) -> InteractionResponse: + def response(self) -> InteractionResponse[ClientT]: """:class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction. A response can only be done once. If secondary messages need to be sent, consider using :attr:`followup` @@ -548,7 +548,7 @@ class Interaction(Generic[ClientT]): return await translator.translate(string, locale=locale, context=context) -class InteractionResponse: +class InteractionResponse(Generic[ClientT]): """Represents a Discord interaction response. This type can be accessed through :attr:`Interaction.response`. @@ -561,8 +561,8 @@ class InteractionResponse: '_parent', ) - def __init__(self, parent: Interaction): - self._parent: Interaction = parent + def __init__(self, parent: Interaction[ClientT]): + self._parent: Interaction[ClientT] = parent self._response_type: Optional[InteractionResponseType] = None def is_done(self) -> bool: diff --git a/discord/ui/item.py b/discord/ui/item.py index 8703fbf90..443876c1a 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -27,6 +27,7 @@ from __future__ import annotations from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from ..interactions import Interaction +from .._types import ClientT # fmt: off __all__ = ( @@ -119,7 +120,7 @@ class Item(Generic[V]): """Optional[:class:`View`]: The underlying view for this item.""" return self._view - async def callback(self, interaction: Interaction) -> Any: + async def callback(self, interaction: Interaction[ClientT]) -> Any: """|coro| The callback associated with this UI item. diff --git a/discord/ui/modal.py b/discord/ui/modal.py index 1b71fb9c0..615930d5c 100644 --- a/discord/ui/modal.py +++ b/discord/ui/modal.py @@ -31,6 +31,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, ClassVar, List from ..utils import MISSING, find +from .._types import ClientT from .item import Item from .view import View @@ -134,7 +135,7 @@ class Modal(View): super().__init__(timeout=timeout) - async def on_submit(self, interaction: Interaction, /) -> None: + async def on_submit(self, interaction: Interaction[ClientT], /) -> None: """|coro| Called when the modal is submitted. @@ -146,7 +147,7 @@ class Modal(View): """ pass - async def on_error(self, interaction: Interaction, error: Exception, /) -> None: + async def on_error(self, interaction: Interaction[ClientT], error: Exception, /) -> None: """|coro| A callback that is called when :meth:`on_submit`