Browse Source

[commands] Add support for aliasing to FlagConverter

pull/6765/head
Josh 4 years ago
committed by GitHub
parent
commit
42463bae67
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 51
      discord/ext/commands/flags.py

51
discord/ext/commands/flags.py

@ -37,7 +37,7 @@ from .view import StringView
from .converter import run_converters
from discord.utils import maybe_coroutine
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (
Dict,
Optional,
@ -87,6 +87,8 @@ class Flag:
------------
name: :class:`str`
The name of the flag.
aliases: List[:class:`str`]
The aliases of the flag name.
attribute: :class:`str`
The attribute in the class that corresponds to this flag.
default: Any
@ -101,6 +103,7 @@ class Flag:
"""
name: str = MISSING
aliases: List[str] = field(default_factory=list)
attribute: str = MISSING
annotation: Any = MISSING
default: Any = MISSING
@ -120,6 +123,7 @@ class Flag:
def flag(
*,
name: str = MISSING,
aliases: List[str] = MISSING,
default: Any = MISSING,
max_args: int = MISSING,
override: bool = MISSING,
@ -131,6 +135,8 @@ def flag(
------------
name: :class:`str`
The flag name. If not given, defaults to the attribute name.
aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set.
default: Any
The default parameter. This could be either a value or a callable that takes
:class:`Context` as its sole parameter. If not given then it defaults to
@ -143,7 +149,7 @@ def flag(
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(name=name, default=default, max_args=max_args, override=override)
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]):
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag):
@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.aliases is MISSING:
flag.aliases = []
# Add sensible defaults based off of the type annotation
# <type> -> (max_args=1)
# List[str] -> (max_args=-1)
@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.override is MISSING:
flag.override = False
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
else:
names.add(name)
for alias in flag.aliases:
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
else:
names.add(alias)
flags[flag.name] = flag
return flags
@ -230,6 +256,7 @@ class FlagsMeta(type):
if TYPE_CHECKING:
__commands_is_flag__: bool
__commands_flags__: Dict[str, Flag]
__commands_flag_aliases__: Dict[str, str]
__commands_flag_regex__: Pattern[str]
__commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str
@ -271,25 +298,37 @@ class FlagsMeta(type):
del frame
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
flags.update(get_flags(attrs, global_ns, local_ns))
forbidden = set(delimiter).union(prefix)
for flag_name in flags:
validate_flag_name(flag_name, forbidden)
for alias_name in aliases:
validate_flag_name(alias_name, forbidden)
regex_flags = 0
if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
regex_flags = re.IGNORECASE
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True)
keys = list(re.escape(k) for k in flags)
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
return type.__new__(cls, name, bases, attrs)
@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta):
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
last_position = 0
last_flag: Optional[Flag] = None
@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta):
if case_insensitive:
key = key.casefold()
if key in aliases:
key = aliases[key]
flag = flags.get(key)
if last_position and last_flag is not None:
value = argument[last_position : begin - 1].lstrip()

Loading…
Cancel
Save