Browse Source
The name is currently pending and there's no command.signature hook for it yet since this requires bikeshedding.pull/6758/head
8 changed files with 786 additions and 0 deletions
@ -0,0 +1,530 @@ |
|||
""" |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2015-present Rapptz |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a |
|||
copy of this software and associated documentation files (the "Software"), |
|||
to deal in the Software without restriction, including without limitation |
|||
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
|||
and/or sell copies of the Software, and to permit persons to whom the |
|||
Software is furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
|||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|||
DEALINGS IN THE SOFTWARE. |
|||
""" |
|||
|
|||
from __future__ import annotations |
|||
|
|||
from .errors import ( |
|||
BadFlagArgument, |
|||
CommandError, |
|||
MissingFlagArgument, |
|||
TooManyFlags, |
|||
MissingRequiredFlag, |
|||
) |
|||
|
|||
from .core import resolve_annotation |
|||
from .view import StringView |
|||
from .converter import run_converters |
|||
|
|||
from discord.utils import maybe_coroutine |
|||
from dataclasses import dataclass |
|||
from typing import ( |
|||
Dict, |
|||
Optional, |
|||
Pattern, |
|||
Set, |
|||
TYPE_CHECKING, |
|||
Tuple, |
|||
List, |
|||
Any, |
|||
Type, |
|||
TypeVar, |
|||
Union, |
|||
) |
|||
|
|||
import inspect |
|||
import sys |
|||
import re |
|||
|
|||
__all__ = ( |
|||
'Flag', |
|||
'flag', |
|||
'FlagConverter', |
|||
) |
|||
|
|||
|
|||
if TYPE_CHECKING: |
|||
from .context import Context |
|||
|
|||
|
|||
class _MissingSentinel: |
|||
def __repr__(self): |
|||
return 'MISSING' |
|||
|
|||
|
|||
MISSING: Any = _MissingSentinel() |
|||
|
|||
|
|||
@dataclass |
|||
class Flag: |
|||
"""Represents a flag parameter for :class:`FlagConverter`. |
|||
|
|||
The :func:`~discord.ext.commands.flag` function helps |
|||
create these flag objects, but it is not necessary to |
|||
do so. These cannot be constructed manually. |
|||
|
|||
Attributes |
|||
------------ |
|||
name: :class:`str` |
|||
The name of the flag. |
|||
attribute: :class:`str` |
|||
The attribute in the class that corresponds to this flag. |
|||
default: Any |
|||
The default value of the flag, if available. |
|||
annotation: Any |
|||
The underlying evaluated annotation of the flag. |
|||
max_args: :class:`int` |
|||
The maximum number of arguments the flag can accept. |
|||
A negative value indicates an unlimited amount of arguments. |
|||
override: :class:`bool` |
|||
Whether multiple given values overrides the previous value. |
|||
""" |
|||
|
|||
name: str = MISSING |
|||
attribute: str = MISSING |
|||
annotation: Any = MISSING |
|||
default: Any = MISSING |
|||
max_args: int = MISSING |
|||
override: bool = MISSING |
|||
cast_to_dict: bool = False |
|||
|
|||
@property |
|||
def required(self) -> bool: |
|||
""":class:`bool`: Whether the flag is required. |
|||
|
|||
A required flag has no default value. |
|||
""" |
|||
return self.default is MISSING |
|||
|
|||
|
|||
def flag( |
|||
*, |
|||
name: str = MISSING, |
|||
default: Any = MISSING, |
|||
max_args: int = MISSING, |
|||
override: bool = MISSING, |
|||
) -> Any: |
|||
"""Override default functionality and parameters of the underlying :class:`FlagConverter` |
|||
class attributes. |
|||
|
|||
Parameters |
|||
------------ |
|||
name: :class:`str` |
|||
The flag name. If not given, defaults to the attribute name. |
|||
default: Any |
|||
The default parameter. This could be either a value or a callable that takes |
|||
:class:`Context` as its sole parameter. If not given then it defaults to |
|||
the default value given to the attribute. |
|||
max_args: :class:`int` |
|||
The maximum number of arguments the flag can accept. |
|||
A negative value indicates an unlimited amount of arguments. |
|||
The default value depends on the annotation given. |
|||
override: :class:`bool` |
|||
Whether multiple given values overrides the previous value. The default |
|||
value depends on the annotation given. |
|||
""" |
|||
return Flag(name=name, default=default, max_args=max_args, override=override) |
|||
|
|||
|
|||
def validate_flag_name(name: str, forbidden: Set[str]): |
|||
if not name: |
|||
raise ValueError('flag names should not be empty') |
|||
|
|||
for ch in name: |
|||
if ch.isspace(): |
|||
raise ValueError(f'flag name {name!r} cannot have spaces') |
|||
if ch == '\\': |
|||
raise ValueError(f'flag name {name!r} cannot have backslashes') |
|||
if ch in forbidden: |
|||
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') |
|||
|
|||
|
|||
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: |
|||
annotations = namespace.get('__annotations__', {}) |
|||
flags: Dict[str, Flag] = {} |
|||
cache: Dict[str, Any] = {} |
|||
for name, annotation in annotations.items(): |
|||
flag = namespace.pop(name, MISSING) |
|||
if isinstance(flag, Flag): |
|||
flag.annotation = annotation |
|||
else: |
|||
flag = Flag(name=name, annotation=annotation, default=flag) |
|||
|
|||
flag.attribute = name |
|||
if flag.name is MISSING: |
|||
flag.name = name |
|||
|
|||
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) |
|||
|
|||
# Add sensible defaults based off of the type annotation |
|||
# <type> -> (max_args=1) |
|||
# List[str] -> (max_args=-1) |
|||
# Tuple[int, ...] -> (max_args=1) |
|||
# Dict[K, V] -> (max_args=-1, override=True) |
|||
# Optional[str] -> (default=None, max_args=1) |
|||
|
|||
try: |
|||
origin = annotation.__origin__ |
|||
except AttributeError: |
|||
# A regular type hint |
|||
if flag.max_args is MISSING: |
|||
flag.max_args = 1 |
|||
else: |
|||
if origin is Union and annotation.__args__[-1] is type(None): |
|||
# typing.Optional |
|||
if flag.max_args is MISSING: |
|||
flag.max_args = 1 |
|||
if flag.default is MISSING: |
|||
flag.default = None |
|||
elif origin is tuple: |
|||
# typing.Tuple |
|||
# tuple parsing is e.g. `flag: peter 20` |
|||
# for Tuple[str, int] would give you flag: ('peter', 20) |
|||
if flag.max_args is MISSING: |
|||
flag.max_args = 1 |
|||
elif origin is list: |
|||
# typing.List |
|||
if flag.max_args is MISSING: |
|||
flag.max_args = -1 |
|||
elif origin is dict: |
|||
# typing.Dict[K, V] |
|||
# Equivalent to: |
|||
# typing.List[typing.Tuple[K, V]] |
|||
flag.cast_to_dict = True |
|||
if flag.max_args is MISSING: |
|||
flag.max_args = -1 |
|||
if flag.override is MISSING: |
|||
flag.override = True |
|||
else: |
|||
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') |
|||
|
|||
if flag.override is MISSING: |
|||
flag.override = False |
|||
|
|||
flags[flag.name] = flag |
|||
|
|||
return flags |
|||
|
|||
|
|||
class FlagsMeta(type): |
|||
if TYPE_CHECKING: |
|||
__commands_is_flag__: bool |
|||
__commands_flags__: Dict[str, Flag] |
|||
__commands_flag_regex__: Pattern[str] |
|||
__commands_flag_case_insensitive__: bool |
|||
__commands_flag_delimiter__: str |
|||
__commands_flag_prefix__: str |
|||
|
|||
def __new__( |
|||
cls: Type[type], |
|||
name: str, |
|||
bases: Tuple[type, ...], |
|||
attrs: Dict[str, Any], |
|||
*, |
|||
case_insensitive: bool = False, |
|||
delimiter: str = ':', |
|||
prefix: str = '', |
|||
): |
|||
attrs['__commands_is_flag__'] = True |
|||
attrs['__commands_flag_case_insensitive__'] = case_insensitive |
|||
attrs['__commands_flag_delimiter__'] = delimiter |
|||
attrs['__commands_flag_prefix__'] = prefix |
|||
|
|||
if not prefix and not delimiter: |
|||
raise TypeError('Must have either a delimiter or a prefix set') |
|||
|
|||
try: |
|||
global_ns = sys.modules[attrs['__module__']].__dict__ |
|||
except KeyError: |
|||
global_ns = {} |
|||
|
|||
frame = inspect.currentframe() |
|||
try: |
|||
if frame is None: |
|||
local_ns = {} |
|||
else: |
|||
if frame.f_back is None: |
|||
local_ns = frame.f_locals |
|||
else: |
|||
local_ns = frame.f_back.f_locals |
|||
finally: |
|||
del frame |
|||
|
|||
flags: Dict[str, Flag] = {} |
|||
for base in reversed(bases): |
|||
if base.__dict__.get('__commands_is_flag__', False): |
|||
flags.update(base.__dict__['__commands_flags__']) |
|||
|
|||
flags.update(get_flags(attrs, global_ns, local_ns)) |
|||
forbidden = set(delimiter).union(prefix) |
|||
for flag_name in flags: |
|||
validate_flag_name(flag_name, forbidden) |
|||
|
|||
regex_flags = 0 |
|||
if case_insensitive: |
|||
flags = {key.casefold(): value for key, value in flags.items()} |
|||
regex_flags = re.IGNORECASE |
|||
|
|||
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) |
|||
joined = '|'.join(keys) |
|||
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) |
|||
attrs['__commands_flag_regex__'] = pattern |
|||
attrs['__commands_flags__'] = flags |
|||
|
|||
return type.__new__(cls, name, bases, attrs) |
|||
|
|||
|
|||
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: |
|||
view = StringView(argument) |
|||
results = [] |
|||
param: inspect.Parameter = ctx.current_parameter # type: ignore |
|||
while not view.eof: |
|||
view.skip_ws() |
|||
if view.eof: |
|||
break |
|||
|
|||
word = view.get_quoted_word() |
|||
if word is None: |
|||
break |
|||
|
|||
try: |
|||
converted = await run_converters(ctx, converter, word, param) |
|||
except CommandError: |
|||
raise |
|||
except Exception as e: |
|||
raise BadFlagArgument(flag) from e |
|||
else: |
|||
results.append(converted) |
|||
|
|||
return tuple(results) |
|||
|
|||
|
|||
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: |
|||
view = StringView(argument) |
|||
results = [] |
|||
param: inspect.Parameter = ctx.current_parameter # type: ignore |
|||
for converter in converters: |
|||
view.skip_ws() |
|||
if view.eof: |
|||
break |
|||
|
|||
word = view.get_quoted_word() |
|||
if word is None: |
|||
break |
|||
|
|||
try: |
|||
converted = await run_converters(ctx, converter, word, param) |
|||
except CommandError: |
|||
raise |
|||
except Exception as e: |
|||
raise BadFlagArgument(flag) from e |
|||
else: |
|||
results.append(converted) |
|||
|
|||
if len(results) != len(converters): |
|||
raise BadFlagArgument(flag) |
|||
|
|||
return tuple(results) |
|||
|
|||
|
|||
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: |
|||
param: inspect.Parameter = ctx.current_parameter # type: ignore |
|||
annotation = annotation or flag.annotation |
|||
try: |
|||
origin = annotation.__origin__ |
|||
except AttributeError: |
|||
pass |
|||
else: |
|||
if origin is tuple: |
|||
if annotation.__args__[-1] is Ellipsis: |
|||
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0]) |
|||
else: |
|||
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) |
|||
elif origin is list or origin is Union and annotation.__args__[-1] is type(None): |
|||
# typing.List[x] or typing.Optional[x] |
|||
annotation = annotation.__args__[0] |
|||
return await convert_flag(ctx, argument, flag, annotation) |
|||
elif origin is dict: |
|||
# typing.Dict[K, V] -> typing.Tuple[K, V] |
|||
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) |
|||
|
|||
try: |
|||
return await run_converters(ctx, annotation, argument, param) |
|||
except CommandError: |
|||
raise |
|||
except Exception as e: |
|||
raise BadFlagArgument(flag) from e |
|||
|
|||
|
|||
F = TypeVar('F', bound='FlagConverter') |
|||
|
|||
|
|||
class FlagConverter(metaclass=FlagsMeta): |
|||
"""A converter that allows for a user-friendly flag syntax. |
|||
|
|||
The flags are defined using :pep:`526` type annotations similar |
|||
to the :mod:`dataclasses` Python module. For more information on |
|||
how this converter works, check the appropriate |
|||
:ref:`documentation <ext_commands_flag_converter>`. |
|||
|
|||
.. versionadded:: 2.0 |
|||
|
|||
Parameters |
|||
----------- |
|||
case_insensitive: :class:`bool` |
|||
A class parameter to toggle case insensitivity of the flag parsing. |
|||
If ``True`` then flags are parsed in a case insensitive manner. |
|||
Defaults to ``False``. |
|||
prefix: :class:`str` |
|||
The prefix that all flags must be prefixed with. By default |
|||
there is no prefix. |
|||
delimiter: :class:`str` |
|||
The delimiter that separates a flag's argument from the flag's name. |
|||
By default this is ``:``. |
|||
""" |
|||
|
|||
@classmethod |
|||
def get_flags(cls) -> Dict[str, Flag]: |
|||
"""Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has.""" |
|||
return cls.__commands_flags__.copy() |
|||
|
|||
def __repr__(self) -> str: |
|||
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) |
|||
return f'<{self.__class__.__name__} {pairs}>' |
|||
|
|||
@classmethod |
|||
def parse_flags(cls, argument: str) -> Dict[str, List[str]]: |
|||
result: Dict[str, List[str]] = {} |
|||
flags = cls.get_flags() |
|||
last_position = 0 |
|||
last_flag: Optional[Flag] = None |
|||
|
|||
case_insensitive = cls.__commands_flag_case_insensitive__ |
|||
for match in cls.__commands_flag_regex__.finditer(argument): |
|||
begin, end = match.span(0) |
|||
key = match.group('flag') |
|||
if case_insensitive: |
|||
key = key.casefold() |
|||
|
|||
flag = flags.get(key) |
|||
if last_position and last_flag is not None: |
|||
value = argument[last_position : begin - 1].lstrip() |
|||
if not value: |
|||
raise MissingFlagArgument(last_flag) |
|||
|
|||
try: |
|||
values = result[last_flag.name] |
|||
except KeyError: |
|||
result[last_flag.name] = [value] |
|||
else: |
|||
values.append(value) |
|||
|
|||
last_position = end |
|||
last_flag = flag |
|||
|
|||
# Add the remaining string to the last available flag |
|||
if last_position and last_flag is not None: |
|||
value = argument[last_position:].strip() |
|||
if not value: |
|||
raise MissingFlagArgument(last_flag) |
|||
|
|||
try: |
|||
values = result[last_flag.name] |
|||
except KeyError: |
|||
result[last_flag.name] = [value] |
|||
else: |
|||
values.append(value) |
|||
|
|||
# Verification of values will come at a later stage |
|||
return result |
|||
|
|||
@classmethod |
|||
async def convert(cls: Type[F], ctx: Context, argument: str) -> F: |
|||
"""|coro| |
|||
|
|||
The method that actually converters an argument to the flag mapping. |
|||
|
|||
Parameters |
|||
---------- |
|||
cls: Type[:class:`FlagConverter`] |
|||
The flag converter class. |
|||
ctx: :class:`Context` |
|||
The invocation context. |
|||
argument: :class:`str` |
|||
The argument to convert from. |
|||
|
|||
Raises |
|||
-------- |
|||
FlagError |
|||
A flag related parsing error. |
|||
CommandError |
|||
A command related error. |
|||
|
|||
Returns |
|||
-------- |
|||
:class:`FlagConverter` |
|||
The flag converter instance with all flags parsed. |
|||
""" |
|||
arguments = cls.parse_flags(argument) |
|||
flags = cls.get_flags() |
|||
|
|||
self: F = cls.__new__(cls) |
|||
for name, flag in flags.items(): |
|||
try: |
|||
values = arguments[name] |
|||
except KeyError: |
|||
if flag.required: |
|||
raise MissingRequiredFlag(flag) |
|||
else: |
|||
if callable(flag.default): |
|||
default = await maybe_coroutine(flag.default, ctx) |
|||
setattr(self, flag.attribute, default) |
|||
else: |
|||
setattr(self, flag.attribute, flag.default) |
|||
continue |
|||
|
|||
if flag.max_args > 0 and len(values) > flag.max_args: |
|||
if flag.override: |
|||
values = values[-flag.max_args :] |
|||
else: |
|||
raise TooManyFlags(flag, values) |
|||
|
|||
# Special case: |
|||
if flag.max_args == 1: |
|||
value = await convert_flag(ctx, values[0], flag) |
|||
setattr(self, flag.attribute, value) |
|||
continue |
|||
|
|||
# Another special case, tuple parsing. |
|||
# Tuple parsing is basically converting arguments within the flag |
|||
# So, given flag: hello 20 as the input and Tuple[str, int] as the type hint |
|||
# We would receive ('hello', 20) as the resulting value |
|||
# This uses the same whitespace and quoting rules as regular parameters. |
|||
values = [await convert_flag(ctx, value, flag) for value in values] |
|||
|
|||
if flag.cast_to_dict: |
|||
values = dict(values) # type: ignore |
|||
|
|||
setattr(self, flag.attribute, values) |
|||
|
|||
return self |
After Width: | Height: | Size: 25 KiB |
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 27 KiB |
Loading…
Reference in new issue