Browse Source
The name is currently pending and there's no command.signature hook for it yet since this requires bikeshedding.pull/6758/head
8 changed files with 786 additions and 0 deletions
@ -0,0 +1,530 @@ |
|||||
|
""" |
||||
|
The MIT License (MIT) |
||||
|
|
||||
|
Copyright (c) 2015-present Rapptz |
||||
|
|
||||
|
Permission is hereby granted, free of charge, to any person obtaining a |
||||
|
copy of this software and associated documentation files (the "Software"), |
||||
|
to deal in the Software without restriction, including without limitation |
||||
|
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
||||
|
and/or sell copies of the Software, and to permit persons to whom the |
||||
|
Software is furnished to do so, subject to the following conditions: |
||||
|
|
||||
|
The above copyright notice and this permission notice shall be included in |
||||
|
all copies or substantial portions of the Software. |
||||
|
|
||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
||||
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
||||
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
||||
|
DEALINGS IN THE SOFTWARE. |
||||
|
""" |
||||
|
|
||||
|
from __future__ import annotations |
||||
|
|
||||
|
from .errors import ( |
||||
|
BadFlagArgument, |
||||
|
CommandError, |
||||
|
MissingFlagArgument, |
||||
|
TooManyFlags, |
||||
|
MissingRequiredFlag, |
||||
|
) |
||||
|
|
||||
|
from .core import resolve_annotation |
||||
|
from .view import StringView |
||||
|
from .converter import run_converters |
||||
|
|
||||
|
from discord.utils import maybe_coroutine |
||||
|
from dataclasses import dataclass |
||||
|
from typing import ( |
||||
|
Dict, |
||||
|
Optional, |
||||
|
Pattern, |
||||
|
Set, |
||||
|
TYPE_CHECKING, |
||||
|
Tuple, |
||||
|
List, |
||||
|
Any, |
||||
|
Type, |
||||
|
TypeVar, |
||||
|
Union, |
||||
|
) |
||||
|
|
||||
|
import inspect |
||||
|
import sys |
||||
|
import re |
||||
|
|
||||
|
__all__ = ( |
||||
|
'Flag', |
||||
|
'flag', |
||||
|
'FlagConverter', |
||||
|
) |
||||
|
|
||||
|
|
||||
|
if TYPE_CHECKING: |
||||
|
from .context import Context |
||||
|
|
||||
|
|
||||
|
class _MissingSentinel: |
||||
|
def __repr__(self): |
||||
|
return 'MISSING' |
||||
|
|
||||
|
|
||||
|
MISSING: Any = _MissingSentinel() |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class Flag: |
||||
|
"""Represents a flag parameter for :class:`FlagConverter`. |
||||
|
|
||||
|
The :func:`~discord.ext.commands.flag` function helps |
||||
|
create these flag objects, but it is not necessary to |
||||
|
do so. These cannot be constructed manually. |
||||
|
|
||||
|
Attributes |
||||
|
------------ |
||||
|
name: :class:`str` |
||||
|
The name of the flag. |
||||
|
attribute: :class:`str` |
||||
|
The attribute in the class that corresponds to this flag. |
||||
|
default: Any |
||||
|
The default value of the flag, if available. |
||||
|
annotation: Any |
||||
|
The underlying evaluated annotation of the flag. |
||||
|
max_args: :class:`int` |
||||
|
The maximum number of arguments the flag can accept. |
||||
|
A negative value indicates an unlimited amount of arguments. |
||||
|
override: :class:`bool` |
||||
|
Whether multiple given values overrides the previous value. |
||||
|
""" |
||||
|
|
||||
|
name: str = MISSING |
||||
|
attribute: str = MISSING |
||||
|
annotation: Any = MISSING |
||||
|
default: Any = MISSING |
||||
|
max_args: int = MISSING |
||||
|
override: bool = MISSING |
||||
|
cast_to_dict: bool = False |
||||
|
|
||||
|
@property |
||||
|
def required(self) -> bool: |
||||
|
""":class:`bool`: Whether the flag is required. |
||||
|
|
||||
|
A required flag has no default value. |
||||
|
""" |
||||
|
return self.default is MISSING |
||||
|
|
||||
|
|
||||
|
def flag( |
||||
|
*, |
||||
|
name: str = MISSING, |
||||
|
default: Any = MISSING, |
||||
|
max_args: int = MISSING, |
||||
|
override: bool = MISSING, |
||||
|
) -> Any: |
||||
|
"""Override default functionality and parameters of the underlying :class:`FlagConverter` |
||||
|
class attributes. |
||||
|
|
||||
|
Parameters |
||||
|
------------ |
||||
|
name: :class:`str` |
||||
|
The flag name. If not given, defaults to the attribute name. |
||||
|
default: Any |
||||
|
The default parameter. This could be either a value or a callable that takes |
||||
|
:class:`Context` as its sole parameter. If not given then it defaults to |
||||
|
the default value given to the attribute. |
||||
|
max_args: :class:`int` |
||||
|
The maximum number of arguments the flag can accept. |
||||
|
A negative value indicates an unlimited amount of arguments. |
||||
|
The default value depends on the annotation given. |
||||
|
override: :class:`bool` |
||||
|
Whether multiple given values overrides the previous value. The default |
||||
|
value depends on the annotation given. |
||||
|
""" |
||||
|
return Flag(name=name, default=default, max_args=max_args, override=override) |
||||
|
|
||||
|
|
||||
|
def validate_flag_name(name: str, forbidden: Set[str]): |
||||
|
if not name: |
||||
|
raise ValueError('flag names should not be empty') |
||||
|
|
||||
|
for ch in name: |
||||
|
if ch.isspace(): |
||||
|
raise ValueError(f'flag name {name!r} cannot have spaces') |
||||
|
if ch == '\\': |
||||
|
raise ValueError(f'flag name {name!r} cannot have backslashes') |
||||
|
if ch in forbidden: |
||||
|
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') |
||||
|
|
||||
|
|
||||
|
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: |
||||
|
annotations = namespace.get('__annotations__', {}) |
||||
|
flags: Dict[str, Flag] = {} |
||||
|
cache: Dict[str, Any] = {} |
||||
|
for name, annotation in annotations.items(): |
||||
|
flag = namespace.pop(name, MISSING) |
||||
|
if isinstance(flag, Flag): |
||||
|
flag.annotation = annotation |
||||
|
else: |
||||
|
flag = Flag(name=name, annotation=annotation, default=flag) |
||||
|
|
||||
|
flag.attribute = name |
||||
|
if flag.name is MISSING: |
||||
|
flag.name = name |
||||
|
|
||||
|
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) |
||||
|
|
||||
|
# Add sensible defaults based off of the type annotation |
||||
|
# <type> -> (max_args=1) |
||||
|
# List[str] -> (max_args=-1) |
||||
|
# Tuple[int, ...] -> (max_args=1) |
||||
|
# Dict[K, V] -> (max_args=-1, override=True) |
||||
|
# Optional[str] -> (default=None, max_args=1) |
||||
|
|
||||
|
try: |
||||
|
origin = annotation.__origin__ |
||||
|
except AttributeError: |
||||
|
# A regular type hint |
||||
|
if flag.max_args is MISSING: |
||||
|
flag.max_args = 1 |
||||
|
else: |
||||
|
if origin is Union and annotation.__args__[-1] is type(None): |
||||
|
# typing.Optional |
||||
|
if flag.max_args is MISSING: |
||||
|
flag.max_args = 1 |
||||
|
if flag.default is MISSING: |
||||
|
flag.default = None |
||||
|
elif origin is tuple: |
||||
|
# typing.Tuple |
||||
|
# tuple parsing is e.g. `flag: peter 20` |
||||
|
# for Tuple[str, int] would give you flag: ('peter', 20) |
||||
|
if flag.max_args is MISSING: |
||||
|
flag.max_args = 1 |
||||
|
elif origin is list: |
||||
|
# typing.List |
||||
|
if flag.max_args is MISSING: |
||||
|
flag.max_args = -1 |
||||
|
elif origin is dict: |
||||
|
# typing.Dict[K, V] |
||||
|
# Equivalent to: |
||||
|
# typing.List[typing.Tuple[K, V]] |
||||
|
flag.cast_to_dict = True |
||||
|
if flag.max_args is MISSING: |
||||
|
flag.max_args = -1 |
||||
|
if flag.override is MISSING: |
||||
|
flag.override = True |
||||
|
else: |
||||
|
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') |
||||
|
|
||||
|
if flag.override is MISSING: |
||||
|
flag.override = False |
||||
|
|
||||
|
flags[flag.name] = flag |
||||
|
|
||||
|
return flags |
||||
|
|
||||
|
|
||||
|
class FlagsMeta(type): |
||||
|
if TYPE_CHECKING: |
||||
|
__commands_is_flag__: bool |
||||
|
__commands_flags__: Dict[str, Flag] |
||||
|
__commands_flag_regex__: Pattern[str] |
||||
|
__commands_flag_case_insensitive__: bool |
||||
|
__commands_flag_delimiter__: str |
||||
|
__commands_flag_prefix__: str |
||||
|
|
||||
|
def __new__( |
||||
|
cls: Type[type], |
||||
|
name: str, |
||||
|
bases: Tuple[type, ...], |
||||
|
attrs: Dict[str, Any], |
||||
|
*, |
||||
|
case_insensitive: bool = False, |
||||
|
delimiter: str = ':', |
||||
|
prefix: str = '', |
||||
|
): |
||||
|
attrs['__commands_is_flag__'] = True |
||||
|
attrs['__commands_flag_case_insensitive__'] = case_insensitive |
||||
|
attrs['__commands_flag_delimiter__'] = delimiter |
||||
|
attrs['__commands_flag_prefix__'] = prefix |
||||
|
|
||||
|
if not prefix and not delimiter: |
||||
|
raise TypeError('Must have either a delimiter or a prefix set') |
||||
|
|
||||
|
try: |
||||
|
global_ns = sys.modules[attrs['__module__']].__dict__ |
||||
|
except KeyError: |
||||
|
global_ns = {} |
||||
|
|
||||
|
frame = inspect.currentframe() |
||||
|
try: |
||||
|
if frame is None: |
||||
|
local_ns = {} |
||||
|
else: |
||||
|
if frame.f_back is None: |
||||
|
local_ns = frame.f_locals |
||||
|
else: |
||||
|
local_ns = frame.f_back.f_locals |
||||
|
finally: |
||||
|
del frame |
||||
|
|
||||
|
flags: Dict[str, Flag] = {} |
||||
|
for base in reversed(bases): |
||||
|
if base.__dict__.get('__commands_is_flag__', False): |
||||
|
flags.update(base.__dict__['__commands_flags__']) |
||||
|
|
||||
|
flags.update(get_flags(attrs, global_ns, local_ns)) |
||||
|
forbidden = set(delimiter).union(prefix) |
||||
|
for flag_name in flags: |
||||
|
validate_flag_name(flag_name, forbidden) |
||||
|
|
||||
|
regex_flags = 0 |
||||
|
if case_insensitive: |
||||
|
flags = {key.casefold(): value for key, value in flags.items()} |
||||
|
regex_flags = re.IGNORECASE |
||||
|
|
||||
|
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) |
||||
|
joined = '|'.join(keys) |
||||
|
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) |
||||
|
attrs['__commands_flag_regex__'] = pattern |
||||
|
attrs['__commands_flags__'] = flags |
||||
|
|
||||
|
return type.__new__(cls, name, bases, attrs) |
||||
|
|
||||
|
|
||||
|
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: |
||||
|
view = StringView(argument) |
||||
|
results = [] |
||||
|
param: inspect.Parameter = ctx.current_parameter # type: ignore |
||||
|
while not view.eof: |
||||
|
view.skip_ws() |
||||
|
if view.eof: |
||||
|
break |
||||
|
|
||||
|
word = view.get_quoted_word() |
||||
|
if word is None: |
||||
|
break |
||||
|
|
||||
|
try: |
||||
|
converted = await run_converters(ctx, converter, word, param) |
||||
|
except CommandError: |
||||
|
raise |
||||
|
except Exception as e: |
||||
|
raise BadFlagArgument(flag) from e |
||||
|
else: |
||||
|
results.append(converted) |
||||
|
|
||||
|
return tuple(results) |
||||
|
|
||||
|
|
||||
|
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: |
||||
|
view = StringView(argument) |
||||
|
results = [] |
||||
|
param: inspect.Parameter = ctx.current_parameter # type: ignore |
||||
|
for converter in converters: |
||||
|
view.skip_ws() |
||||
|
if view.eof: |
||||
|
break |
||||
|
|
||||
|
word = view.get_quoted_word() |
||||
|
if word is None: |
||||
|
break |
||||
|
|
||||
|
try: |
||||
|
converted = await run_converters(ctx, converter, word, param) |
||||
|
except CommandError: |
||||
|
raise |
||||
|
except Exception as e: |
||||
|
raise BadFlagArgument(flag) from e |
||||
|
else: |
||||
|
results.append(converted) |
||||
|
|
||||
|
if len(results) != len(converters): |
||||
|
raise BadFlagArgument(flag) |
||||
|
|
||||
|
return tuple(results) |
||||
|
|
||||
|
|
||||
|
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: |
||||
|
param: inspect.Parameter = ctx.current_parameter # type: ignore |
||||
|
annotation = annotation or flag.annotation |
||||
|
try: |
||||
|
origin = annotation.__origin__ |
||||
|
except AttributeError: |
||||
|
pass |
||||
|
else: |
||||
|
if origin is tuple: |
||||
|
if annotation.__args__[-1] is Ellipsis: |
||||
|
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0]) |
||||
|
else: |
||||
|
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) |
||||
|
elif origin is list or origin is Union and annotation.__args__[-1] is type(None): |
||||
|
# typing.List[x] or typing.Optional[x] |
||||
|
annotation = annotation.__args__[0] |
||||
|
return await convert_flag(ctx, argument, flag, annotation) |
||||
|
elif origin is dict: |
||||
|
# typing.Dict[K, V] -> typing.Tuple[K, V] |
||||
|
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) |
||||
|
|
||||
|
try: |
||||
|
return await run_converters(ctx, annotation, argument, param) |
||||
|
except CommandError: |
||||
|
raise |
||||
|
except Exception as e: |
||||
|
raise BadFlagArgument(flag) from e |
||||
|
|
||||
|
|
||||
|
F = TypeVar('F', bound='FlagConverter') |
||||
|
|
||||
|
|
||||
|
class FlagConverter(metaclass=FlagsMeta): |
||||
|
"""A converter that allows for a user-friendly flag syntax. |
||||
|
|
||||
|
The flags are defined using :pep:`526` type annotations similar |
||||
|
to the :mod:`dataclasses` Python module. For more information on |
||||
|
how this converter works, check the appropriate |
||||
|
:ref:`documentation <ext_commands_flag_converter>`. |
||||
|
|
||||
|
.. versionadded:: 2.0 |
||||
|
|
||||
|
Parameters |
||||
|
----------- |
||||
|
case_insensitive: :class:`bool` |
||||
|
A class parameter to toggle case insensitivity of the flag parsing. |
||||
|
If ``True`` then flags are parsed in a case insensitive manner. |
||||
|
Defaults to ``False``. |
||||
|
prefix: :class:`str` |
||||
|
The prefix that all flags must be prefixed with. By default |
||||
|
there is no prefix. |
||||
|
delimiter: :class:`str` |
||||
|
The delimiter that separates a flag's argument from the flag's name. |
||||
|
By default this is ``:``. |
||||
|
""" |
||||
|
|
||||
|
@classmethod |
||||
|
def get_flags(cls) -> Dict[str, Flag]: |
||||
|
"""Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has.""" |
||||
|
return cls.__commands_flags__.copy() |
||||
|
|
||||
|
def __repr__(self) -> str: |
||||
|
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) |
||||
|
return f'<{self.__class__.__name__} {pairs}>' |
||||
|
|
||||
|
@classmethod |
||||
|
def parse_flags(cls, argument: str) -> Dict[str, List[str]]: |
||||
|
result: Dict[str, List[str]] = {} |
||||
|
flags = cls.get_flags() |
||||
|
last_position = 0 |
||||
|
last_flag: Optional[Flag] = None |
||||
|
|
||||
|
case_insensitive = cls.__commands_flag_case_insensitive__ |
||||
|
for match in cls.__commands_flag_regex__.finditer(argument): |
||||
|
begin, end = match.span(0) |
||||
|
key = match.group('flag') |
||||
|
if case_insensitive: |
||||
|
key = key.casefold() |
||||
|
|
||||
|
flag = flags.get(key) |
||||
|
if last_position and last_flag is not None: |
||||
|
value = argument[last_position : begin - 1].lstrip() |
||||
|
if not value: |
||||
|
raise MissingFlagArgument(last_flag) |
||||
|
|
||||
|
try: |
||||
|
values = result[last_flag.name] |
||||
|
except KeyError: |
||||
|
result[last_flag.name] = [value] |
||||
|
else: |
||||
|
values.append(value) |
||||
|
|
||||
|
last_position = end |
||||
|
last_flag = flag |
||||
|
|
||||
|
# Add the remaining string to the last available flag |
||||
|
if last_position and last_flag is not None: |
||||
|
value = argument[last_position:].strip() |
||||
|
if not value: |
||||
|
raise MissingFlagArgument(last_flag) |
||||
|
|
||||
|
try: |
||||
|
values = result[last_flag.name] |
||||
|
except KeyError: |
||||
|
result[last_flag.name] = [value] |
||||
|
else: |
||||
|
values.append(value) |
||||
|
|
||||
|
# Verification of values will come at a later stage |
||||
|
return result |
||||
|
|
||||
|
@classmethod |
||||
|
async def convert(cls: Type[F], ctx: Context, argument: str) -> F: |
||||
|
"""|coro| |
||||
|
|
||||
|
The method that actually converters an argument to the flag mapping. |
||||
|
|
||||
|
Parameters |
||||
|
---------- |
||||
|
cls: Type[:class:`FlagConverter`] |
||||
|
The flag converter class. |
||||
|
ctx: :class:`Context` |
||||
|
The invocation context. |
||||
|
argument: :class:`str` |
||||
|
The argument to convert from. |
||||
|
|
||||
|
Raises |
||||
|
-------- |
||||
|
FlagError |
||||
|
A flag related parsing error. |
||||
|
CommandError |
||||
|
A command related error. |
||||
|
|
||||
|
Returns |
||||
|
-------- |
||||
|
:class:`FlagConverter` |
||||
|
The flag converter instance with all flags parsed. |
||||
|
""" |
||||
|
arguments = cls.parse_flags(argument) |
||||
|
flags = cls.get_flags() |
||||
|
|
||||
|
self: F = cls.__new__(cls) |
||||
|
for name, flag in flags.items(): |
||||
|
try: |
||||
|
values = arguments[name] |
||||
|
except KeyError: |
||||
|
if flag.required: |
||||
|
raise MissingRequiredFlag(flag) |
||||
|
else: |
||||
|
if callable(flag.default): |
||||
|
default = await maybe_coroutine(flag.default, ctx) |
||||
|
setattr(self, flag.attribute, default) |
||||
|
else: |
||||
|
setattr(self, flag.attribute, flag.default) |
||||
|
continue |
||||
|
|
||||
|
if flag.max_args > 0 and len(values) > flag.max_args: |
||||
|
if flag.override: |
||||
|
values = values[-flag.max_args :] |
||||
|
else: |
||||
|
raise TooManyFlags(flag, values) |
||||
|
|
||||
|
# Special case: |
||||
|
if flag.max_args == 1: |
||||
|
value = await convert_flag(ctx, values[0], flag) |
||||
|
setattr(self, flag.attribute, value) |
||||
|
continue |
||||
|
|
||||
|
# Another special case, tuple parsing. |
||||
|
# Tuple parsing is basically converting arguments within the flag |
||||
|
# So, given flag: hello 20 as the input and Tuple[str, int] as the type hint |
||||
|
# We would receive ('hello', 20) as the resulting value |
||||
|
# This uses the same whitespace and quoting rules as regular parameters. |
||||
|
values = [await convert_flag(ctx, value, flag) for value in values] |
||||
|
|
||||
|
if flag.cast_to_dict: |
||||
|
values = dict(values) # type: ignore |
||||
|
|
||||
|
setattr(self, flag.attribute, values) |
||||
|
|
||||
|
return self |
After Width: | Height: | Size: 25 KiB |
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 27 KiB |
Loading…
Reference in new issue