From 52a5f103075c77684fa62d6389ed7d77f9118da7 Mon Sep 17 00:00:00 2001 From: Josh Date: Mon, 13 Jun 2022 05:30:45 +1000 Subject: [PATCH] Fix type annotations to adhere to latest pyright release --- .github/workflows/lint.yml | 2 +- discord/errors.py | 8 ++------ discord/ext/commands/_types.py | 2 +- discord/ext/commands/bot.py | 3 +-- discord/ext/commands/context.py | 8 ++++++-- discord/ext/commands/converter.py | 5 +++-- discord/ext/commands/core.py | 24 ++++++++++++------------ discord/ext/commands/help.py | 5 ++--- discord/scheduled_event.py | 4 ++-- discord/utils.py | 8 ++++++-- discord/welcome_screen.py | 4 ++-- 11 files changed, 38 insertions(+), 35 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4e63b97b1..82aafc5db 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,7 +33,7 @@ jobs: - name: Run Pyright uses: jakebailey/pyright-action@v1 with: - version: '1.1.242' + version: '1.1.253' warnings: false no-comments: ${{ matrix.python-version != '3.x' }} diff --git a/discord/errors.py b/discord/errors.py index 2519e5678..22d34e7b3 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -29,13 +29,9 @@ from .utils import _get_as_snowflake if TYPE_CHECKING: from aiohttp import ClientResponse, ClientWebSocketResponse + from requests import Response - try: - from requests import Response - - _ResponseType = Union[ClientResponse, Response] - except ModuleNotFoundError: - _ResponseType = ClientResponse + _ResponseType = Union[ClientResponse, Response] __all__ = ( 'DiscordException', diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 7f2ff6ec4..9856a75d4 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -61,7 +61,7 @@ BotT = TypeVar('BotT', bound=_Bot, covariant=True) ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True) -class Check(Protocol[ContextT_co]): +class Check(Protocol[ContextT_co]): # type: ignore # TypeVar is expected to be invariant predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index a37b36d0c..57d8af9a5 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -379,8 +379,7 @@ class BotBase(GroupMixin[None]): if len(data) == 0: return True - # type-checker doesn't distinguish between functions and methods - return await discord.utils.async_all(f(ctx) for f in data) # type: ignore + return await discord.utils.async_all(f(ctx) for f in data) async def is_owner(self, user: User, /) -> bool: """|coro| diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 59b509148..319b5c471 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -34,7 +34,7 @@ from discord.message import Message from ._types import BotT if TYPE_CHECKING: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, TypeGuard from discord.abc import MessageableChannel from discord.commands import MessageCommand @@ -65,6 +65,10 @@ else: P = TypeVar('P') +def is_cog(obj: Any) -> TypeGuard[Cog]: + return hasattr(obj, '__cog_commands__') + + class Context(discord.abc.Messageable, Generic[BotT]): r"""Represents the context in which a command is being invoked under. @@ -393,7 +397,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): await cmd.prepare_help_command(self, entity.qualified_name) try: - if hasattr(entity, '__cog_commands__'): + if is_cog(entity): injected = wrap_callback(cmd.send_cog_help) return await injected(entity) elif isinstance(entity, Group): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 9abedda7f..64e775509 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -234,6 +234,7 @@ class MemberConverter(IDConverter[discord.Member]): guild = ctx.guild result = None user_id = None + if match is None: # not a mention... if guild: @@ -247,7 +248,7 @@ class MemberConverter(IDConverter[discord.Member]): else: result = _get_from_guilds(bot, 'get_member', user_id) - if result is None: + if not isinstance(result, discord.Member): if guild is None: raise MemberNotFound(argument) @@ -1172,7 +1173,7 @@ async def _actual_conversion(ctx: Context[BotT], converter: Any, argument: str, except CommandError: raise except Exception as exc: - raise ConversionError(converter, exc) from exc + raise ConversionError(converter, exc) from exc # type: ignore try: return converter(argument) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 82022cae5..c9c2bd42d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -351,8 +351,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): def __init__( self, func: Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]], + Callable[Concatenate[CogT, Context[Any], P], Coro[T]], + Callable[Concatenate[Context[Any], P], Coro[T]], ], /, **kwargs: Any, @@ -396,7 +396,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: checks = kwargs.get('checks', []) - self.checks: List[UserCheck[ContextT]] = checks + self.checks: List[UserCheck[Context[Any]]] = checks try: cooldown = func.__commands_cooldown__ @@ -468,7 +468,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) - def add_check(self, func: UserCheck[ContextT], /) -> None: + def add_check(self, func: UserCheck[Context[Any]], /) -> None: """Adds a check to the command. This is the non-decorator interface to :func:`.check`. @@ -487,7 +487,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.checks.append(func) - def remove_check(self, func: UserCheck[ContextT], /) -> None: + def remove_check(self, func: UserCheck[Context[Any]], /) -> None: """Removes a check from the command. This function is idempotent and will not raise an exception @@ -1236,7 +1236,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore + return await discord.utils.async_all(predicate(ctx) for predicate in predicates) finally: ctx.command = original @@ -1435,7 +1435,7 @@ class GroupMixin(Generic[CogT]): def command( self: GroupMixin[CogT], name: str = ..., - cls: Type[CommandT] = ..., + cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set *args: Any, **kwargs: Any, ) -> Callable[ @@ -1495,7 +1495,7 @@ class GroupMixin(Generic[CogT]): def group( self: GroupMixin[CogT], name: str = ..., - cls: Type[GroupT] = ..., + cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set *args: Any, **kwargs: Any, ) -> Callable[ @@ -1687,7 +1687,7 @@ def command( @overload def command( name: str = ..., - cls: Type[CommandT] = ..., + cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set **attrs: Any, ) -> Callable[ [ @@ -1757,7 +1757,7 @@ def group( @overload def group( name: str = ..., - cls: Type[GroupT] = ..., + cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set **attrs: Any, ) -> Callable[ [ @@ -1865,9 +1865,9 @@ def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]: The predicate to check if the command should be invoked. """ - def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + def decorator(func: Union[Command[Any, ..., Any], CoroFunc]) -> Union[Command[Any, ..., Any], CoroFunc]: if isinstance(func, Command): - func.checks.append(predicate) + func.checks.append(predicate) # type: ignore else: if not hasattr(func, '__commands_checks__'): func.__commands_checks__ = [] diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index d9c173991..34e47a5a6 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -62,7 +62,6 @@ if TYPE_CHECKING: from ._types import ( UserCheck, - ContextT, BotT, _Bot, ) @@ -377,7 +376,7 @@ class HelpCommand: bot.remove_command(self._command_impl.name) self._command_impl._eject_cog() - def add_check(self, func: UserCheck[ContextT], /) -> None: + def add_check(self, func: UserCheck[Context[Any]], /) -> None: """ Adds a check to the help command. @@ -395,7 +394,7 @@ class HelpCommand: self._command_impl.add_check(func) - def remove_check(self, func: UserCheck[ContextT], /) -> None: + def remove_check(self, func: UserCheck[Context[Any]], /) -> None: """ Removes a check from the help command. diff --git a/discord/scheduled_event.py b/discord/scheduled_event.py index 0ec6a875a..4b49b8a62 100644 --- a/discord/scheduled_event.py +++ b/discord/scheduled_event.py @@ -35,7 +35,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING if TYPE_CHECKING: from .types.scheduled_event import ( - GuildScheduledEvent as GuildScheduledEventPayload, + GuildScheduledEvent as BaseGuildScheduledEventPayload, GuildScheduledEventWithUserCount as GuildScheduledEventWithUserCountPayload, EntityMetadata, ) @@ -46,7 +46,7 @@ if TYPE_CHECKING: from .state import ConnectionState from .user import User - GuildScheduledEventPayload = Union[GuildScheduledEventPayload, GuildScheduledEventWithUserCountPayload] + GuildScheduledEventPayload = Union[BaseGuildScheduledEventPayload, GuildScheduledEventWithUserCountPayload] # fmt: off __all__ = ( diff --git a/discord/utils.py b/discord/utils.py index 49e824de2..c1915b877 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -142,7 +142,7 @@ if TYPE_CHECKING: from aiohttp import ClientSession from functools import cached_property as cached_property - from typing_extensions import ParamSpec, Self + from typing_extensions import ParamSpec, Self, TypeGuard from .permissions import Permissions from .abc import Messageable, Snowflake @@ -709,7 +709,11 @@ async def maybe_coroutine(f: MaybeAwaitableFunc[P, T], *args: P.args, **kwargs: return value # type: ignore -async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool: +async def async_all( + gen: Iterable[Union[T, Awaitable[T]]], + *, + check: Callable[[Union[T, Awaitable[T]]], TypeGuard[Awaitable[T]]] = _isawaitable, +) -> bool: for elem in gen: if check(elem): elem = await elem diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index 7e54151ef..130e6a8e2 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -126,8 +126,8 @@ class WelcomeScreen: state = self.guild._state channels = data.get('welcome_channels', []) - self.welcome_channels = [WelcomeChannel._from_dict(data=channel, state=state) for channel in channels] - self.description = data.get('description', '') + self.welcome_channels: List[WelcomeChannel] = [WelcomeChannel._from_dict(data=channel, state=state) for channel in channels] + self.description: str = data.get('description', '') def __repr__(self) -> str: return f''