Browse Source

Make inspect.iscoroutinefunction use consistent across Python versions

pull/10438/head
Jakub Kuczys 3 months ago
parent
commit
f72295e231
No known key found for this signature in database GPG Key ID: 9F02686F15FCBCD3
  1. 27
      discord/app_commands/commands.py
  2. 4
      discord/app_commands/transformers.py
  3. 8
      discord/app_commands/tree.py
  4. 6
      discord/ext/commands/cog.py
  5. 6
      discord/ext/commands/core.py
  6. 10
      discord/ext/tasks/__init__.py
  7. 2
      discord/member.py
  8. 3
      discord/ui/button.py
  9. 4
      discord/ui/select.py

27
discord/app_commands/commands.py

@ -58,7 +58,16 @@ from ..message import Message
from ..user import User
from ..member import Member
from ..permissions import Permissions
from ..utils import resolve_annotation, MISSING, is_inside_class, maybe_coroutine, async_all, _shorten, _to_kebab_case
from ..utils import (
resolve_annotation,
MISSING,
is_inside_class,
maybe_coroutine,
async_all,
_iscoroutinefunction,
_shorten,
_to_kebab_case,
)
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Concatenate, Unpack
@ -346,7 +355,7 @@ def _populate_autocomplete(params: Dict[str, CommandParameter], autocomplete: Di
if callback is MISSING:
continue
if not inspect.iscoroutinefunction(callback):
if not _iscoroutinefunction(callback):
raise TypeError('autocomplete callback must be a coroutine function')
if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer):
@ -1037,7 +1046,7 @@ class Command(Generic[GroupT, P, T]):
The coroutine passed is not actually a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
self.on_error = coro
@ -1098,7 +1107,7 @@ class Command(Generic[GroupT, P, T]):
"""
def decorator(coro: AutocompleteCallback[GroupT, ChoiceT]) -> AutocompleteCallback[GroupT, ChoiceT]:
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError('The autocomplete callback must be a coroutine function.')
try:
@ -1347,7 +1356,7 @@ class ContextMenu:
The coroutine passed is not actually a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
self.on_error = coro
@ -1840,7 +1849,7 @@ class Group:
The coroutine passed is not actually a coroutine, or is an invalid coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
params = inspect.signature(coro).parameters
@ -1990,7 +1999,7 @@ class Group:
"""
def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function')
if description is MISSING:
@ -2051,7 +2060,7 @@ def command(
"""
def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function')
if description is MISSING:
@ -2123,7 +2132,7 @@ def context_menu(
"""
def decorator(func: ContextMenuCallback) -> ContextMenu:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('context menu function must be a coroutine function')
actual_name = func.__name__.title() if name is MISSING else name

4
discord/app_commands/transformers.py

@ -53,7 +53,7 @@ from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel,
from ..abc import GuildChannel
from ..threads import Thread
from ..enums import Enum as InternalEnum, AppCommandOptionType, ChannelType, Locale
from ..utils import MISSING, maybe_coroutine, _human_join, TIMESTAMP_PATTERN
from ..utils import MISSING, maybe_coroutine, _human_join, _iscoroutinefunction, TIMESTAMP_PATTERN
from ..user import User
from ..role import Role
from ..member import Member
@ -814,7 +814,7 @@ def get_supported_annotation(
params = inspect.signature(transform_classmethod.__func__).parameters
if len(params) != 3:
raise TypeError('Inline transformer with transform classmethod requires 3 parameters')
if not inspect.iscoroutinefunction(transform_classmethod.__func__):
if not _iscoroutinefunction(transform_classmethod.__func__):
raise TypeError('Inline transformer with transform classmethod must be a coroutine')
return (InlineTransformer(annotation), MISSING, False)

8
discord/app_commands/tree.py

@ -62,7 +62,7 @@ from .installs import AppCommandContext, AppInstallationType
from .translator import Translator, locale_str
from ..errors import ClientException, HTTPException
from ..enums import AppCommandType, InteractionType
from ..utils import MISSING, _get_as_snowflake, _is_submodule, _shorten
from ..utils import MISSING, _get_as_snowflake, _iscoroutinefunction, _is_submodule, _shorten
from .._types import ClientT
@ -839,7 +839,7 @@ class CommandTree(Generic[ClientT]):
not match the signature.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
params = inspect.signature(coro).parameters
@ -908,7 +908,7 @@ class CommandTree(Generic[ClientT]):
"""
def decorator(func: CommandCallback[Group, P, T]) -> Command[Group, P, T]:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('command function must be a coroutine function')
if description is MISSING:
@ -1005,7 +1005,7 @@ class CommandTree(Generic[ClientT]):
"""
def decorator(func: ContextMenuCallback) -> ContextMenu:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('context menu function must be a coroutine function')
actual_name = func.__name__.title() if name is MISSING else name

6
discord/ext/commands/cog.py

@ -28,7 +28,7 @@ import inspect
import discord
import logging
from discord import app_commands
from discord.utils import maybe_coroutine, _to_kebab_case
from discord.utils import maybe_coroutine, _iscoroutinefunction, _to_kebab_case
from typing import (
Any,
@ -233,7 +233,7 @@ class CogMeta(type):
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
cog_app_commands[elem] = value
elif inspect.iscoroutinefunction(value):
elif _iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
except AttributeError:
@ -522,7 +522,7 @@ class Cog(metaclass=CogMeta):
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
if not _iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
actual.__cog_listener__ = True
to_assign = name or actual.__name__

6
discord/ext/commands/core.py

@ -1945,7 +1945,7 @@ def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
return func
if inspect.iscoroutinefunction(predicate):
if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate
else:
@ -2369,7 +2369,7 @@ def guild_only() -> Check[Any]:
return func
if inspect.iscoroutinefunction(predicate):
if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate
else:
@ -2444,7 +2444,7 @@ def is_nsfw() -> Check[Any]:
return func
if inspect.iscoroutinefunction(predicate):
if discord.utils._iscoroutinefunction(predicate):
decorator.predicate = predicate
else:

10
discord/ext/tasks/__init__.py

@ -46,7 +46,7 @@ import inspect
from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
from discord.utils import MISSING, _iscoroutinefunction
_log = logging.getLogger(__name__)
@ -182,7 +182,7 @@ class Loop(Generic[LF]):
self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
if not _iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
@ -574,7 +574,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._before_loop = coro
@ -602,7 +602,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._after_loop = coro
@ -632,7 +632,7 @@ class Loop(Generic[LF]):
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
if not _iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__}.')
self._error = coro # type: ignore

2
discord/member.py

@ -190,7 +190,7 @@ def flatten_user(cls: T) -> T:
# probably a member function by now
def generate_function(x):
# We want sphinx to properly show coroutine functions as coroutines
if inspect.iscoroutinefunction(value):
if utils._iscoroutinefunction(value):
async def general(self, *args, **kwargs): # type: ignore
return await getattr(self._user, x)(*args, **kwargs)

3
discord/ui/button.py

@ -34,6 +34,7 @@ from .item import Item, ContainedItemCallbackType as ItemCallbackType, _ItemCall
from ..enums import ButtonStyle, ComponentType
from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent
from ..utils import _iscoroutinefunction
__all__ = (
'Button',
@ -370,7 +371,7 @@ def button(
"""
def decorator(func: ItemCallbackType[S, Button[V]]) -> ItemCallbackType[S, Button[V]]:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
func.__discord_ui_model_type__ = Button

4
discord/ui/select.py

@ -47,7 +47,7 @@ from .item import Item, ContainedItemCallbackType as ItemCallbackType, _ItemCall
from ..enums import ChannelType, ComponentType, SelectDefaultValueType
from ..partial_emoji import PartialEmoji
from ..emoji import Emoji
from ..utils import MISSING, _human_join
from ..utils import MISSING, _human_join, _iscoroutinefunction
from ..components import (
SelectOption,
SelectMenu,
@ -1209,7 +1209,7 @@ def select(
"""
def decorator(func: ItemCallbackType[S, BaseSelectT]) -> ItemCallbackType[S, BaseSelectT]:
if not inspect.iscoroutinefunction(func):
if not _iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
callback_cls = getattr(cls, '__origin__', cls)
if not issubclass(callback_cls, BaseSelect):

Loading…
Cancel
Save