|
|
@ -28,7 +28,7 @@ import inspect |
|
|
|
import re |
|
|
|
import sys |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Union |
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Type, Union |
|
|
|
|
|
|
|
from discord.utils import MISSING, maybe_coroutine, resolve_annotation |
|
|
|
|
|
|
@ -44,7 +44,7 @@ __all__ = ( |
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from typing_extensions import Self |
|
|
|
from typing_extensions import Self, TypeGuard |
|
|
|
|
|
|
|
from ._types import BotT |
|
|
|
from .context import Context |
|
|
@ -76,6 +76,12 @@ class Flag: |
|
|
|
A negative value indicates an unlimited amount of arguments. |
|
|
|
override: :class:`bool` |
|
|
|
Whether multiple given values overrides the previous value. |
|
|
|
description: :class:`str` |
|
|
|
The description of the flag. |
|
|
|
positional: :class:`bool` |
|
|
|
Whether the flag is positional or not. There can only be one positional flag. |
|
|
|
|
|
|
|
.. versionadded:: 2.4 |
|
|
|
""" |
|
|
|
|
|
|
|
name: str = MISSING |
|
|
@ -85,6 +91,8 @@ class Flag: |
|
|
|
default: Any = MISSING |
|
|
|
max_args: int = MISSING |
|
|
|
override: bool = MISSING |
|
|
|
description: str = MISSING |
|
|
|
positional: bool = MISSING |
|
|
|
cast_to_dict: bool = False |
|
|
|
|
|
|
|
@property |
|
|
@ -104,6 +112,8 @@ def flag( |
|
|
|
max_args: int = MISSING, |
|
|
|
override: bool = MISSING, |
|
|
|
converter: Any = MISSING, |
|
|
|
description: str = MISSING, |
|
|
|
positional: bool = MISSING, |
|
|
|
) -> Any: |
|
|
|
"""Override default functionality and parameters of the underlying :class:`FlagConverter` |
|
|
|
class attributes. |
|
|
@ -128,8 +138,27 @@ def flag( |
|
|
|
converter: Any |
|
|
|
The converter to use for this flag. This replaces the annotation at |
|
|
|
runtime which is transparent to type checkers. |
|
|
|
description: :class:`str` |
|
|
|
The description of the flag. |
|
|
|
positional: :class:`bool` |
|
|
|
Whether the flag is positional or not. There can only be one positional flag. |
|
|
|
|
|
|
|
.. versionadded:: 2.4 |
|
|
|
""" |
|
|
|
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override, annotation=converter) |
|
|
|
return Flag( |
|
|
|
name=name, |
|
|
|
aliases=aliases, |
|
|
|
default=default, |
|
|
|
max_args=max_args, |
|
|
|
override=override, |
|
|
|
annotation=converter, |
|
|
|
description=description, |
|
|
|
positional=positional, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def is_flag(obj: Any) -> TypeGuard[Type[FlagConverter]]: |
|
|
|
return hasattr(obj, '__commands_is_flag__') |
|
|
|
|
|
|
|
|
|
|
|
def validate_flag_name(name: str, forbidden: Set[str]) -> None: |
|
|
@ -151,6 +180,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s |
|
|
|
flags: Dict[str, Flag] = {} |
|
|
|
cache: Dict[str, Any] = {} |
|
|
|
names: Set[str] = set() |
|
|
|
positional: Optional[Flag] = None |
|
|
|
for name, annotation in annotations.items(): |
|
|
|
flag = namespace.pop(name, MISSING) |
|
|
|
if isinstance(flag, Flag): |
|
|
@ -163,6 +193,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s |
|
|
|
if flag.name is MISSING: |
|
|
|
flag.name = name |
|
|
|
|
|
|
|
if flag.positional: |
|
|
|
if positional is not None: |
|
|
|
raise TypeError(f"{flag.name!r} positional flag conflicts with {positional.name!r} flag.") |
|
|
|
positional = flag |
|
|
|
|
|
|
|
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) |
|
|
|
|
|
|
|
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible(): |
|
|
@ -250,6 +285,7 @@ class FlagsMeta(type): |
|
|
|
__commands_flag_case_insensitive__: bool |
|
|
|
__commands_flag_delimiter__: str |
|
|
|
__commands_flag_prefix__: str |
|
|
|
__commands_flag_positional__: Optional[Flag] |
|
|
|
|
|
|
|
def __new__( |
|
|
|
cls, |
|
|
@ -304,9 +340,13 @@ class FlagsMeta(type): |
|
|
|
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':') |
|
|
|
prefix = attrs.setdefault('__commands_flag_prefix__', '') |
|
|
|
|
|
|
|
positional: Optional[Flag] = None |
|
|
|
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): |
|
|
|
flags[flag_name] = flag |
|
|
|
aliases.update({alias_name: flag_name for alias_name in flag.aliases}) |
|
|
|
if flag.positional: |
|
|
|
positional = flag |
|
|
|
attrs['__commands_flag_positional__'] = positional |
|
|
|
|
|
|
|
forbidden = set(delimiter).union(prefix) |
|
|
|
for flag_name in flags: |
|
|
@ -480,10 +520,25 @@ class FlagConverter(metaclass=FlagsMeta): |
|
|
|
result: Dict[str, List[str]] = {} |
|
|
|
flags = cls.__commands_flags__ |
|
|
|
aliases = cls.__commands_flag_aliases__ |
|
|
|
positional_flag = cls.__commands_flag_positional__ |
|
|
|
last_position = 0 |
|
|
|
last_flag: Optional[Flag] = None |
|
|
|
|
|
|
|
case_insensitive = cls.__commands_flag_case_insensitive__ |
|
|
|
|
|
|
|
if positional_flag is not None: |
|
|
|
match = cls.__commands_flag_regex__.search(argument) |
|
|
|
if match is not None: |
|
|
|
begin, end = match.span(0) |
|
|
|
value = argument[:begin].strip() |
|
|
|
else: |
|
|
|
value = argument.strip() |
|
|
|
last_position = len(argument) |
|
|
|
|
|
|
|
if value: |
|
|
|
name = positional_flag.name.casefold() if case_insensitive else positional_flag.name |
|
|
|
result[name] = [value] |
|
|
|
|
|
|
|
for match in cls.__commands_flag_regex__.finditer(argument): |
|
|
|
begin, end = match.span(0) |
|
|
|
key = match.group('flag') |
|
|
|