Browse Source

Fix type-errors in commands extension

pull/7494/head
Josh 3 years ago
committed by GitHub
parent
commit
39c5a4fdc3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      discord/client.py
  2. 13
      discord/ext/commands/bot.py
  3. 19
      discord/ext/commands/converter.py
  4. 147
      discord/ext/commands/core.py
  5. 3
      discord/ext/commands/view.py

3
discord/client.py

@ -78,7 +78,8 @@ from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factor
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel
from .channel import DMChannel from .channel import DMChannel
from .message import Message from .message import Message
from .member import Member from .member import Member

13
discord/ext/commands/bot.py

@ -33,7 +33,7 @@ import importlib.util
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
import discord import discord
@ -65,6 +65,7 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc') CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context') CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Union[Bot, AutoShardedBot]')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
@ -932,7 +933,15 @@ class BotBase(GroupMixin):
return ret return ret
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT: @overload
async def get_context(self: BT, message: Message) -> Context[BT]:
...
@overload
async def get_context(self, message: Message, *, cls: Type[CXT] = ...) -> CXT:
...
async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Any:
r"""|coro| r"""|coro|
Returns the invocation context from the message. Returns the invocation context from the message.

19
discord/ext/commands/converter.py

@ -41,6 +41,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
runtime_checkable, runtime_checkable,
overload,
) )
import discord import discord
@ -48,7 +49,8 @@ from .errors import *
if TYPE_CHECKING: if TYPE_CHECKING:
from .context import Context from .context import Context
from discord.message import PartialMessageableChannel from discord.state import Channel
from discord.threads import Thread
from .bot import Bot, AutoShardedBot from .bot import Bot, AutoShardedBot
_Bot = Union[Bot, AutoShardedBot] _Bot = Union[Bot, AutoShardedBot]
@ -357,7 +359,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel( def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[PartialMessageableChannel]: ) -> Optional[Union[Channel, Thread]]:
if channel_id is None: if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel # we were passed just a message id so we can assume the channel is the current context channel
return ctx.channel return ctx.channel
@ -373,8 +375,8 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id) channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel: if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here
return discord.PartialMessage(channel=channel, id=message_id) return discord.PartialMessage(channel=channel, id=message_id)
@ -399,14 +401,14 @@ class MessageConverter(IDConverter[discord.Message]):
if message: if message:
return message return message
channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id) channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
if not channel: if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here
try: try:
return await channel.fetch_message(message_id) return await channel.fetch_message(message_id)
except discord.NotFound: except discord.NotFound:
raise MessageNotFound(argument) raise MessageNotFound(argument)
except discord.Forbidden: except discord.Forbidden:
raise ChannelNotReadable(channel) raise ChannelNotReadable(channel) # type: ignore - type-checker thinks channel could be a DMChannel at this point
class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
@ -449,7 +451,8 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
if guild: if guild:
result = guild.get_channel(channel_id) # guild.get_channel returns an explicit union instead of the base class
result = guild.get_channel(channel_id) # type: ignore
else: else:
result = _get_from_guilds(bot, 'get_channel', channel_id) result = _get_from_guilds(bot, 'get_channel', channel_id)

147
discord/ext/commands/core.py

@ -99,7 +99,7 @@ __all__ = (
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
CogT = TypeVar('CogT', bound='Cog') CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command') CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context') ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check') # CHT = TypeVar('CHT', bound='Check')
@ -307,7 +307,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
], ],
**kwargs: Any, **kwargs: Any,
): ) -> None:
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.') raise TypeError('Callback must be a coroutine.')
@ -372,7 +372,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.require_var_positional: bool = kwargs.get('require_var_positional', False) self.require_var_positional: bool = kwargs.get('require_var_positional', False)
self.ignore_extra: bool = kwargs.get('ignore_extra', True) self.ignore_extra: bool = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
self.cog: Optional[CogT] = None self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance # bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent') parent = kwargs.get('parent')
@ -1321,9 +1321,8 @@ class GroupMixin(Generic[CogT]):
@overload @overload
def command( def command(
self, self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[Command[CogT, P, T]] = ...,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[ ) -> Callable[
@ -1339,21 +1338,29 @@ class GroupMixin(Generic[CogT]):
@overload @overload
def command( def command(
self, self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[CommandT] = ..., cls: Type[CommandT] = ...,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: ) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
CommandT,
]:
... ...
def command( def command(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[CommandT] = MISSING, cls: Type[Command] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: ) -> Any:
"""A shortcut decorator that invokes :func:`.command` and adds it to """A shortcut decorator that invokes :func:`.command` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`. the internal command list via :meth:`~.GroupMixin.add_command`.
@ -1363,7 +1370,8 @@ class GroupMixin(Generic[CogT]):
A decorator that converts the provided method into a Command, adds it to the bot, then returns it. A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
""" """
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: def decorator(func):
kwargs.setdefault('parent', self) kwargs.setdefault('parent', self)
result = command(name=name, cls=cls, *args, **kwargs)(func) result = command(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result) self.add_command(result)
@ -1373,34 +1381,46 @@ class GroupMixin(Generic[CogT]):
@overload @overload
def group( def group(
self, self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[Group[CogT, P, T]] = ...,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[ ) -> Callable[
[Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], [
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Group[CogT, P, T], Group[CogT, P, T],
]: ]:
... ...
@overload @overload
def group( def group(
self, self: GroupMixin[CogT],
name: str = ..., name: str = ...,
cls: Type[GroupT] = ..., cls: Type[GroupT] = ...,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: ) -> Callable[
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
GroupT,
]:
... ...
def group( def group(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[GroupT] = MISSING, cls: Type[Group] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: ) -> Any:
"""A shortcut decorator that invokes :func:`.group` and adds it to """A shortcut decorator that invokes :func:`.group` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`. the internal command list via :meth:`~.GroupMixin.add_command`.
@ -1410,7 +1430,7 @@ class GroupMixin(Generic[CogT]):
A decorator that converts the provided method into a Group, adds it to the bot, then returns it. A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
""" """
def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: def decorator(func):
kwargs.setdefault('parent', self) kwargs.setdefault('parent', self)
result = group(name=name, cls=cls, *args, **kwargs)(func) result = group(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result) self.add_command(result)
@ -1533,21 +1553,39 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
# Decorators # Decorators
if TYPE_CHECKING:
# Using a class to emulate a function allows for overloading the inner function in the decorator.
class _CommandDecorator:
@overload
def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Command[CogT, P, T]:
...
@overload
def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Command[None, P, T]:
...
def __call__(self, func: Callable[..., Coro[T]], /) -> Any:
...
class _GroupDecorator:
@overload
def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Group[CogT, P, T]:
...
@overload
def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Group[None, P, T]:
...
def __call__(self, func: Callable[..., Coro[T]], /) -> Any:
...
@overload @overload
def command( def command(
name: str = ..., name: str = ...,
cls: Type[Command[CogT, P, T]] = ...,
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> _CommandDecorator:
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Command[CogT, P, T],
]:
... ...
@ -1559,8 +1597,8 @@ def command(
) -> Callable[ ) -> Callable[
[ [
Union[ Union[
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance
] ]
], ],
CommandT, CommandT,
@ -1570,17 +1608,9 @@ def command(
def command( def command(
name: str = MISSING, name: str = MISSING,
cls: Type[CommandT] = MISSING, cls: Type[Command] = MISSING,
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> Any:
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
],
Union[Command[CogT, P, T], CommandT],
]:
"""A decorator that transforms a function into a :class:`.Command` """A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`. or if called with :func:`.group`, :class:`.Group`.
@ -1611,14 +1641,9 @@ def command(
If the function is not a coroutine or is already a command. If the function is not a coroutine or is already a command.
""" """
if cls is MISSING: if cls is MISSING:
cls = Command # type: ignore cls = Command
def decorator( def decorator(func):
func: Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
]
) -> CommandT:
if isinstance(func, Command): if isinstance(func, Command):
raise TypeError('Callback is already a command.') raise TypeError('Callback is already a command.')
return cls(func, name=name, **attrs) return cls(func, name=name, **attrs)
@ -1629,17 +1654,8 @@ def command(
@overload @overload
def group( def group(
name: str = ..., name: str = ...,
cls: Type[Group[CogT, P, T]] = ...,
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> _GroupDecorator:
[
Union[
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
]
],
Group[CogT, P, T],
]:
... ...
@ -1651,7 +1667,7 @@ def group(
) -> Callable[ ) -> Callable[
[ [
Union[ Union[
Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance
Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]],
] ]
], ],
@ -1662,17 +1678,9 @@ def group(
def group( def group(
name: str = MISSING, name: str = MISSING,
cls: Type[GroupT] = MISSING, cls: Type[Group] = MISSING,
**attrs: Any, **attrs: Any,
) -> Callable[ ) -> Any:
[
Union[
Callable[Concatenate[ContextT, P], Coro[Any]],
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
]
],
Union[Group[CogT, P, T], GroupT],
]:
"""A decorator that transforms a function into a :class:`.Group`. """A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls`` This is similar to the :func:`.command` decorator but the ``cls``
@ -1682,8 +1690,9 @@ def group(
The ``cls`` parameter can now be passed. The ``cls`` parameter can now be passed.
""" """
if cls is MISSING: if cls is MISSING:
cls = Group # type: ignore cls = Group
return command(name=name, cls=cls, **attrs) # type: ignore
return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]: def check(predicate: Check) -> Callable[[T], T]:

3
discord/ext/commands/view.py

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes # map from opening quotes to closing quotes
@ -177,7 +176,7 @@ class StringView:
next_char = self.get() next_char = self.get()
valid_eof = not next_char or next_char.isspace() valid_eof = not next_char or next_char.isspace()
if not valid_eof: if not valid_eof:
raise InvalidEndOfQuotedStringError(next_char) raise InvalidEndOfQuotedStringError(next_char) # type: ignore - this will always be a string
# we're quoted so it's okay # we're quoted so it's okay
return ''.join(result) return ''.join(result)

Loading…
Cancel
Save