Browse Source

[commands] Add fallback behaviour to the default parameter instances

This allows users to explicitly override the default annotation for
CurrentAuthor and CurrentChannel since they're wider than what most
users would expect
pull/7845/head
Rapptz 3 years ago
parent
commit
629f36e7d7
  1. 11
      discord/ext/commands/core.py
  2. 18
      discord/ext/commands/parameters.py

11
discord/ext/commands/core.py

@ -138,7 +138,16 @@ def get_signature_parameters(
default = parameter.default default = parameter.default
if isinstance(default, Parameter): # update from the default if isinstance(default, Parameter): # update from the default
if default.annotation is not Parameter.empty: if default.annotation is not Parameter.empty:
parameter._annotation = default.annotation # There are a few cases to care about here.
# x: TextChannel = commands.CurrentChannel
# x = commands.CurrentChannel
# In both of these cases, the default parameter has an explicit annotation
# but in the second case it's only used as the fallback.
if default._fallback:
if parameter.annotation is Parameter.empty:
parameter._annotation = default.annotation
else:
parameter._annotation = default.annotation
parameter._default = default.default parameter._default = default.default
parameter._displayed_default = default._displayed_default parameter._displayed_default = default._displayed_default

18
discord/ext/commands/parameters.py

@ -31,6 +31,16 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, OrderedDict, Union
from discord.utils import MISSING, maybe_coroutine from discord.utils import MISSING, maybe_coroutine
from .errors import NoPrivateMessage from .errors import NoPrivateMessage
from .converter import GuildConverter
from discord import (
Member,
User,
TextChannel,
VoiceChannel,
DMChannel,
Thread,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -77,7 +87,7 @@ class Parameter(inspect.Parameter):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__slots__ = ('_displayed_default',) __slots__ = ('_displayed_default', '_fallback')
def __init__( def __init__(
self, self,
@ -93,6 +103,7 @@ class Parameter(inspect.Parameter):
self._default = default self._default = default
self._annotation = annotation self._annotation = annotation
self._displayed_default = displayed_default self._displayed_default = displayed_default
self._fallback = False
def replace( def replace(
self, self,
@ -218,12 +229,16 @@ An alias for :func:`parameter`.
Author = parameter( Author = parameter(
default=attrgetter('author'), default=attrgetter('author'),
displayed_default='<you>', displayed_default='<you>',
converter=Union[Member, User],
) )
Author._fallback = True
CurrentChannel = parameter( CurrentChannel = parameter(
default=attrgetter('channel'), default=attrgetter('channel'),
displayed_default='<this channel>', displayed_default='<this channel>',
converter=Union[TextChannel, DMChannel, Thread, VoiceChannel],
) )
CurrentChannel._fallback = True
def default_guild(ctx: Context) -> Guild: def default_guild(ctx: Context) -> Guild:
@ -235,6 +250,7 @@ def default_guild(ctx: Context) -> Guild:
CurrentGuild = parameter( CurrentGuild = parameter(
default=default_guild, default=default_guild,
displayed_default='<this server>', displayed_default='<this server>',
converter=GuildConverter,
) )

Loading…
Cancel
Save