|
|
@ -37,7 +37,7 @@ from .view import StringView |
|
|
|
from .converter import run_converters |
|
|
|
|
|
|
|
from discord.utils import maybe_coroutine |
|
|
|
from dataclasses import dataclass |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from typing import ( |
|
|
|
Dict, |
|
|
|
Optional, |
|
|
@ -87,6 +87,8 @@ class Flag: |
|
|
|
------------ |
|
|
|
name: :class:`str` |
|
|
|
The name of the flag. |
|
|
|
aliases: List[:class:`str`] |
|
|
|
The aliases of the flag name. |
|
|
|
attribute: :class:`str` |
|
|
|
The attribute in the class that corresponds to this flag. |
|
|
|
default: Any |
|
|
@ -101,6 +103,7 @@ class Flag: |
|
|
|
""" |
|
|
|
|
|
|
|
name: str = MISSING |
|
|
|
aliases: List[str] = field(default_factory=list) |
|
|
|
attribute: str = MISSING |
|
|
|
annotation: Any = MISSING |
|
|
|
default: Any = MISSING |
|
|
@ -120,6 +123,7 @@ class Flag: |
|
|
|
def flag( |
|
|
|
*, |
|
|
|
name: str = MISSING, |
|
|
|
aliases: List[str] = MISSING, |
|
|
|
default: Any = MISSING, |
|
|
|
max_args: int = MISSING, |
|
|
|
override: bool = MISSING, |
|
|
@ -131,6 +135,8 @@ def flag( |
|
|
|
------------ |
|
|
|
name: :class:`str` |
|
|
|
The flag name. If not given, defaults to the attribute name. |
|
|
|
aliases: List[:class:`str`] |
|
|
|
Aliases to the flag name. If not given no aliases are set. |
|
|
|
default: Any |
|
|
|
The default parameter. This could be either a value or a callable that takes |
|
|
|
:class:`Context` as its sole parameter. If not given then it defaults to |
|
|
@ -143,7 +149,7 @@ def flag( |
|
|
|
Whether multiple given values overrides the previous value. The default |
|
|
|
value depends on the annotation given. |
|
|
|
""" |
|
|
|
return Flag(name=name, default=default, max_args=max_args, override=override) |
|
|
|
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) |
|
|
|
|
|
|
|
|
|
|
|
def validate_flag_name(name: str, forbidden: Set[str]): |
|
|
@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]): |
|
|
|
|
|
|
|
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: |
|
|
|
annotations = namespace.get('__annotations__', {}) |
|
|
|
case_insensitive = namespace['__commands_flag_case_insensitive__'] |
|
|
|
flags: Dict[str, Flag] = {} |
|
|
|
cache: Dict[str, Any] = {} |
|
|
|
names: Set[str] = set() |
|
|
|
for name, annotation in annotations.items(): |
|
|
|
flag = namespace.pop(name, MISSING) |
|
|
|
if isinstance(flag, Flag): |
|
|
@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s |
|
|
|
|
|
|
|
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) |
|
|
|
|
|
|
|
if flag.aliases is MISSING: |
|
|
|
flag.aliases = [] |
|
|
|
|
|
|
|
# Add sensible defaults based off of the type annotation |
|
|
|
# <type> -> (max_args=1) |
|
|
|
# List[str] -> (max_args=-1) |
|
|
@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s |
|
|
|
if flag.override is MISSING: |
|
|
|
flag.override = False |
|
|
|
|
|
|
|
# Validate flag names are unique |
|
|
|
name = flag.name.casefold() if case_insensitive else flag.name |
|
|
|
if name in names: |
|
|
|
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.') |
|
|
|
else: |
|
|
|
names.add(name) |
|
|
|
|
|
|
|
for alias in flag.aliases: |
|
|
|
# Validate alias is unique |
|
|
|
alias = alias.casefold() if case_insensitive else alias |
|
|
|
if alias in names: |
|
|
|
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.') |
|
|
|
else: |
|
|
|
names.add(alias) |
|
|
|
|
|
|
|
flags[flag.name] = flag |
|
|
|
|
|
|
|
return flags |
|
|
@ -230,6 +256,7 @@ class FlagsMeta(type): |
|
|
|
if TYPE_CHECKING: |
|
|
|
__commands_is_flag__: bool |
|
|
|
__commands_flags__: Dict[str, Flag] |
|
|
|
__commands_flag_aliases__: Dict[str, str] |
|
|
|
__commands_flag_regex__: Pattern[str] |
|
|
|
__commands_flag_case_insensitive__: bool |
|
|
|
__commands_flag_delimiter__: str |
|
|
@ -271,25 +298,37 @@ class FlagsMeta(type): |
|
|
|
del frame |
|
|
|
|
|
|
|
flags: Dict[str, Flag] = {} |
|
|
|
aliases: Dict[str, str] = {} |
|
|
|
for base in reversed(bases): |
|
|
|
if base.__dict__.get('__commands_is_flag__', False): |
|
|
|
flags.update(base.__dict__['__commands_flags__']) |
|
|
|
aliases.update(base.__dict__['__commands_flag_aliases__']) |
|
|
|
|
|
|
|
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): |
|
|
|
flags[flag_name] = flag |
|
|
|
aliases.update({alias_name: flag_name for alias_name in flag.aliases}) |
|
|
|
|
|
|
|
flags.update(get_flags(attrs, global_ns, local_ns)) |
|
|
|
forbidden = set(delimiter).union(prefix) |
|
|
|
for flag_name in flags: |
|
|
|
validate_flag_name(flag_name, forbidden) |
|
|
|
for alias_name in aliases: |
|
|
|
validate_flag_name(alias_name, forbidden) |
|
|
|
|
|
|
|
regex_flags = 0 |
|
|
|
if case_insensitive: |
|
|
|
flags = {key.casefold(): value for key, value in flags.items()} |
|
|
|
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()} |
|
|
|
regex_flags = re.IGNORECASE |
|
|
|
|
|
|
|
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) |
|
|
|
keys = list(re.escape(k) for k in flags) |
|
|
|
keys.extend(re.escape(a) for a in aliases) |
|
|
|
keys = sorted(keys, key=lambda t: len(t), reverse=True) |
|
|
|
|
|
|
|
joined = '|'.join(keys) |
|
|
|
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) |
|
|
|
attrs['__commands_flag_regex__'] = pattern |
|
|
|
attrs['__commands_flags__'] = flags |
|
|
|
attrs['__commands_flag_aliases__'] = aliases |
|
|
|
|
|
|
|
return type.__new__(cls, name, bases, attrs) |
|
|
|
|
|
|
@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta): |
|
|
|
def parse_flags(cls, argument: str) -> Dict[str, List[str]]: |
|
|
|
result: Dict[str, List[str]] = {} |
|
|
|
flags = cls.__commands_flags__ |
|
|
|
aliases = cls.__commands_flag_aliases__ |
|
|
|
last_position = 0 |
|
|
|
last_flag: Optional[Flag] = None |
|
|
|
|
|
|
@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta): |
|
|
|
if case_insensitive: |
|
|
|
key = key.casefold() |
|
|
|
|
|
|
|
if key in aliases: |
|
|
|
key = aliases[key] |
|
|
|
|
|
|
|
flag = flags.get(key) |
|
|
|
if last_position and last_flag is not None: |
|
|
|
value = argument[last_position : begin - 1].lstrip() |
|
|
|