Browse Source

Improve component typing

pull/10109/head
Lilly Rose Berner 3 years ago
committed by dolfies
parent
commit
e9e2d8cb1c
  1. 96
      discord/components.py
  2. 19
      discord/message.py
  3. 17
      discord/partial_emoji.py
  4. 5
      discord/types/components.py
  5. 2
      discord/types/interactions.py

96
discord/components.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, InteractionType
from .interactions import _wrapped_interaction
@ -39,13 +39,15 @@ if TYPE_CHECKING:
ButtonComponent as ButtonComponentPayload,
SelectMenu as SelectMenuPayload,
SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload,
TextInput as TextInputPayload,
ActionRowChildComponent as ActionRowChildComponentPayload,
)
from .emoji import Emoji
from .interactions import Interaction
from .message import Message
ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput']
__all__ = (
'Component',
@ -68,23 +70,22 @@ class Component:
- :class:`TextInput`
.. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
"""
__slots__: Tuple[str, ...] = ('type', 'message')
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
message: Message
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
@property
def type(self) -> ComponentType:
""":class:`ComponentType`: The type of component."""
raise NotImplementedError
@classmethod
def _raw_construct(cls, **kwargs) -> Self:
self = cls.__new__(cls)
@ -97,7 +98,7 @@ class Component:
setattr(self, slot, value)
return self
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> ComponentPayload:
raise NotImplementedError
@ -112,9 +113,7 @@ class ActionRow(Component):
Attributes
------------
type: :class:`ComponentType`
The type of component.
children: List[:class:`Component`]
children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]]
The children components that this holds, if any.
message: :class:`Message`
The originating message.
@ -126,14 +125,18 @@ class ActionRow(Component):
def __init__(self, data: ComponentPayload, message: Message):
self.message = message
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.children: List[Component] = [_component_factory(d, message) for d in data.get('components', [])]
self.children: List[ActionRowChildComponentType] = []
def to_dict(self) -> ActionRowPayload:
return {
'type': int(self.type),
'components': [child.to_dict() for child in self.children],
} # type: ignore # Type checker does not understand these are the same
for component_data in data.get('components', []):
component = _component_factory(component_data)
if component is not None:
self.children.append(component)
@property
def type(self) -> Literal[ComponentType.action_row]:
""":class:`ComponentType`: The type of component."""
return ComponentType.action_row
class Button(Component):
@ -175,7 +178,6 @@ class Button(Component):
def __init__(self, data: ButtonComponentPayload, message: Message):
self.message = message
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
@ -193,6 +195,11 @@ class Button(Component):
'custom_id': self.custom_id,
}
@property
def type(self) -> Literal[ComponentType.button]:
""":class:`ComponentType`: The type of component."""
return ComponentType.button
async def click(self) -> Union[str, Interaction]:
"""|coro|
@ -268,7 +275,6 @@ class SelectMenu(Component):
def __init__(self, data: SelectMenuPayload, message: Message):
self.message = message
self.type = ComponentType.select
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
@ -277,6 +283,11 @@ class SelectMenu(Component):
self.disabled: bool = data.get('disabled', False)
self.hash: str = data.get('hash', '')
@property
def type(self) -> Literal[ComponentType.select]:
""":class:`ComponentType`: The type of component."""
return ComponentType.select
def to_dict(self, options: Tuple[SelectOption]) -> dict:
return {
'component_type': self.type.value,
@ -333,7 +344,7 @@ class SelectOption:
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
emoji: Optional[:class:`PartialEmoji`]
The emoji of the option, if available.
default: :class:`bool`
Whether this option is selected by default.
@ -368,7 +379,7 @@ class SelectOption:
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.emoji: Optional[PartialEmoji] = emoji
self.default: bool = default
def __repr__(self) -> str:
@ -440,8 +451,7 @@ class TextInput(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: TextInputPayload, _=MISSING) -> None:
self.type: ComponentType = ComponentType.text_input
def __init__(self, data: TextInputPayload, *args) -> None:
self.style: TextStyle = try_enum(TextStyle, data['style'])
self.label: str = data['label']
self.custom_id: str = data['custom_id']
@ -451,6 +461,11 @@ class TextInput(Component):
self.min_length: Optional[int] = data.get('min_length')
self.max_length: Optional[int] = data.get('max_length')
@property
def type(self) -> Literal[ComponentType.text_input]:
""":class:`ComponentType`: The type of component."""
return ComponentType.text_input
def to_dict(self) -> dict:
return {
'type': self.type.value,
@ -500,17 +515,22 @@ class TextInput(Component):
self.value = value
def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Component:
# The type checker does not properly do narrowing here
component_type = data['type']
if component_type == 1:
@overload
def _component_factory(data: ActionRowChildComponentPayload, message: Message = ...) -> Optional[ActionRowChildComponentType]:
...
@overload
def _component_factory(data: ComponentPayload, message: Message = ...) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
...
def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
if data['type'] == 1:
return ActionRow(data, message)
elif component_type == 2:
return Button(data, message) # type: ignore
elif component_type == 3:
return SelectMenu(data, message) # type: ignore
elif component_type == 4:
return TextInput(data, message) # type: ignore
else:
as_enum = try_enum(ComponentType, component_type)
return Component._raw_construct(type=as_enum)
elif data['type'] == 2:
return Button(data, message)
elif data['type'] == 3:
return SelectMenu(data, message)
elif data['type'] == 4:
return TextInput(data, message)

19
discord/message.py

@ -94,7 +94,7 @@ if TYPE_CHECKING:
from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent
from .abc import Snowflake
from .abc import GuildChannel, MessageableChannel
from .components import ActionRow
from .components import ActionRow, ActionRowChildComponentType
from .state import ConnectionState
from .channel import TextChannel
from .mentions import AllowedMentions
@ -102,6 +102,7 @@ if TYPE_CHECKING:
from .role import Role
EmojiInputType = Union[Emoji, PartialEmoji, str]
MessageComponentType = Union[ActionRow, ActionRowChildComponentType]
__all__ = (
@ -1254,7 +1255,7 @@ class Message(PartialMessage, Hashable):
A list of sticker items given to the message.
.. versionadded:: 1.6
components: List[:class:`Component`]
components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]]
A list of components in the message.
.. versionadded:: 2.0
@ -1311,6 +1312,7 @@ class Message(PartialMessage, Hashable):
mentions: List[Union[User, Member]]
author: Union[User, Member]
role_mentions: List[Role]
components: List[MessageComponentType]
def __init__(
self,
@ -1337,7 +1339,6 @@ class Message(PartialMessage, Hashable):
self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[ActionRow] = [_component_factory(d, self) for d in data.get('components', [])] # type: ignore # Will always be rows here
self.call: Optional[CallMessage] = None
try:
@ -1387,7 +1388,7 @@ class Message(PartialMessage, Hashable):
# The channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction'):
for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction', 'components'):
try:
getattr(self, f'_handle_{handler}')(data[handler])
except KeyError:
@ -1579,8 +1580,14 @@ class Message(PartialMessage, Hashable):
call['participants'] = participants
self.call = CallMessage(message=self, **call)
def _handle_components(self, components: List[ComponentPayload]):
self.components: List[ActionRow] = [_component_factory(d, self) for d in components] # type: ignore # Will always be rows here
def _handle_components(self, data: List[ComponentPayload]) -> None:
self.components = []
for component_data in data:
component = _component_factory(component_data, self)
if component is not None:
self.components.append(component)
def _handle_interaction(self, data: MessageInteractionPayload):
self.interaction = Interaction._from_message(self, **data)

17
discord/partial_emoji.py

@ -41,7 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name}
if self.id:
o['id'] = self.id
def to_dict(self) -> EmojiPayload:
payload: EmojiPayload = {
'id': self.id,
'name': self.name,
}
if self.animated:
o['animated'] = self.animated
return o
payload['animated'] = self.animated
return payload
def _to_partial(self) -> PartialEmoji:
return self

5
discord/types/components.py

@ -36,7 +36,7 @@ TextStyle = Literal[1, 2]
class ActionRow(TypedDict):
type: Literal[1]
components: List[Component]
components: List[ActionRowChildComponent]
class ButtonComponent(TypedDict):
@ -79,4 +79,5 @@ class TextInput(TypedDict):
max_length: NotRequired[int]
Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput]
ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput]
Component = Union[ActionRow, ActionRowChildComponent]

2
discord/types/interactions.py

@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData,
class ModalSubmitInteractionData(TypedDict):
custom_id: str
components: List[ModalSubmitActionRowInteractionData]
components: List[ModalSubmitComponentInteractionData]
InteractionData = Union[

Loading…
Cancel
Save