Browse Source

[commands] Add support for aliasing to FlagConverter

pull/6765/head
Josh 4 years ago
committed by GitHub
parent
commit
42463bae67
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 51
      discord/ext/commands/flags.py

51
discord/ext/commands/flags.py

@ -37,7 +37,7 @@ from .view import StringView
from .converter import run_converters from .converter import run_converters
from discord.utils import maybe_coroutine from discord.utils import maybe_coroutine
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import ( from typing import (
Dict, Dict,
Optional, Optional,
@ -87,6 +87,8 @@ class Flag:
------------ ------------
name: :class:`str` name: :class:`str`
The name of the flag. The name of the flag.
aliases: List[:class:`str`]
The aliases of the flag name.
attribute: :class:`str` attribute: :class:`str`
The attribute in the class that corresponds to this flag. The attribute in the class that corresponds to this flag.
default: Any default: Any
@ -101,6 +103,7 @@ class Flag:
""" """
name: str = MISSING name: str = MISSING
aliases: List[str] = field(default_factory=list)
attribute: str = MISSING attribute: str = MISSING
annotation: Any = MISSING annotation: Any = MISSING
default: Any = MISSING default: Any = MISSING
@ -120,6 +123,7 @@ class Flag:
def flag( def flag(
*, *,
name: str = MISSING, name: str = MISSING,
aliases: List[str] = MISSING,
default: Any = MISSING, default: Any = MISSING,
max_args: int = MISSING, max_args: int = MISSING,
override: bool = MISSING, override: bool = MISSING,
@ -131,6 +135,8 @@ def flag(
------------ ------------
name: :class:`str` name: :class:`str`
The flag name. If not given, defaults to the attribute name. The flag name. If not given, defaults to the attribute name.
aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set.
default: Any default: Any
The default parameter. This could be either a value or a callable that takes The default parameter. This could be either a value or a callable that takes
:class:`Context` as its sole parameter. If not given then it defaults to :class:`Context` as its sole parameter. If not given then it defaults to
@ -143,7 +149,7 @@ def flag(
Whether multiple given values overrides the previous value. The default Whether multiple given values overrides the previous value. The default
value depends on the annotation given. value depends on the annotation given.
""" """
return Flag(name=name, default=default, max_args=max_args, override=override) return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]): def validate_flag_name(name: str, forbidden: Set[str]):
@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]):
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {}) annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
flags: Dict[str, Flag] = {} flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {} cache: Dict[str, Any] = {}
names: Set[str] = set()
for name, annotation in annotations.items(): for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING) flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag): if isinstance(flag, Flag):
@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.aliases is MISSING:
flag.aliases = []
# Add sensible defaults based off of the type annotation # Add sensible defaults based off of the type annotation
# <type> -> (max_args=1) # <type> -> (max_args=1)
# List[str] -> (max_args=-1) # List[str] -> (max_args=-1)
@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.override is MISSING: if flag.override is MISSING:
flag.override = False flag.override = False
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
else:
names.add(name)
for alias in flag.aliases:
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
else:
names.add(alias)
flags[flag.name] = flag flags[flag.name] = flag
return flags return flags
@ -230,6 +256,7 @@ class FlagsMeta(type):
if TYPE_CHECKING: if TYPE_CHECKING:
__commands_is_flag__: bool __commands_is_flag__: bool
__commands_flags__: Dict[str, Flag] __commands_flags__: Dict[str, Flag]
__commands_flag_aliases__: Dict[str, str]
__commands_flag_regex__: Pattern[str] __commands_flag_regex__: Pattern[str]
__commands_flag_case_insensitive__: bool __commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str __commands_flag_delimiter__: str
@ -271,25 +298,37 @@ class FlagsMeta(type):
del frame del frame
flags: Dict[str, Flag] = {} flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases): for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False): if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__']) flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
flags.update(get_flags(attrs, global_ns, local_ns))
forbidden = set(delimiter).union(prefix) forbidden = set(delimiter).union(prefix)
for flag_name in flags: for flag_name in flags:
validate_flag_name(flag_name, forbidden) validate_flag_name(flag_name, forbidden)
for alias_name in aliases:
validate_flag_name(alias_name, forbidden)
regex_flags = 0 regex_flags = 0
if case_insensitive: if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()} flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
regex_flags = re.IGNORECASE regex_flags = re.IGNORECASE
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) keys = list(re.escape(k) for k in flags)
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys) joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta):
def parse_flags(cls, argument: str) -> Dict[str, List[str]]: def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
result: Dict[str, List[str]] = {} result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__ flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
last_position = 0 last_position = 0
last_flag: Optional[Flag] = None last_flag: Optional[Flag] = None
@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta):
if case_insensitive: if case_insensitive:
key = key.casefold() key = key.casefold()
if key in aliases:
key = aliases[key]
flag = flags.get(key) flag = flags.get(key)
if last_position and last_flag is not None: if last_position and last_flag is not None:
value = argument[last_position : begin - 1].lstrip() value = argument[last_position : begin - 1].lstrip()

Loading…
Cancel
Save