Browse Source

Improve component typing

pull/8047/head
Lilly Rose Berner 3 years ago
committed by GitHub
parent
commit
7267d18d9e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 136
      discord/components.py
  2. 19
      discord/message.py
  3. 17
      discord/partial_emoji.py
  4. 2
      discord/state.py
  5. 5
      discord/types/components.py
  6. 2
      discord/types/interactions.py
  7. 1
      discord/ui/button.py
  8. 1
      discord/ui/select.py
  9. 3
      discord/ui/text_input.py
  10. 14
      discord/ui/view.py

136
discord/components.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle
from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag
@ -39,9 +39,12 @@ if TYPE_CHECKING:
SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload,
TextInput as TextInputPayload,
ActionRowChildComponent as ActionRowChildComponentPayload,
)
from .emoji import Emoji
ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput']
__all__ = (
'Component',
@ -61,26 +64,26 @@ class Component:
- :class:`ActionRow`
- :class:`Button`
- :class:`SelectMenu`
- :class:`TextInput`
This class is abstract and cannot be instantiated.
.. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
"""
__slots__: Tuple[str, ...] = ('type',)
__slots__: Tuple[str, ...] = ()
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
@property
def type(self) -> ComponentType:
""":class:`ComponentType`: The type of component."""
raise NotImplementedError
@classmethod
def _raw_construct(cls, **kwargs) -> Self:
self = cls.__new__(cls)
@ -93,7 +96,7 @@ class Component:
setattr(self, slot, value)
return self
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> ComponentPayload:
raise NotImplementedError
@ -108,9 +111,7 @@ class ActionRow(Component):
Attributes
------------
type: :class:`ComponentType`
The type of component.
children: List[:class:`Component`]
children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]]
The children components that this holds, if any.
"""
@ -118,15 +119,25 @@ class ActionRow(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: Literal[ComponentType.action_row] = ComponentType.action_row
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
def __init__(self, data: ActionRowPayload, /) -> None:
self.children: List[ActionRowChildComponentType] = []
for component_data in data.get('components', []):
component = _component_factory(component_data)
if component is not None:
self.children.append(component)
@property
def type(self) -> Literal[ComponentType.action_row]:
""":class:`ComponentType`: The type of component."""
return ComponentType.action_row
def to_dict(self) -> ActionRowPayload:
return {
'type': int(self.type),
'type': self.type.value,
'components': [child.to_dict() for child in self.children],
} # type: ignore # Type checker does not understand these are the same
}
class Button(Component):
@ -169,8 +180,7 @@ class Button(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: Literal[ComponentType.button] = ComponentType.button
def __init__(self, data: ButtonComponentPayload, /) -> None:
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
@ -182,13 +192,21 @@ class Button(Component):
except KeyError:
self.emoji = None
@property
def type(self) -> Literal[ComponentType.button]:
""":class:`ComponentType`: The type of component."""
return ComponentType.button
def to_dict(self) -> ButtonComponentPayload:
payload = {
payload: ButtonComponentPayload = {
'type': 2,
'style': int(self.style),
'label': self.label,
'style': self.style.value,
'disabled': self.disabled,
}
if self.label:
payload['label'] = self.label
if self.custom_id:
payload['custom_id'] = self.custom_id
@ -198,7 +216,7 @@ class Button(Component):
if self.emoji:
payload['emoji'] = self.emoji.to_dict()
return payload # type: ignore # Type checker does not understand these are the same
return payload
class SelectMenu(Component):
@ -243,8 +261,7 @@ class SelectMenu(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type: Literal[ComponentType.select] = ComponentType.select
def __init__(self, data: SelectMenuPayload, /) -> None:
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
@ -252,6 +269,11 @@ class SelectMenu(Component):
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
@property
def type(self) -> Literal[ComponentType.select]:
""":class:`ComponentType`: The type of component."""
return ComponentType.select
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
'type': self.type.value,
@ -275,7 +297,7 @@ class SelectOption:
.. versionadded:: 2.0
Attributes
Parameters
-----------
label: :class:`str`
The label of the option. This is displayed to users.
@ -291,6 +313,23 @@ class SelectOption:
The emoji of the option, if available.
default: :class:`bool`
Whether this option is selected by default.
Attributes
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the option, if available.
default: :class:`bool`
Whether this option is selected by default.
"""
__slots__: Tuple[str, ...] = (
@ -322,7 +361,7 @@ class SelectOption:
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.emoji: Optional[PartialEmoji] = emoji
self.default: bool = default
def __repr__(self) -> str:
@ -364,7 +403,7 @@ class SelectOption:
}
if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore # This Dict[str, Any] is compatible with PartialEmoji
payload['emoji'] = self.emoji.to_dict()
if self.description:
payload['description'] = self.description
@ -414,8 +453,7 @@ class TextInput(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: TextInputPayload) -> None:
self.type: Literal[ComponentType.text_input] = ComponentType.text_input
def __init__(self, data: TextInputPayload, /) -> None:
self.style: TextStyle = try_enum(TextStyle, data['style'])
self.label: str = data['label']
self.custom_id: str = data['custom_id']
@ -425,6 +463,11 @@ class TextInput(Component):
self.min_length: Optional[int] = data.get('min_length')
self.max_length: Optional[int] = data.get('max_length')
@property
def type(self) -> Literal[ComponentType.text_input]:
""":class:`ComponentType`: The type of component."""
return ComponentType.text_input
def to_dict(self) -> TextInputPayload:
payload: TextInputPayload = {
'type': self.type.value,
@ -457,19 +500,22 @@ class TextInput(Component):
return self.value
def _component_factory(data: ComponentPayload) -> Component:
component_type = data['type']
if component_type == 1:
@overload
def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]:
...
@overload
def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
...
def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
if data['type'] == 1:
return ActionRow(data)
elif component_type == 2:
# The type checker does not properly do narrowing here.
return Button(data) # type: ignore
elif component_type == 3:
# The type checker does not properly do narrowing here.
return SelectMenu(data) # type: ignore
elif component_type == 4:
# The type checker does not properly do narrowing here.
return TextInput(data) # type: ignore
else:
as_enum = try_enum(ComponentType, component_type)
return Component._raw_construct(type=as_enum)
elif data['type'] == 2:
return Button(data)
elif data['type'] == 3:
return SelectMenu(data)
elif data['type'] == 4:
return TextInput(data)

19
discord/message.py

@ -87,7 +87,7 @@ if TYPE_CHECKING:
from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent
from .abc import Snowflake
from .abc import GuildChannel, MessageableChannel
from .components import Component
from .components import ActionRow, ActionRowChildComponentType
from .state import ConnectionState
from .channel import TextChannel
from .mentions import AllowedMentions
@ -96,6 +96,7 @@ if TYPE_CHECKING:
from .ui.view import View
EmojiInputType = Union[Emoji, PartialEmoji, str]
MessageComponentType = Union[ActionRow, ActionRowChildComponentType]
__all__ = (
@ -1340,7 +1341,7 @@ class Message(PartialMessage, Hashable):
A list of sticker items given to the message.
.. versionadded:: 1.6
components: List[:class:`Component`]
components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]]
A list of components in the message.
.. versionadded:: 2.0
@ -1392,6 +1393,7 @@ class Message(PartialMessage, Hashable):
mentions: List[Union[User, Member]]
author: Union[User, Member]
role_mentions: List[Role]
components: List[MessageComponentType]
def __init__(
self,
@ -1418,7 +1420,6 @@ class Message(PartialMessage, Hashable):
self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
try:
# if the channel doesn't have a guild attribute, we handle that
@ -1460,7 +1461,7 @@ class Message(PartialMessage, Hashable):
# the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'):
for handler in ('author', 'member', 'mentions', 'mention_roles', 'components'):
try:
getattr(self, f'_handle_{handler}')(data[handler])
except KeyError:
@ -1631,8 +1632,14 @@ class Message(PartialMessage, Hashable):
if role is not None:
self.role_mentions.append(role)
def _handle_components(self, components: List[ComponentPayload]):
self.components = [_component_factory(d) for d in components]
def _handle_components(self, data: List[ComponentPayload]) -> None:
self.components = []
for component_data in data:
component = _component_factory(component_data)
if component is not None:
self.components.append(component)
def _handle_interaction(self, data: MessageInteractionPayload):
self.interaction = MessageInteraction(state=self._state, guild=self.guild, data=data)

17
discord/partial_emoji.py

@ -41,7 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name}
if self.id:
o['id'] = self.id
def to_dict(self) -> EmojiPayload:
payload: EmojiPayload = {
'id': self.id,
'name': self.name,
}
if self.animated:
o['animated'] = self.animated
return o
payload['animated'] = self.animated
return payload
def _to_partial(self) -> PartialEmoji:
return self

2
discord/state.py

@ -730,7 +730,7 @@ class ConnectionState:
inner_data = data['data']
custom_id = inner_data['custom_id']
components = inner_data['components']
self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore
self._view_store.dispatch_modal(custom_id, interaction, components)
self.dispatch('interaction', interaction)
def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None:

5
discord/types/components.py

@ -36,7 +36,7 @@ TextStyle = Literal[1, 2]
class ActionRow(TypedDict):
type: Literal[1]
components: List[Component]
components: List[ActionRowChildComponent]
class ButtonComponent(TypedDict):
@ -79,4 +79,5 @@ class TextInput(TypedDict):
max_length: NotRequired[int]
Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput]
ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput]
Component = Union[ActionRow, ActionRowChildComponent]

2
discord/types/interactions.py

@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData,
class ModalSubmitInteractionData(TypedDict):
custom_id: str
components: List[ModalSubmitActionRowInteractionData]
components: List[ModalSubmitComponentInteractionData]
InteractionData = Union[

1
discord/ui/button.py

@ -120,7 +120,6 @@ class Button(Item[V]):
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
custom_id=custom_id,
url=url,
disabled=disabled,

1
discord/ui/select.py

@ -117,7 +117,6 @@ class Select(Item[V]):
options = [] if options is MISSING else options
self._underlying = SelectMenu._raw_construct(
custom_id=custom_id,
type=ComponentType.select,
placeholder=placeholder,
min_values=min_values,
max_values=max_values,

3
discord/ui/text_input.py

@ -114,7 +114,6 @@ class TextInput(Item[V]):
raise TypeError(f'expected custom_id to be str not {custom_id.__class__!r}')
self._underlying = TextInputComponent._raw_construct(
type=ComponentType.text_input,
label=label,
style=style,
custom_id=custom_id,
@ -238,7 +237,7 @@ class TextInput(Item[V]):
@property
def type(self) -> Literal[ComponentType.text_input]:
return ComponentType.text_input
return self._underlying.type
def is_dispatchable(self) -> bool:
return False

14
discord/ui/view.py

@ -281,7 +281,7 @@ class View:
one of its subclasses.
"""
view = View(timeout=timeout)
for component in _walk_all_components(message.components):
for component in _walk_all_components(message.components): # type: ignore
view.add_item(_component_to_item(component))
return view
@ -634,7 +634,15 @@ class ViewStore:
def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None:
def update_from_message(self, message_id: int, data: List[ComponentPayload]) -> None:
components: List[Component] = []
for component_data in data:
component = _component_factory(component_data)
if component is not None:
components.append(component)
# pre-req: is_message_tracked == true
view = self._synced_message_views[message_id]
view._refresh([_component_factory(d) for d in components])
view._refresh(components)

Loading…
Cancel
Save