Browse Source

Improve component typing

pull/8047/head
Lilly Rose Berner 3 years ago
committed by GitHub
parent
commit
7267d18d9e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 136
      discord/components.py
  2. 19
      discord/message.py
  3. 17
      discord/partial_emoji.py
  4. 2
      discord/state.py
  5. 5
      discord/types/components.py
  6. 2
      discord/types/interactions.py
  7. 1
      discord/ui/button.py
  8. 1
      discord/ui/select.py
  9. 3
      discord/ui/text_input.py
  10. 14
      discord/ui/view.py

136
discord/components.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle from .enums import try_enum, ComponentType, ButtonStyle, TextStyle
from .utils import get_slots, MISSING from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag from .partial_emoji import PartialEmoji, _EmojiTag
@ -39,9 +39,12 @@ if TYPE_CHECKING:
SelectOption as SelectOptionPayload, SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload, ActionRow as ActionRowPayload,
TextInput as TextInputPayload, TextInput as TextInputPayload,
ActionRowChildComponent as ActionRowChildComponentPayload,
) )
from .emoji import Emoji from .emoji import Emoji
ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput']
__all__ = ( __all__ = (
'Component', 'Component',
@ -61,26 +64,26 @@ class Component:
- :class:`ActionRow` - :class:`ActionRow`
- :class:`Button` - :class:`Button`
- :class:`SelectMenu` - :class:`SelectMenu`
- :class:`TextInput`
This class is abstract and cannot be instantiated. This class is abstract and cannot be instantiated.
.. versionadded:: 2.0 .. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
""" """
__slots__: Tuple[str, ...] = ('type',) __slots__: Tuple[str, ...] = ()
__repr_info__: ClassVar[Tuple[str, ...]] __repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>' return f'<{self.__class__.__name__} {attrs}>'
@property
def type(self) -> ComponentType:
""":class:`ComponentType`: The type of component."""
raise NotImplementedError
@classmethod @classmethod
def _raw_construct(cls, **kwargs) -> Self: def _raw_construct(cls, **kwargs) -> Self:
self = cls.__new__(cls) self = cls.__new__(cls)
@ -93,7 +96,7 @@ class Component:
setattr(self, slot, value) setattr(self, slot, value)
return self return self
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> ComponentPayload:
raise NotImplementedError raise NotImplementedError
@ -108,9 +111,7 @@ class ActionRow(Component):
Attributes Attributes
------------ ------------
type: :class:`ComponentType` children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]]
The type of component.
children: List[:class:`Component`]
The children components that this holds, if any. The children components that this holds, if any.
""" """
@ -118,15 +119,25 @@ class ActionRow(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload): def __init__(self, data: ActionRowPayload, /) -> None:
self.type: Literal[ComponentType.action_row] = ComponentType.action_row self.children: List[ActionRowChildComponentType] = []
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
for component_data in data.get('components', []):
component = _component_factory(component_data)
if component is not None:
self.children.append(component)
@property
def type(self) -> Literal[ComponentType.action_row]:
""":class:`ComponentType`: The type of component."""
return ComponentType.action_row
def to_dict(self) -> ActionRowPayload: def to_dict(self) -> ActionRowPayload:
return { return {
'type': int(self.type), 'type': self.type.value,
'components': [child.to_dict() for child in self.children], 'components': [child.to_dict() for child in self.children],
} # type: ignore # Type checker does not understand these are the same }
class Button(Component): class Button(Component):
@ -169,8 +180,7 @@ class Button(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload): def __init__(self, data: ButtonComponentPayload, /) -> None:
self.type: Literal[ComponentType.button] = ComponentType.button
self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id') self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url') self.url: Optional[str] = data.get('url')
@ -182,13 +192,21 @@ class Button(Component):
except KeyError: except KeyError:
self.emoji = None self.emoji = None
@property
def type(self) -> Literal[ComponentType.button]:
""":class:`ComponentType`: The type of component."""
return ComponentType.button
def to_dict(self) -> ButtonComponentPayload: def to_dict(self) -> ButtonComponentPayload:
payload = { payload: ButtonComponentPayload = {
'type': 2, 'type': 2,
'style': int(self.style), 'style': self.style.value,
'label': self.label,
'disabled': self.disabled, 'disabled': self.disabled,
} }
if self.label:
payload['label'] = self.label
if self.custom_id: if self.custom_id:
payload['custom_id'] = self.custom_id payload['custom_id'] = self.custom_id
@ -198,7 +216,7 @@ class Button(Component):
if self.emoji: if self.emoji:
payload['emoji'] = self.emoji.to_dict() payload['emoji'] = self.emoji.to_dict()
return payload # type: ignore # Type checker does not understand these are the same return payload
class SelectMenu(Component): class SelectMenu(Component):
@ -243,8 +261,7 @@ class SelectMenu(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload): def __init__(self, data: SelectMenuPayload, /) -> None:
self.type: Literal[ComponentType.select] = ComponentType.select
self.custom_id: str = data['custom_id'] self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder') self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1) self.min_values: int = data.get('min_values', 1)
@ -252,6 +269,11 @@ class SelectMenu(Component):
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False) self.disabled: bool = data.get('disabled', False)
@property
def type(self) -> Literal[ComponentType.select]:
""":class:`ComponentType`: The type of component."""
return ComponentType.select
def to_dict(self) -> SelectMenuPayload: def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = { payload: SelectMenuPayload = {
'type': self.type.value, 'type': self.type.value,
@ -275,7 +297,7 @@ class SelectOption:
.. versionadded:: 2.0 .. versionadded:: 2.0
Attributes Parameters
----------- -----------
label: :class:`str` label: :class:`str`
The label of the option. This is displayed to users. The label of the option. This is displayed to users.
@ -291,6 +313,23 @@ class SelectOption:
The emoji of the option, if available. The emoji of the option, if available.
default: :class:`bool` default: :class:`bool`
Whether this option is selected by default. Whether this option is selected by default.
Attributes
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the option, if available.
default: :class:`bool`
Whether this option is selected by default.
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
@ -322,7 +361,7 @@ class SelectOption:
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji self.emoji: Optional[PartialEmoji] = emoji
self.default: bool = default self.default: bool = default
def __repr__(self) -> str: def __repr__(self) -> str:
@ -364,7 +403,7 @@ class SelectOption:
} }
if self.emoji: if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore # This Dict[str, Any] is compatible with PartialEmoji payload['emoji'] = self.emoji.to_dict()
if self.description: if self.description:
payload['description'] = self.description payload['description'] = self.description
@ -414,8 +453,7 @@ class TextInput(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: TextInputPayload) -> None: def __init__(self, data: TextInputPayload, /) -> None:
self.type: Literal[ComponentType.text_input] = ComponentType.text_input
self.style: TextStyle = try_enum(TextStyle, data['style']) self.style: TextStyle = try_enum(TextStyle, data['style'])
self.label: str = data['label'] self.label: str = data['label']
self.custom_id: str = data['custom_id'] self.custom_id: str = data['custom_id']
@ -425,6 +463,11 @@ class TextInput(Component):
self.min_length: Optional[int] = data.get('min_length') self.min_length: Optional[int] = data.get('min_length')
self.max_length: Optional[int] = data.get('max_length') self.max_length: Optional[int] = data.get('max_length')
@property
def type(self) -> Literal[ComponentType.text_input]:
""":class:`ComponentType`: The type of component."""
return ComponentType.text_input
def to_dict(self) -> TextInputPayload: def to_dict(self) -> TextInputPayload:
payload: TextInputPayload = { payload: TextInputPayload = {
'type': self.type.value, 'type': self.type.value,
@ -457,19 +500,22 @@ class TextInput(Component):
return self.value return self.value
def _component_factory(data: ComponentPayload) -> Component: @overload
component_type = data['type'] def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]:
if component_type == 1: ...
@overload
def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
...
def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
if data['type'] == 1:
return ActionRow(data) return ActionRow(data)
elif component_type == 2: elif data['type'] == 2:
# The type checker does not properly do narrowing here. return Button(data)
return Button(data) # type: ignore elif data['type'] == 3:
elif component_type == 3: return SelectMenu(data)
# The type checker does not properly do narrowing here. elif data['type'] == 4:
return SelectMenu(data) # type: ignore return TextInput(data)
elif component_type == 4:
# The type checker does not properly do narrowing here.
return TextInput(data) # type: ignore
else:
as_enum = try_enum(ComponentType, component_type)
return Component._raw_construct(type=as_enum)

19
discord/message.py

@ -87,7 +87,7 @@ if TYPE_CHECKING:
from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent
from .abc import Snowflake from .abc import Snowflake
from .abc import GuildChannel, MessageableChannel from .abc import GuildChannel, MessageableChannel
from .components import Component from .components import ActionRow, ActionRowChildComponentType
from .state import ConnectionState from .state import ConnectionState
from .channel import TextChannel from .channel import TextChannel
from .mentions import AllowedMentions from .mentions import AllowedMentions
@ -96,6 +96,7 @@ if TYPE_CHECKING:
from .ui.view import View from .ui.view import View
EmojiInputType = Union[Emoji, PartialEmoji, str] EmojiInputType = Union[Emoji, PartialEmoji, str]
MessageComponentType = Union[ActionRow, ActionRowChildComponentType]
__all__ = ( __all__ = (
@ -1340,7 +1341,7 @@ class Message(PartialMessage, Hashable):
A list of sticker items given to the message. A list of sticker items given to the message.
.. versionadded:: 1.6 .. versionadded:: 1.6
components: List[:class:`Component`] components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]]
A list of components in the message. A list of components in the message.
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -1392,6 +1393,7 @@ class Message(PartialMessage, Hashable):
mentions: List[Union[User, Member]] mentions: List[Union[User, Member]]
author: Union[User, Member] author: Union[User, Member]
role_mentions: List[Role] role_mentions: List[Role]
components: List[MessageComponentType]
def __init__( def __init__(
self, self,
@ -1418,7 +1420,6 @@ class Message(PartialMessage, Hashable):
self.content: str = data['content'] self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce') self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
try: try:
# if the channel doesn't have a guild attribute, we handle that # if the channel doesn't have a guild attribute, we handle that
@ -1460,7 +1461,7 @@ class Message(PartialMessage, Hashable):
# the channel will be the correct type here # the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'): for handler in ('author', 'member', 'mentions', 'mention_roles', 'components'):
try: try:
getattr(self, f'_handle_{handler}')(data[handler]) getattr(self, f'_handle_{handler}')(data[handler])
except KeyError: except KeyError:
@ -1631,8 +1632,14 @@ class Message(PartialMessage, Hashable):
if role is not None: if role is not None:
self.role_mentions.append(role) self.role_mentions.append(role)
def _handle_components(self, components: List[ComponentPayload]): def _handle_components(self, data: List[ComponentPayload]) -> None:
self.components = [_component_factory(d) for d in components] self.components = []
for component_data in data:
component = _component_factory(component_data)
if component is not None:
self.components.append(component)
def _handle_interaction(self, data: MessageInteractionPayload): def _handle_interaction(self, data: MessageInteractionPayload):
self.interaction = MessageInteraction(state=self._state, guild=self.guild, data=data) self.interaction = MessageInteraction(state=self._state, guild=self.guild, data=data)

17
discord/partial_emoji.py

@ -41,7 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji from .types.activity import ActivityEmoji
@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return cls(name=value, id=None, animated=False) return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> EmojiPayload:
o: Dict[str, Any] = {'name': self.name} payload: EmojiPayload = {
if self.id: 'id': self.id,
o['id'] = self.id 'name': self.name,
}
if self.animated: if self.animated:
o['animated'] = self.animated payload['animated'] = self.animated
return o
return payload
def _to_partial(self) -> PartialEmoji: def _to_partial(self) -> PartialEmoji:
return self return self

2
discord/state.py

@ -730,7 +730,7 @@ class ConnectionState:
inner_data = data['data'] inner_data = data['data']
custom_id = inner_data['custom_id'] custom_id = inner_data['custom_id']
components = inner_data['components'] components = inner_data['components']
self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore self._view_store.dispatch_modal(custom_id, interaction, components)
self.dispatch('interaction', interaction) self.dispatch('interaction', interaction)
def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None: def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None:

5
discord/types/components.py

@ -36,7 +36,7 @@ TextStyle = Literal[1, 2]
class ActionRow(TypedDict): class ActionRow(TypedDict):
type: Literal[1] type: Literal[1]
components: List[Component] components: List[ActionRowChildComponent]
class ButtonComponent(TypedDict): class ButtonComponent(TypedDict):
@ -79,4 +79,5 @@ class TextInput(TypedDict):
max_length: NotRequired[int] max_length: NotRequired[int]
Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput] ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput]
Component = Union[ActionRow, ActionRowChildComponent]

2
discord/types/interactions.py

@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData,
class ModalSubmitInteractionData(TypedDict): class ModalSubmitInteractionData(TypedDict):
custom_id: str custom_id: str
components: List[ModalSubmitActionRowInteractionData] components: List[ModalSubmitComponentInteractionData]
InteractionData = Union[ InteractionData = Union[

1
discord/ui/button.py

@ -120,7 +120,6 @@ class Button(Item[V]):
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self._underlying = ButtonComponent._raw_construct( self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
custom_id=custom_id, custom_id=custom_id,
url=url, url=url,
disabled=disabled, disabled=disabled,

1
discord/ui/select.py

@ -117,7 +117,6 @@ class Select(Item[V]):
options = [] if options is MISSING else options options = [] if options is MISSING else options
self._underlying = SelectMenu._raw_construct( self._underlying = SelectMenu._raw_construct(
custom_id=custom_id, custom_id=custom_id,
type=ComponentType.select,
placeholder=placeholder, placeholder=placeholder,
min_values=min_values, min_values=min_values,
max_values=max_values, max_values=max_values,

3
discord/ui/text_input.py

@ -114,7 +114,6 @@ class TextInput(Item[V]):
raise TypeError(f'expected custom_id to be str not {custom_id.__class__!r}') raise TypeError(f'expected custom_id to be str not {custom_id.__class__!r}')
self._underlying = TextInputComponent._raw_construct( self._underlying = TextInputComponent._raw_construct(
type=ComponentType.text_input,
label=label, label=label,
style=style, style=style,
custom_id=custom_id, custom_id=custom_id,
@ -238,7 +237,7 @@ class TextInput(Item[V]):
@property @property
def type(self) -> Literal[ComponentType.text_input]: def type(self) -> Literal[ComponentType.text_input]:
return ComponentType.text_input return self._underlying.type
def is_dispatchable(self) -> bool: def is_dispatchable(self) -> bool:
return False return False

14
discord/ui/view.py

@ -281,7 +281,7 @@ class View:
one of its subclasses. one of its subclasses.
""" """
view = View(timeout=timeout) view = View(timeout=timeout)
for component in _walk_all_components(message.components): for component in _walk_all_components(message.components): # type: ignore
view.add_item(_component_to_item(component)) view.add_item(_component_to_item(component))
return view return view
@ -634,7 +634,15 @@ class ViewStore:
def remove_message_tracking(self, message_id: int) -> Optional[View]: def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None) return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None: def update_from_message(self, message_id: int, data: List[ComponentPayload]) -> None:
components: List[Component] = []
for component_data in data:
component = _component_factory(component_data)
if component is not None:
components.append(component)
# pre-req: is_message_tracked == true # pre-req: is_message_tracked == true
view = self._synced_message_views[message_id] view = self._synced_message_views[message_id]
view._refresh([_component_factory(d) for d in components]) view._refresh(components)

Loading…
Cancel
Save