diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py index 3da57d80b..61d66090e 100644 --- a/discord/ext/commands/__init__.py +++ b/discord/ext/commands/__init__.py @@ -9,11 +9,12 @@ An extension module to facilitate creation of bot commands. """ from .bot import * +from .cog import * from .context import * -from .core import * -from .errors import * -from .help import * from .converter import * from .cooldowns import * -from .cog import * +from .core import * +from .errors import * from .flags import * +from .help import * +from .parameters import * diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 09b63a4e5..c66b32035 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -23,18 +23,15 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -import inspect import re - -from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union - -from ._types import BotT +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union import discord.abc import discord.utils - from discord.message import Message +from ._types import BotT + if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -47,6 +44,7 @@ if TYPE_CHECKING: from .cog import Cog from .core import Command + from .parameters import Parameter from .view import StringView # fmt: off @@ -90,7 +88,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): A dictionary of transformed arguments that were passed into the command. Similar to :attr:`args`\, if this is accessed in the :func:`.on_command_error` event then this dict could be incomplete. - current_parameter: Optional[:class:`inspect.Parameter`] + current_parameter: Optional[:class:`Parameter`] The parameter that is currently being inspected and converted. This is only of use for within converters. @@ -143,7 +141,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): invoked_subcommand: Optional[Command[Any, ..., Any]] = None, subcommand_passed: Optional[str] = None, command_failed: bool = False, - current_parameter: Optional[inspect.Parameter] = None, + current_parameter: Optional[Parameter] = None, current_argument: Optional[str] = None, ): self.message: Message = message @@ -158,7 +156,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand self.subcommand_passed: Optional[str] = subcommand_passed self.command_failed: bool = command_failed - self.current_parameter: Optional[inspect.Parameter] = current_parameter + self.current_parameter: Optional[Parameter] = current_parameter self.current_argument: Optional[str] = current_argument self._state: ConnectionState = self.message._state @@ -357,7 +355,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): Any The result of the help command, if any. """ - from .core import Group, Command, wrap_callback + from .core import Command, Group, wrap_callback from .errors import CommandError bot = self.bot diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 93ffd1a6c..7bf29412d 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -24,35 +24,36 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -import re import inspect +import re from typing import ( + TYPE_CHECKING, Any, Dict, Generic, Iterable, + List, Literal, Optional, - TYPE_CHECKING, - List, Protocol, + Tuple, Type, TypeVar, - Tuple, Union, runtime_checkable, ) import discord + from .errors import * if TYPE_CHECKING: - from .context import Context from discord.state import Channel from discord.threads import Thread + from .parameters import Parameter from ._types import BotT, _Bot - + from .context import Context __all__ = ( 'Converter', @@ -1062,16 +1063,6 @@ def _convert_to_bool(argument: str) -> bool: raise BadBoolArgument(lowered) -def get_converter(param: inspect.Parameter) -> Any: - converter = param.annotation - if converter is param.empty: - if param.default is not param.empty: - converter = str if param.default is None else type(param.default) - else: - converter = str - return converter - - _GenericAlias = type(List[T]) @@ -1141,7 +1132,7 @@ async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc -async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any: +async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: Parameter) -> Any: """|coro| Runs converters for a given converter, argument, and parameter. @@ -1158,7 +1149,7 @@ async def run_converters(ctx: Context[BotT], converter: Any, argument: str, para The converter to run, this corresponds to the annotation in the function. argument: :class:`str` The argument to convert to. - param: :class:`inspect.Parameter` + param: :class:`Parameter` The parameter being converted. This is mainly for error reporting. Raises @@ -1183,7 +1174,7 @@ async def run_converters(ctx: Context[BotT], converter: Any, argument: str, para # with the other parameters if conv is _NoneType and param.kind != param.VAR_POSITIONAL: ctx.view.undo() - return None if param.default is param.empty else param.default + return None if param.required else await param.get_default(ctx) try: value = await run_converters(ctx, conv, argument, param) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c43b4320e..6ef424f56 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -23,54 +23,44 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations +import asyncio +import datetime +import functools +import inspect from typing import ( + TYPE_CHECKING, Any, Callable, Dict, Generator, Generic, - Literal, List, + Literal, Optional, - Union, Set, Tuple, - TypeVar, Type, - TYPE_CHECKING, + TypeVar, + Union, overload, ) -import asyncio -import functools -import inspect -import datetime import discord -from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping -from .converter import run_converters, get_converter, Greedy from ._types import _BaseCommand from .cog import Cog from .context import Context - +from .converter import Greedy, run_converters +from .cooldowns import BucketType, Cooldown, CooldownMapping, DynamicCooldownMapping, MaxConcurrency +from .errors import * +from .parameters import Parameter, Signature if TYPE_CHECKING: - from typing_extensions import Concatenate, ParamSpec, TypeGuard, Self + from typing_extensions import Concatenate, ParamSpec, Self, TypeGuard from discord.message import Message - from ._types import ( - BotT, - ContextT, - Coro, - CoroFunc, - Check, - Hook, - Error, - ErrorT, - HookT, - ) + from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, ErrorT, Hook, HookT __all__ = ( @@ -131,9 +121,9 @@ def get_signature_parameters( /, *, skip_parameters: Optional[int] = None, -) -> Dict[str, inspect.Parameter]: - signature = inspect.signature(function) - params = {} +) -> Dict[str, Parameter]: + signature = Signature.from_callable(function) + params: Dict[str, Parameter] = {} cache: Dict[str, Any] = {} eval_annotation = discord.utils.evaluate_annotation required_params = discord.utils.is_inside_class(function) + 1 if skip_parameters is None else skip_parameters @@ -145,10 +135,14 @@ def get_signature_parameters( next(iterator) for name, parameter in iterator: + default = parameter.default + if isinstance(default, Parameter): # update from the default + parameter._annotation = default.annotation + parameter._default = default.default + parameter._displayed_default = default._displayed_default + annotation = parameter.annotation - if annotation is parameter.empty: - params[name] = parameter - continue + if annotation is None: params[name] = parameter.replace(annotation=type(None)) continue @@ -435,7 +429,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns) + self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) def add_check(self, func: Check[ContextT], /) -> None: """Adds a check to the command. @@ -571,9 +565,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.bot.dispatch('command_error', ctx, error) - async def transform(self, ctx: Context[BotT], param: inspect.Parameter, /) -> Any: - required = param.default is param.empty - converter = get_converter(param) + async def transform(self, ctx: Context[BotT], param: Parameter, /) -> Any: + converter = param.converter consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() @@ -582,7 +575,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos(ctx, param, required, converter.converter) + return await self._transform_greedy_pos(ctx, param, param.required, converter.converter) elif param.kind == param.VAR_POSITIONAL: return await self._transform_greedy_var_pos(ctx, param, converter.converter) else: @@ -594,13 +587,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if view.eof: if param.kind == param.VAR_POSITIONAL: raise RuntimeError() # break the loop - if required: + if param.required: if self._is_typing_optional(param.annotation): return None if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible(): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) - return param.default + return await param.get_default(ctx) previous = view.index if consume_rest_is_special: @@ -619,9 +612,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # type-checker fails to narrow argument return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos( - self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any - ) -> Any: + async def _transform_greedy_pos(self, ctx: Context[BotT], param: Parameter, required: bool, converter: Any) -> Any: view = ctx.view result = [] while not view.eof: @@ -639,10 +630,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): result.append(value) if not result and not required: - return param.default + return await param.get_default(ctx) return result - async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any: + async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: @@ -655,8 +646,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return value @property - def clean_params(self) -> Dict[str, inspect.Parameter]: - """Dict[:class:`str`, :class:`inspect.Parameter`]: + def clean_params(self) -> Dict[str, Parameter]: + """Dict[:class:`str`, :class:`Parameter`]: Retrieves the parameter dictionary without the context or self parameters. Useful for inspecting signature. @@ -753,9 +744,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): elif param.kind == param.KEYWORD_ONLY: # kwarg only param denotes "consume rest" semantics if self.rest_is_raw: - converter = get_converter(param) ctx.current_argument = argument = view.read_rest() - kwargs[name] = await run_converters(ctx, converter, argument, param) + kwargs[name] = await run_converters(ctx, param.converter, argument, param) else: kwargs[name] = await self.transform(ctx, param) break @@ -1078,29 +1068,31 @@ class Command(_BaseCommand, Generic[CogT, P, T]): result = [] for name, param in params.items(): - greedy = isinstance(param.annotation, Greedy) + greedy = isinstance(param.converter, Greedy) optional = False # postpone evaluation of if it's an optional argument - # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the - # parameter signature is a literal list of it's values - annotation = param.annotation.converter if greedy else param.annotation + annotation = param.converter.converter if greedy else param.converter # type: ignore # needs conditional types origin = getattr(annotation, '__origin__', None) if not greedy and origin is Union: none_cls = type(None) - union_args = annotation.__args__ + union_args = annotation.__args__ # type: ignore # this is safe optional = union_args[-1] is none_cls if len(union_args) == 2 and optional: annotation = union_args[0] origin = getattr(annotation, '__origin__', None) + # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the + # parameter signature is a literal list of it's values if origin is Literal: - name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) - if param.default is not param.empty: + name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) # type: ignore # this is safe + if not param.required: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append(f'[{name}={param.default}]' if not greedy else f'[{name}={param.default}]...') + result.append( + f'[{name}={param.displayed_default}]' if not greedy else f'[{name}={param.displayed_default}]...' + ) continue else: result.append(f'[{name}]') diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 670eba4e9..037183894 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -24,22 +24,21 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Optional, Any, TYPE_CHECKING, List, Callable, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union from discord.errors import ClientException, DiscordException if TYPE_CHECKING: - from inspect import Parameter - - from .converter import Converter - from .context import Context - from .cooldowns import Cooldown, BucketType - from .flags import Flag from discord.abc import GuildChannel from discord.threads import Thread from discord.types.snowflake import Snowflake, SnowflakeList from ._types import BotT + from .context import Context + from .converter import Converter + from .cooldowns import BucketType, Cooldown + from .flags import Flag + from .parameters import Parameter __all__ = ( @@ -173,7 +172,7 @@ class MissingRequiredArgument(UserInputError): Attributes ----------- - param: :class:`inspect.Parameter` + param: :class:`Parameter` The argument that is missing. """ @@ -687,11 +686,11 @@ class MissingAnyRole(CheckFailure): missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1]) else: fmt = ' or '.join(missing) - message = f"You are missing at least one of the required roles: {fmt}" + message = f'You are missing at least one of the required roles: {fmt}' super().__init__(message) @@ -717,11 +716,11 @@ class BotMissingAnyRole(CheckFailure): missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1]) else: fmt = ' or '.join(missing) - message = f"Bot is missing at least one of the required roles: {fmt}" + message = f'Bot is missing at least one of the required roles: {fmt}' super().__init__(message) @@ -761,7 +760,7 @@ class MissingPermissions(CheckFailure): missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1]) else: fmt = ' and '.join(missing) message = f'You are missing {fmt} permission(s) to run this command.' @@ -786,7 +785,7 @@ class BotMissingPermissions(CheckFailure): missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1]) else: fmt = ' and '.join(missing) message = f'Bot requires {fmt} permission(s) to run this command.' diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index 64d57a145..76d8ba051 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -24,37 +24,17 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from .errors import ( - BadFlagArgument, - CommandError, - MissingFlagArgument, - TooManyFlags, - MissingRequiredFlag, -) - -from discord.utils import resolve_annotation -from .view import StringView -from .converter import run_converters - -from discord.utils import maybe_coroutine, MISSING -from dataclasses import dataclass, field -from typing import ( - Dict, - Iterator, - Literal, - Optional, - Pattern, - Set, - TYPE_CHECKING, - Tuple, - List, - Any, - Union, -) - import inspect -import sys import re +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Union + +from discord.utils import MISSING, maybe_coroutine, resolve_annotation + +from .converter import run_converters +from .errors import BadFlagArgument, CommandError, MissingFlagArgument, MissingRequiredFlag, TooManyFlags +from .view import StringView __all__ = ( 'Flag', @@ -66,9 +46,9 @@ __all__ = ( if TYPE_CHECKING: from typing_extensions import Self - from .context import Context - from ._types import BotT + from .context import Context + from .parameters import Parameter @dataclass @@ -351,7 +331,7 @@ class FlagsMeta(type): async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] - param: inspect.Parameter = ctx.current_parameter # type: ignore + param: Parameter = ctx.current_parameter # type: ignore while not view.eof: view.skip_ws() if view.eof: @@ -376,7 +356,7 @@ async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, conve async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] - param: inspect.Parameter = ctx.current_parameter # type: ignore + param: Parameter = ctx.current_parameter # type: ignore for converter in converters: view.skip_ws() if view.eof: @@ -402,7 +382,7 @@ async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, conv async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any: - param: inspect.Parameter = ctx.current_parameter # type: ignore + param: Parameter = ctx.current_parameter # type: ignore annotation = annotation or flag.annotation try: origin = annotation.__origin__ diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 01458ab28..ca2dab52e 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -51,13 +51,13 @@ from .errors import CommandError if TYPE_CHECKING: from typing_extensions import Self - import inspect import discord.abc from .bot import BotBase from .context import Context from .cog import Cog + from .parameters import Parameter from ._types import ( Check, @@ -224,9 +224,7 @@ class _HelpCommandImpl(Command): super().__init__(inject.command_callback, *args, **kwargs) self._original: HelpCommand = inject self._injected: HelpCommand = inject - self.params: Dict[str, inspect.Parameter] = get_signature_parameters( - inject.command_callback, globals(), skip_parameters=1 - ) + self.params: Dict[str, Parameter] = get_signature_parameters(inject.command_callback, globals(), skip_parameters=1) async def prepare(self, ctx: Context[Any]) -> None: self._injected = injected = self._original.copy() @@ -1021,7 +1019,7 @@ class DefaultHelpCommand(HelpCommand): self.sort_commands: bool = options.pop('sort_commands', True) self.dm_help: bool = options.pop('dm_help', False) self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000) - self.commands_heading: str = options.pop('commands_heading', "Commands:") + self.commands_heading: str = options.pop('commands_heading', 'Commands:') self.no_category: str = options.pop('no_category', 'No Category') self.paginator: Paginator = options.pop('paginator', None) @@ -1045,8 +1043,8 @@ class DefaultHelpCommand(HelpCommand): """:class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes.""" command_name = self.invoked_with return ( - f"Type {self.context.clean_prefix}{command_name} command for more info on a command.\n" - f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category." + f'Type {self.context.clean_prefix}{command_name} command for more info on a command.\n' + f'You can also type {self.context.clean_prefix}{command_name} category for more info on a category.' ) def add_indented_commands( @@ -1235,10 +1233,10 @@ class MinimalHelpCommand(HelpCommand): def __init__(self, **options: Any) -> None: self.sort_commands: bool = options.pop('sort_commands', True) - self.commands_heading: str = options.pop('commands_heading', "Commands") + self.commands_heading: str = options.pop('commands_heading', 'Commands') self.dm_help: bool = options.pop('dm_help', False) self.dm_help_threshold: int = options.pop('dm_help_threshold', 1000) - self.aliases_heading: str = options.pop('aliases_heading', "Aliases:") + self.aliases_heading: str = options.pop('aliases_heading', 'Aliases:') self.no_category: str = options.pop('no_category', 'No Category') self.paginator: Paginator = options.pop('paginator', None) @@ -1268,8 +1266,8 @@ class MinimalHelpCommand(HelpCommand): """ command_name = self.invoked_with return ( - f"Use `{self.context.clean_prefix}{command_name} [command]` for more info on a command.\n" - f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category." + f'Use `{self.context.clean_prefix}{command_name} [command]` for more info on a command.\n' + f'You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category.' ) def get_command_signature(self, command: Command[Any, ..., Any], /) -> str: diff --git a/discord/ext/commands/parameters.py b/discord/ext/commands/parameters.py new file mode 100644 index 000000000..404b77ff1 --- /dev/null +++ b/discord/ext/commands/parameters.py @@ -0,0 +1,246 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import inspect +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Literal, Optional, OrderedDict, Union + +from discord.utils import MISSING, maybe_coroutine + +from . import converter +from .errors import MissingRequiredArgument + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord import Guild, Member, TextChannel, User + + from .context import Context + +__all__ = ( + 'Parameter', + 'parameter', + 'param', + 'Author', + 'CurrentChannel', + 'CurrentGuild', +) + + +ParamKinds = Union[ + Literal[inspect.Parameter.POSITIONAL_ONLY], + Literal[inspect.Parameter.POSITIONAL_OR_KEYWORD], + Literal[inspect.Parameter.VAR_POSITIONAL], + Literal[inspect.Parameter.KEYWORD_ONLY], + Literal[inspect.Parameter.VAR_KEYWORD], +] + +empty: Any = inspect.Parameter.empty + + +def _gen_property(name: str) -> property: + attr = f'_{name}' + return property( + attrgetter(attr), + lambda self, value: setattr(self, attr, value), + doc="The parameter's {name}.", + ) + + +class Parameter(inspect.Parameter): + r"""A class that stores information on a :class:`Command`\'s parameter. + This is a subclass of :class:`inspect.Parameter`. + + .. versionadded:: 2.0 + """ + + __slots__ = ('_displayed_default',) + + def __init__( + self, + name: str, + kind: ParamKinds, + default: Any = empty, + annotation: Any = empty, + displayed_default: str = empty, + ) -> None: + super().__init__(name=name, kind=kind, default=default, annotation=annotation) + self._name = name + self._kind = kind + self._default = default + self._annotation = annotation + self._displayed_default = displayed_default + + def replace( + self, + *, + name: str = MISSING, # MISSING here cause empty is valid + kind: ParamKinds = MISSING, + default: Any = MISSING, + annotation: Any = MISSING, + displayed_default: Any = MISSING, + ) -> Self: + if name is MISSING: + name = self._name + if kind is MISSING: + kind = self._kind # type: ignore # this assignment is actually safe + if default is MISSING: + default = self._default + if annotation is MISSING: + annotation = self._annotation + if displayed_default is MISSING: + displayed_default = self._displayed_default + + return self.__class__( + name=name, + kind=kind, + default=default, + annotation=annotation, + displayed_default=displayed_default, + ) + + if not TYPE_CHECKING: # this is to prevent anything breaking if inspect internals change + name = _gen_property('name') + kind = _gen_property('kind') + default = _gen_property('default') + annotation = _gen_property('annotation') + + @property + def required(self) -> bool: + """:class:`bool`: Whether this parameter is required.""" + return self.default is empty + + @property + def converter(self) -> Any: + """The converter that should be used for this parameter.""" + if self.annotation is empty: + return type(self.default) if self.default not in (empty, None) else str + + return self.annotation + + @property + def displayed_default(self) -> Optional[str]: + """Optional[:class:`str`]: The displayed default in :class:`Command.signature`.""" + if self._displayed_default is not empty: + return self._displayed_default + + return None if self.required else str(self.default) + + async def get_default(self, ctx: Context) -> Any: + """|coro| + + Gets this parameter's default value. + + Parameters + ---------- + ctx: :class:`Context` + The invocation context that is used to get the default argument. + """ + # pre-condition: required is False + if callable(self.default): + return await maybe_coroutine(self.default, ctx) # type: ignore + return self.default + + +def parameter( + *, + converter: Any = empty, + default: Any = empty, + displayed_default: str = empty, +) -> Any: + r"""parameter(\*, converter=..., default=..., displayed_default=...) + + A way to assign custom metadata for a :class:`Command`\'s parameter. + + .. versionadded:: 2.0 + + Examples + -------- + A custom default can be used to have late binding behaviour. + + .. code-block:: python3 + + @bot.command() + async def wave(to: discord.User = commands.parameter(default=lambda ctx: ctx.author)): + await ctx.send(f'Hello {to.mention} :wave:') + + Parameters + ---------- + converter: Any + The converter to use for this parameter, this replaces the annotation at runtime which is transparent to type checkers. + default: Any + The default value for the parameter, if this is a :term:`callable` or a |coroutine_link|_ it is called with a + positional :class:`Context` argument. + displayed_default: :class:`str` + The displayed default in :attr:`Command.signature`. + """ + return Parameter( + name='empty', + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=converter, + default=default, + displayed_default=displayed_default, + ) + + +param = parameter +r"""param(\*, converter=..., default=..., displayed_default=...) + +An alias for :func:`parameter`. + +.. versionadded:: 2.0 +""" + +# some handy defaults +Author: Union[Member, User] = parameter( + default=attrgetter('author'), + displayed_default='', + converter=Union[converter.MemberConverter, converter.UserConverter], +) + +CurrentChannel: TextChannel = parameter( + default=attrgetter('channel'), + displayed_default='', + converter=converter.TextChannelConverter, +) + + +def default_guild(ctx: Context) -> Guild: + if ctx.guild is not None: + return ctx.guild + raise MissingRequiredArgument(ctx.current_parameter) # type: ignore # this is never going to be None + + +CurrentGuild: Guild = parameter( + default=default_guild, + displayed_default='', + converter=converter.GuildConverter, +) + + +class Signature(inspect.Signature): + _parameter_cls = Parameter + parameters: OrderedDict[str, Parameter] diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index ac3357817..5d46b3c26 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -429,6 +429,35 @@ Flag Converter .. autofunction:: discord.ext.commands.flag + +Defaults +-------- + +.. autoclass:: discord.ext.commands.Parameter() + :members: + +.. autofunction:: discord.ext.commands.parameter + +.. autofunction:: discord.ext.commands.param + +.. data:: discord.ext.commands.Author + + A default :class:`.Parameter` which returns the :attr:`~.Context.author` for this context. + + .. versionadded:: 2.0 + +.. data:: discord.ext.commands.CurrentChannel + + A default :class:`.Parameter` which returns the :attr:`~.Context.channel` for this context. + + .. versionadded:: 2.0 + +.. data:: discord.ext.commands.CurrentGuild + + A default :class:`.Parameter` which returns the :attr:`~.Context.guild` for this context. This will never be ``None``. + + .. versionadded:: 2.0 + .. _ext_commands_api_errors: Exceptions diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index 4fa97ae5c..d31809f58 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -768,6 +768,58 @@ A :class:`dict` annotation is functionally equivalent to ``List[Tuple[K, V]]`` e given as a :class:`dict` rather than a :class:`list`. +.. _ext_commands_parameter: + +Parameter Metadata +------------------- + +:func:`~ext.commands.parameter` assigns custom metadata to a :class:`~ext.commands.Command`'s parameter. + +This is useful for: + +- Custom converters as annotating a parameter with a custom converter works at runtime, type checkers don't like it + because they can't understand what's going on. + + .. code-block:: python3 + + class SomeType: + foo: int + + class MyVeryCoolConverter(commands.Converter[SomeType]): + ... # implementation left as an exercise for the reader + + @bot.command() + async def bar(ctx, cool_value: MyVeryCoolConverter): + cool_value.foo # type checker warns MyVeryCoolConverter has no value foo (uh-oh) + + However, fear not we can use :func:`~ext.commands.parameter` to tell type checkers what's going on. + + .. code-block:: python3 + + @bot.command() + async def bar(ctx, cool_value: SomeType = commands.parameter(converter=MyVeryCoolConverter)): + cool_value.foo # no error (hurray) + +- Late binding behaviour + + .. code-block:: python3 + + @bot.command() + async def wave(to: discord.User = commands.parameter(default=lambda ctx: ctx.author)): + await ctx.send(f'Hello {to.mention} :wave:') + + Because this is such a common use-case, the library provides :obj:`~.ext.commands.Author`, :obj:`~.ext.commands.CurrentChannel` and + :obj:`~.ext.commands.CurrentGuild`, armed with this we can simplify ``wave`` to: + + .. code-block:: python3 + + @bot.command() + async def wave(to: discord.User = commands.Author): + await ctx.send(f'Hello {to.mention} :wave:') + + :obj:`~.ext.commands.Author` and co also have other benefits like having the displayed default being filled. + + .. _ext_commands_error_handler: Error Handling diff --git a/docs/migrating.rst b/docs/migrating.rst index 31e137556..b1121d074 100644 --- a/docs/migrating.rst +++ b/docs/migrating.rst @@ -112,7 +112,7 @@ Quick example: With this change, constructor of :class:`Client` no longer accepts ``connector`` and ``loop`` parameters. -In parallel with this change, changes were made to loading and unloading of commands extension extensions and cogs, +In parallel with this change, changes were made to loading and unloading of commands extension extensions and cogs, see :ref:`migrating_2_0_commands_extension_cog_async` for more information. Abstract Base Classes Changes @@ -1240,7 +1240,7 @@ Quick example of loading an extension: async with bot: await bot.load_extension('my_extension') await bot.start(TOKEN) - + asyncio.run(main()) @@ -1422,6 +1422,7 @@ Miscellaneous Changes - ``BotMissingPermissions.missing_perms`` has been renamed to :attr:`ext.commands.BotMissingPermissions.missing_permissions`. - :meth:`ext.commands.Cog.cog_load` has been added as part of the :ref:`migrating_2_0_commands_extension_cog_async` changes. - :meth:`ext.commands.Cog.cog_unload` may now be a :term:`coroutine` due to the :ref:`migrating_2_0_commands_extension_cog_async` changes. +- :attr:`ext.commands.Command.clean_params` type now uses a custom :class:`inspect.Parameter` to handle defaults. .. _migrating_2_0_tasks: