Browse Source

Improve component typing

pull/10109/head
Lilly Rose Berner 3 years ago
committed by dolfies
parent
commit
e9e2d8cb1c
  1. 96
      discord/components.py
  2. 19
      discord/message.py
  3. 17
      discord/partial_emoji.py
  4. 5
      discord/types/components.py
  5. 2
      discord/types/interactions.py

96
discord/components.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Union from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, InteractionType from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, InteractionType
from .interactions import _wrapped_interaction from .interactions import _wrapped_interaction
@ -39,13 +39,15 @@ if TYPE_CHECKING:
ButtonComponent as ButtonComponentPayload, ButtonComponent as ButtonComponentPayload,
SelectMenu as SelectMenuPayload, SelectMenu as SelectMenuPayload,
SelectOption as SelectOptionPayload, SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload,
TextInput as TextInputPayload, TextInput as TextInputPayload,
ActionRowChildComponent as ActionRowChildComponentPayload,
) )
from .emoji import Emoji from .emoji import Emoji
from .interactions import Interaction from .interactions import Interaction
from .message import Message from .message import Message
ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput']
__all__ = ( __all__ = (
'Component', 'Component',
@ -68,23 +70,22 @@ class Component:
- :class:`TextInput` - :class:`TextInput`
.. versionadded:: 2.0 .. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
""" """
__slots__: Tuple[str, ...] = ('type', 'message') __slots__: Tuple[str, ...] = ('type', 'message')
__repr_info__: ClassVar[Tuple[str, ...]] __repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
message: Message message: Message
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>' return f'<{self.__class__.__name__} {attrs}>'
@property
def type(self) -> ComponentType:
""":class:`ComponentType`: The type of component."""
raise NotImplementedError
@classmethod @classmethod
def _raw_construct(cls, **kwargs) -> Self: def _raw_construct(cls, **kwargs) -> Self:
self = cls.__new__(cls) self = cls.__new__(cls)
@ -97,7 +98,7 @@ class Component:
setattr(self, slot, value) setattr(self, slot, value)
return self return self
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> ComponentPayload:
raise NotImplementedError raise NotImplementedError
@ -112,9 +113,7 @@ class ActionRow(Component):
Attributes Attributes
------------ ------------
type: :class:`ComponentType` children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]]
The type of component.
children: List[:class:`Component`]
The children components that this holds, if any. The children components that this holds, if any.
message: :class:`Message` message: :class:`Message`
The originating message. The originating message.
@ -126,14 +125,18 @@ class ActionRow(Component):
def __init__(self, data: ComponentPayload, message: Message): def __init__(self, data: ComponentPayload, message: Message):
self.message = message self.message = message
self.type: ComponentType = try_enum(ComponentType, data['type']) self.children: List[ActionRowChildComponentType] = []
self.children: List[Component] = [_component_factory(d, message) for d in data.get('components', [])]
def to_dict(self) -> ActionRowPayload: for component_data in data.get('components', []):
return { component = _component_factory(component_data)
'type': int(self.type),
'components': [child.to_dict() for child in self.children], if component is not None:
} # type: ignore # Type checker does not understand these are the same self.children.append(component)
@property
def type(self) -> Literal[ComponentType.action_row]:
""":class:`ComponentType`: The type of component."""
return ComponentType.action_row
class Button(Component): class Button(Component):
@ -175,7 +178,6 @@ class Button(Component):
def __init__(self, data: ButtonComponentPayload, message: Message): def __init__(self, data: ButtonComponentPayload, message: Message):
self.message = message self.message = message
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id') self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url') self.url: Optional[str] = data.get('url')
@ -193,6 +195,11 @@ class Button(Component):
'custom_id': self.custom_id, 'custom_id': self.custom_id,
} }
@property
def type(self) -> Literal[ComponentType.button]:
""":class:`ComponentType`: The type of component."""
return ComponentType.button
async def click(self) -> Union[str, Interaction]: async def click(self) -> Union[str, Interaction]:
"""|coro| """|coro|
@ -268,7 +275,6 @@ class SelectMenu(Component):
def __init__(self, data: SelectMenuPayload, message: Message): def __init__(self, data: SelectMenuPayload, message: Message):
self.message = message self.message = message
self.type = ComponentType.select
self.custom_id: str = data['custom_id'] self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder') self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1) self.min_values: int = data.get('min_values', 1)
@ -277,6 +283,11 @@ class SelectMenu(Component):
self.disabled: bool = data.get('disabled', False) self.disabled: bool = data.get('disabled', False)
self.hash: str = data.get('hash', '') self.hash: str = data.get('hash', '')
@property
def type(self) -> Literal[ComponentType.select]:
""":class:`ComponentType`: The type of component."""
return ComponentType.select
def to_dict(self, options: Tuple[SelectOption]) -> dict: def to_dict(self, options: Tuple[SelectOption]) -> dict:
return { return {
'component_type': self.type.value, 'component_type': self.type.value,
@ -333,7 +344,7 @@ class SelectOption:
description: Optional[:class:`str`] description: Optional[:class:`str`]
An additional description of the option, if any. An additional description of the option, if any.
Can only be up to 100 characters. Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]] emoji: Optional[:class:`PartialEmoji`]
The emoji of the option, if available. The emoji of the option, if available.
default: :class:`bool` default: :class:`bool`
Whether this option is selected by default. Whether this option is selected by default.
@ -368,7 +379,7 @@ class SelectOption:
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji self.emoji: Optional[PartialEmoji] = emoji
self.default: bool = default self.default: bool = default
def __repr__(self) -> str: def __repr__(self) -> str:
@ -440,8 +451,7 @@ class TextInput(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: TextInputPayload, _=MISSING) -> None: def __init__(self, data: TextInputPayload, *args) -> None:
self.type: ComponentType = ComponentType.text_input
self.style: TextStyle = try_enum(TextStyle, data['style']) self.style: TextStyle = try_enum(TextStyle, data['style'])
self.label: str = data['label'] self.label: str = data['label']
self.custom_id: str = data['custom_id'] self.custom_id: str = data['custom_id']
@ -451,6 +461,11 @@ class TextInput(Component):
self.min_length: Optional[int] = data.get('min_length') self.min_length: Optional[int] = data.get('min_length')
self.max_length: Optional[int] = data.get('max_length') self.max_length: Optional[int] = data.get('max_length')
@property
def type(self) -> Literal[ComponentType.text_input]:
""":class:`ComponentType`: The type of component."""
return ComponentType.text_input
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'type': self.type.value, 'type': self.type.value,
@ -500,17 +515,22 @@ class TextInput(Component):
self.value = value self.value = value
def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Component: @overload
# The type checker does not properly do narrowing here def _component_factory(data: ActionRowChildComponentPayload, message: Message = ...) -> Optional[ActionRowChildComponentType]:
component_type = data['type'] ...
if component_type == 1:
@overload
def _component_factory(data: ComponentPayload, message: Message = ...) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
...
def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Optional[Union[ActionRow, ActionRowChildComponentType]]:
if data['type'] == 1:
return ActionRow(data, message) return ActionRow(data, message)
elif component_type == 2: elif data['type'] == 2:
return Button(data, message) # type: ignore return Button(data, message)
elif component_type == 3: elif data['type'] == 3:
return SelectMenu(data, message) # type: ignore return SelectMenu(data, message)
elif component_type == 4: elif data['type'] == 4:
return TextInput(data, message) # type: ignore return TextInput(data, message)
else:
as_enum = try_enum(ComponentType, component_type)
return Component._raw_construct(type=as_enum)

19
discord/message.py

@ -94,7 +94,7 @@ if TYPE_CHECKING:
from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent
from .abc import Snowflake from .abc import Snowflake
from .abc import GuildChannel, MessageableChannel from .abc import GuildChannel, MessageableChannel
from .components import ActionRow from .components import ActionRow, ActionRowChildComponentType
from .state import ConnectionState from .state import ConnectionState
from .channel import TextChannel from .channel import TextChannel
from .mentions import AllowedMentions from .mentions import AllowedMentions
@ -102,6 +102,7 @@ if TYPE_CHECKING:
from .role import Role from .role import Role
EmojiInputType = Union[Emoji, PartialEmoji, str] EmojiInputType = Union[Emoji, PartialEmoji, str]
MessageComponentType = Union[ActionRow, ActionRowChildComponentType]
__all__ = ( __all__ = (
@ -1254,7 +1255,7 @@ class Message(PartialMessage, Hashable):
A list of sticker items given to the message. A list of sticker items given to the message.
.. versionadded:: 1.6 .. versionadded:: 1.6
components: List[:class:`Component`] components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]]
A list of components in the message. A list of components in the message.
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -1311,6 +1312,7 @@ class Message(PartialMessage, Hashable):
mentions: List[Union[User, Member]] mentions: List[Union[User, Member]]
author: Union[User, Member] author: Union[User, Member]
role_mentions: List[Role] role_mentions: List[Role]
components: List[MessageComponentType]
def __init__( def __init__(
self, self,
@ -1337,7 +1339,6 @@ class Message(PartialMessage, Hashable):
self.content: str = data['content'] self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce') self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[ActionRow] = [_component_factory(d, self) for d in data.get('components', [])] # type: ignore # Will always be rows here
self.call: Optional[CallMessage] = None self.call: Optional[CallMessage] = None
try: try:
@ -1387,7 +1388,7 @@ class Message(PartialMessage, Hashable):
# The channel will be the correct type here # The channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction'): for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction', 'components'):
try: try:
getattr(self, f'_handle_{handler}')(data[handler]) getattr(self, f'_handle_{handler}')(data[handler])
except KeyError: except KeyError:
@ -1579,8 +1580,14 @@ class Message(PartialMessage, Hashable):
call['participants'] = participants call['participants'] = participants
self.call = CallMessage(message=self, **call) self.call = CallMessage(message=self, **call)
def _handle_components(self, components: List[ComponentPayload]): def _handle_components(self, data: List[ComponentPayload]) -> None:
self.components: List[ActionRow] = [_component_factory(d, self) for d in components] # type: ignore # Will always be rows here self.components = []
for component_data in data:
component = _component_factory(component_data, self)
if component is not None:
self.components.append(component)
def _handle_interaction(self, data: MessageInteractionPayload): def _handle_interaction(self, data: MessageInteractionPayload):
self.interaction = Interaction._from_message(self, **data) self.interaction = Interaction._from_message(self, **data)

17
discord/partial_emoji.py

@ -41,7 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji from .types.activity import ActivityEmoji
@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return cls(name=value, id=None, animated=False) return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> EmojiPayload:
o: Dict[str, Any] = {'name': self.name} payload: EmojiPayload = {
if self.id: 'id': self.id,
o['id'] = self.id 'name': self.name,
}
if self.animated: if self.animated:
o['animated'] = self.animated payload['animated'] = self.animated
return o
return payload
def _to_partial(self) -> PartialEmoji: def _to_partial(self) -> PartialEmoji:
return self return self

5
discord/types/components.py

@ -36,7 +36,7 @@ TextStyle = Literal[1, 2]
class ActionRow(TypedDict): class ActionRow(TypedDict):
type: Literal[1] type: Literal[1]
components: List[Component] components: List[ActionRowChildComponent]
class ButtonComponent(TypedDict): class ButtonComponent(TypedDict):
@ -79,4 +79,5 @@ class TextInput(TypedDict):
max_length: NotRequired[int] max_length: NotRequired[int]
Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput] ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput]
Component = Union[ActionRow, ActionRowChildComponent]

2
discord/types/interactions.py

@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData,
class ModalSubmitInteractionData(TypedDict): class ModalSubmitInteractionData(TypedDict):
custom_id: str custom_id: str
components: List[ModalSubmitActionRowInteractionData] components: List[ModalSubmitComponentInteractionData]
InteractionData = Union[ InteractionData = Union[

Loading…
Cancel
Save