Browse Source

Make inspect.iscoroutinefunction use consistent across Python versions

pull/10438/head
Jakub Kuczys 3 months ago
parent
commit
f72295e231
No known key found for this signature in database GPG Key ID: 9F02686F15FCBCD3
  1. 27
      discord/app_commands/commands.py
  2. 4
      discord/app_commands/transformers.py
  3. 8
      discord/app_commands/tree.py
  4. 6
      discord/ext/commands/cog.py
  5. 6
      discord/ext/commands/core.py
  6. 10
      discord/ext/tasks/__init__.py
  7. 2
      discord/member.py
  8. 3
      discord/ui/button.py
  9. 4
      discord/ui/select.py

27
discord/app_commands/commands.py

@ -58,7 +58,16 @@ from ..message import Message
from ..user import User from ..user import User
from ..member import Member from ..member import Member
from ..permissions import Permissions from ..permissions import Permissions
from ..utils import resolve_annotation, MISSING, is_inside_class, maybe_coroutine, async_all, _shorten, _to_kebab_case from ..utils import (
resolve_annotation,
MISSING,
is_inside_class,
maybe_coroutine,
async_all,
_iscoroutinefunction,
_shorten,
_to_kebab_case,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec, Concatenate, Unpack from typing_extensions import ParamSpec, Concatenate, Unpack
@ -346,7 +355,7 @@ def _populate_autocomplete(params: Dict[str, CommandParameter], autocomplete: Di
if callback is MISSING: if callback is MISSING:
continue continue
if not inspect.iscoroutinefunction(callback): if not _iscoroutinefunction(callback):
raise TypeError('autocomplete callback must be a coroutine function') raise TypeError('autocomplete callback must be a coroutine function')
if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer):
@ -1037,7 +1046,7 @@ class Command(Generic[GroupT, P, T]):
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
self.on_error = coro self.on_error = coro
@ -1098,7 +1107,7 @@ class Command(Generic[GroupT, P, T]):
""" """
def decorator(coro: AutocompleteCallback[GroupT, ChoiceT]) -> AutocompleteCallback[GroupT, ChoiceT]: def decorator(coro: AutocompleteCallback[GroupT, ChoiceT]) -> AutocompleteCallback[GroupT, ChoiceT]:
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError('The autocomplete callback must be a coroutine function.') raise TypeError('The autocomplete callback must be a coroutine function.')
try: try:
@ -1347,7 +1356,7 @@ class ContextMenu:
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
self.on_error = coro self.on_error = coro
@ -1840,7 +1849,7 @@ class Group:
The coroutine passed is not actually a coroutine, or is an invalid coroutine. The coroutine passed is not actually a coroutine, or is an invalid coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
params = inspect.signature(coro).parameters params = inspect.signature(coro).parameters
@ -1990,7 +1999,7 @@ class Group:
""" """
def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]: def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function') raise TypeError('command function must be a coroutine function')
if description is MISSING: if description is MISSING:
@ -2051,7 +2060,7 @@ def command(
""" """
def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]: def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function') raise TypeError('command function must be a coroutine function')
if description is MISSING: if description is MISSING:
@ -2123,7 +2132,7 @@ def context_menu(
""" """
def decorator(func: ContextMenuCallback) -> ContextMenu: def decorator(func: ContextMenuCallback) -> ContextMenu:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('context menu function must be a coroutine function') raise TypeError('context menu function must be a coroutine function')
actual_name = func.__name__.title() if name is MISSING else name actual_name = func.__name__.title() if name is MISSING else name

4
discord/app_commands/transformers.py

@ -53,7 +53,7 @@ from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel,
from ..abc import GuildChannel from ..abc import GuildChannel
from ..threads import Thread from ..threads import Thread
from ..enums import Enum as InternalEnum, AppCommandOptionType, ChannelType, Locale from ..enums import Enum as InternalEnum, AppCommandOptionType, ChannelType, Locale
from ..utils import MISSING, maybe_coroutine, _human_join, TIMESTAMP_PATTERN from ..utils import MISSING, maybe_coroutine, _human_join, _iscoroutinefunction, TIMESTAMP_PATTERN
from ..user import User from ..user import User
from ..role import Role from ..role import Role
from ..member import Member from ..member import Member
@ -814,7 +814,7 @@ def get_supported_annotation(
params = inspect.signature(transform_classmethod.__func__).parameters params = inspect.signature(transform_classmethod.__func__).parameters
if len(params) != 3: if len(params) != 3:
raise TypeError('Inline transformer with transform classmethod requires 3 parameters') raise TypeError('Inline transformer with transform classmethod requires 3 parameters')
if not inspect.iscoroutinefunction(transform_classmethod.__func__): if not _iscoroutinefunction(transform_classmethod.__func__):
raise TypeError('Inline transformer with transform classmethod must be a coroutine') raise TypeError('Inline transformer with transform classmethod must be a coroutine')
return (InlineTransformer(annotation), MISSING, False) return (InlineTransformer(annotation), MISSING, False)

8
discord/app_commands/tree.py

@ -62,7 +62,7 @@ from .installs import AppCommandContext, AppInstallationType
from .translator import Translator, locale_str from .translator import Translator, locale_str
from ..errors import ClientException, HTTPException from ..errors import ClientException, HTTPException
from ..enums import AppCommandType, InteractionType from ..enums import AppCommandType, InteractionType
from ..utils import MISSING, _get_as_snowflake, _is_submodule, _shorten from ..utils import MISSING, _get_as_snowflake, _iscoroutinefunction, _is_submodule, _shorten
from .._types import ClientT from .._types import ClientT
@ -839,7 +839,7 @@ class CommandTree(Generic[ClientT]):
not match the signature. not match the signature.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
params = inspect.signature(coro).parameters params = inspect.signature(coro).parameters
@ -908,7 +908,7 @@ class CommandTree(Generic[ClientT]):
""" """
def decorator(func: CommandCallback[Group, P, T]) -> Command[Group, P, T]: def decorator(func: CommandCallback[Group, P, T]) -> Command[Group, P, T]:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function') raise TypeError('command function must be a coroutine function')
if description is MISSING: if description is MISSING:
@ -1005,7 +1005,7 @@ class CommandTree(Generic[ClientT]):
""" """
def decorator(func: ContextMenuCallback) -> ContextMenu: def decorator(func: ContextMenuCallback) -> ContextMenu:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('context menu function must be a coroutine function') raise TypeError('context menu function must be a coroutine function')
actual_name = func.__name__.title() if name is MISSING else name actual_name = func.__name__.title() if name is MISSING else name

6
discord/ext/commands/cog.py

@ -28,7 +28,7 @@ import inspect
import discord import discord
import logging import logging
from discord import app_commands from discord import app_commands
from discord.utils import maybe_coroutine, _to_kebab_case from discord.utils import maybe_coroutine, _iscoroutinefunction, _to_kebab_case
from typing import ( from typing import (
Any, Any,
@ -233,7 +233,7 @@ class CogMeta(type):
if elem.startswith(('cog_', 'bot_')): if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem)) raise TypeError(no_bot_cog.format(base, elem))
cog_app_commands[elem] = value cog_app_commands[elem] = value
elif inspect.iscoroutinefunction(value): elif _iscoroutinefunction(value):
try: try:
getattr(value, '__cog_listener__') getattr(value, '__cog_listener__')
except AttributeError: except AttributeError:
@ -522,7 +522,7 @@ class Cog(metaclass=CogMeta):
actual = func actual = func
if isinstance(actual, staticmethod): if isinstance(actual, staticmethod):
actual = actual.__func__ actual = actual.__func__
if not inspect.iscoroutinefunction(actual): if not _iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.') raise TypeError('Listener function must be a coroutine function.')
actual.__cog_listener__ = True actual.__cog_listener__ = True
to_assign = name or actual.__name__ to_assign = name or actual.__name__

6
discord/ext/commands/core.py

@ -1945,7 +1945,7 @@ def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
return func return func
if inspect.iscoroutinefunction(predicate): if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate decorator.predicate = predicate
else: else:
@ -2369,7 +2369,7 @@ def guild_only() -> Check[Any]:
return func return func
if inspect.iscoroutinefunction(predicate): if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate decorator.predicate = predicate
else: else:
@ -2444,7 +2444,7 @@ def is_nsfw() -> Check[Any]:
return func return func
if inspect.iscoroutinefunction(predicate): if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate decorator.predicate = predicate
else: else:

10
discord/ext/tasks/__init__.py

@ -46,7 +46,7 @@ import inspect
from collections.abc import Sequence from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING, _iscoroutinefunction
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -182,7 +182,7 @@ class Loop(Generic[LF]):
self._last_iteration: datetime.datetime = MISSING self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro): if not _iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
@ -574,7 +574,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.') raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._before_loop = coro self._before_loop = coro
@ -602,7 +602,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.') raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._after_loop = coro self._after_loop = coro
@ -632,7 +632,7 @@ class Loop(Generic[LF]):
TypeError TypeError
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.') raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._error = coro # type: ignore self._error = coro # type: ignore

2
discord/member.py

@ -190,7 +190,7 @@ def flatten_user(cls: T) -> T:
# probably a member function by now # probably a member function by now
def generate_function(x): def generate_function(x):
# We want sphinx to properly show coroutine functions as coroutines # We want sphinx to properly show coroutine functions as coroutines
if inspect.iscoroutinefunction(value): if utils._iscoroutinefunction(value):
async def general(self, *args, **kwargs): # type: ignore async def general(self, *args, **kwargs): # type: ignore
return await getattr(self._user, x)(*args, **kwargs) return await getattr(self._user, x)(*args, **kwargs)

3
discord/ui/button.py

@ -34,6 +34,7 @@ from .item import Item, ContainedItemCallbackType as ItemCallbackType, _ItemCall
from ..enums import ButtonStyle, ComponentType from ..enums import ButtonStyle, ComponentType
from ..partial_emoji import PartialEmoji, _EmojiTag from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent from ..components import Button as ButtonComponent
from ..utils import _iscoroutinefunction
__all__ = ( __all__ = (
'Button', 'Button',
@ -370,7 +371,7 @@ def button(
""" """
def decorator(func: ItemCallbackType[S, Button[V]]) -> ItemCallbackType[S, Button[V]]: def decorator(func: ItemCallbackType[S, Button[V]]) -> ItemCallbackType[S, Button[V]]:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function') raise TypeError('button function must be a coroutine function')
func.__discord_ui_model_type__ = Button func.__discord_ui_model_type__ = Button

4
discord/ui/select.py

@ -47,7 +47,7 @@ from .item import Item, ContainedItemCallbackType as ItemCallbackType, _ItemCall
from ..enums import ChannelType, ComponentType, SelectDefaultValueType from ..enums import ChannelType, ComponentType, SelectDefaultValueType
from ..partial_emoji import PartialEmoji from ..partial_emoji import PartialEmoji
from ..emoji import Emoji from ..emoji import Emoji
from ..utils import MISSING, _human_join from ..utils import MISSING, _human_join, _iscoroutinefunction
from ..components import ( from ..components import (
SelectOption, SelectOption,
SelectMenu, SelectMenu,
@ -1209,7 +1209,7 @@ def select(
""" """
def decorator(func: ItemCallbackType[S, BaseSelectT]) -> ItemCallbackType[S, BaseSelectT]: def decorator(func: ItemCallbackType[S, BaseSelectT]) -> ItemCallbackType[S, BaseSelectT]:
if not inspect.iscoroutinefunction(func): if not _iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function') raise TypeError('select function must be a coroutine function')
callback_cls = getattr(cls, '__origin__', cls) callback_cls = getattr(cls, '__origin__', cls)
if not issubclass(callback_cls, BaseSelect): if not issubclass(callback_cls, BaseSelect):

Loading…
Cancel
Save