Browse Source

Implement application commands

pull/10109/head
dolfies 3 years ago
parent
commit
95ed42669d
  1. 587
      discord/commands.py
  2. 8
      discord/enums.py
  3. 77
      discord/gateway.py
  4. 269
      discord/iterators.py
  5. 40
      discord/utils.py

587
discord/commands.py

@ -0,0 +1,587 @@
"""
The MIT License (MIT)
Copyright (c) 2021-present Dolfies
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from asyncio import TimeoutError
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from .enums import CommandType, ChannelType, OptionType, try_enum
from .errors import InvalidData
from .utils import time_snowflake
if TYPE_CHECKING:
from .abc import Messageable, Snowflake
from .interactions import Interaction
from .message import Message
from .state import ConnectionState
class ApplicationCommand:
def __init__(
self, data: Dict[str, Any]
) -> None:
self.name: str = data['name']
self.description: str = data['description']
async def __call__(self, data, channel: Optional[Messageable] = None) -> Interaction:
channel = channel or self.target_channel # type: ignore
if channel is None:
raise TypeError('__call__() missing 1 required keyword-only argument: \'channel\'')
state = self._state # type: ignore
channel = await channel._get_channel()
payload = {
'application_id': str(self._application_id), # type: ignore
'channel_id': str(channel.id),
'data': data,
'nonce': str(time_snowflake(datetime.utcnow())),
'type': 2, # Should be an enum but eh
}
if getattr(channel, 'guild', None):
payload['guild_id'] = str(channel.guild.id) # type: ignore
state._interactions[payload['nonce']] = 2
await state.http.interact(payload, form_data=True)
try:
i = await state.client.wait_for(
'interaction_finish',
check=lambda d: d.nonce == payload['nonce'],
timeout=5,
)
except TimeoutError as exc:
raise InvalidData('Did not receive a response from Discord') from exc
return i
class _BaseCommand(ApplicationCommand):
def __init__(
self, *, state: ConnectionState, data: Dict[str, Any], channel: Optional[Messageable] = None
) -> None:
super().__init__(data)
self._state = state
self._channel = channel
self._application_id: int = int(data['application_id'])
self.id: int = int(data['id'])
self.version: int = int(data['version'])
self.type: CommandType = try_enum(CommandType, data['type'])
self.default_permission: bool = data['default_permission']
self._dm_permission = data['dm_permission']
self._default_member_permissions = data['default_member_permissions']
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
def is_group(self) -> bool:
"""Query whether this command is a group.
Here for compatibility purposes.
Returns
-------
:class:`bool`
Whether this command is a group.
"""
return False
@property
def application(self):
"""The application this command belongs to."""
...
#return self._state.get_application(self._application_id)
@property
def target_channel(self) -> Optional[Messageable]:
"""Optional[:class:`Messageable`]: The channel this application command will be used on.
You can set this in order to use this command in a different channel without re-fetching it.
"""
return self._channel
@target_channel.setter
def target_channel(self, value: Optional[Messageable]) -> None:
from .abc import Messageable
if not isinstance(value, Messageable) and value is not None:
raise TypeError('channel must derive from Messageable')
self._channel = value
class _SlashMixin:
async def __call__(self, options, channel=None):
# This will always be used in a context where all these attributes are set
obj = getattr(self, '_parent', self)
data = {
'attachments': [],
'id': str(obj.id), # type: ignore
'name': obj.name, # type: ignore
'options': options,
'type': obj.type.value, # type: ignore
'version': str(obj.version), # type: ignore
}
return await super().__call__(data, channel) # type: ignore
def _parse_kwargs(self, kwargs: Dict[str, Any]) -> List[Dict[str, Any]]:
possible_options = {o.name: o for o in self.options}
kwargs = {k: v for k, v in kwargs.items() if k in possible_options}
options = []
for k, v in kwargs.items():
option = possible_options[k]
type = option.type
if type in {
OptionType.user,
OptionType.channel,
OptionType.role,
OptionType.mentionable,
}:
v = str(v.id)
elif type is OptionType.boolean:
v = bool(v)
else:
v = option._convert(v)
if type is OptionType.string:
v = str(v)
elif type is OptionType.integer:
v = int(v)
elif type is OptionType.number:
v = float(v)
options.append({'name': k, 'value': v, 'type': type.value})
return options
def _unwrap_options(self, data: List[Dict[str, Any]]) -> None:
options = []
children = []
for option in data:
type = try_enum(OptionType, option['type'])
if type in {
OptionType.sub_command,
OptionType.sub_command_group,
}:
children.append(SubCommand(parent=self, data=option))
else:
options.append(Option(option))
for child in children:
setattr(self, child.name, child)
self.options: List[Option] = options
self.children: List[SubCommand] = children
class UserCommand(_BaseCommand):
"""Represents a user command.
Attributes
----------
id: :class:`int`
The command's ID.
name: :class:`str`
The command's name.
description: :class:`str`
The command's description, if any.
type: :class:`CommandType`
The type of application command. Always :class:`CommandType.user`.
default_permission: :class:`bool`
Whether the command is enabled in guilds by default.
"""
def __init__(self, *, user: Optional[Snowflake] = None, **kwargs):
super().__init__(**kwargs)
self._user = user
async def __call__(
self, user: Optional[Snowflake] = None, *, channel: Optional[Messageable] = None
):
"""Use the user command.
Parameters
----------
user: Optional[:class:`User`]
The user to use the command on. Overrides :attr:`target_user`.
Required if :attr:`target_user` is not set.
channel: Optional[:class:`abc.Messageable`]
The channel to use the command on. Overrides :attr:`target_channel`.
Required if :attr:`target_channel` is not set.
"""
user = user or self._user
if user is None:
raise TypeError('__call__() missing 1 required positional argument: \'user\'')
data = {
'attachments': [],
'id': str(self.id),
'name': self.name,
'options': [],
'target_id': str(user.id),
'type': self.type.value,
'version': str(self.version),
}
return await super().__call__(data, channel)
@property
def target_user(self) -> Optional[Snowflake]:
"""Optional[:class:`Snowflake`]: The user this application command will be used on.
You can set this in order to use this command on a different user without re-fetching it.
"""
return self._user
@target_user.setter
def target_user(self, value: Optional[Snowflake]) -> None:
from .abc import Snowflake
if not isinstance(value, Snowflake) and value is not None:
raise TypeError('user must be Snowflake')
self._user = value
class MessageCommand(_BaseCommand):
"""Represents a message command.
Attributes
----------
id: :class:`int`
The command's ID.
name: :class:`str`
The command's name.
description: :class:`str`
The command's description, if any.
type: :class:`CommandType`
The type of application command. Always :class:`CommandType.message`.
default_permission: :class:`bool`
Whether the command is enabled in guilds by default.
"""
def __init__(self, *, message: Optional[Message] = None, **kwargs):
super().__init__(**kwargs)
self._message = message
async def __call__(
self, message: Optional[Message] = None, *, channel: Optional[Messageable] = None
):
"""Use the message command.
Parameters
----------
message: Optional[:class:`Message`]
The message to use the command on. Overrides :attr:`target_message`.
Required if :attr:`target_message` is not set.
channel: Optional[:class:`abc.Messageable`]
The channel to use the command on. Overrides :attr:`target_channel`.
Required if :attr:`target_channel` is not set.
"""
message = message or self._message
if message is None:
raise TypeError('__call__() missing 1 required positional argument: \'message\'')
data = {
'attachments': [],
'id': str(self.id),
'name': self.name,
'options': [],
'target_id': str(message.id),
'type': self.type.value,
'version': str(self.version),
}
return await super().__call__(data, channel)
@property
def target_message(self) -> Optional[Message]:
"""Optional[:class:`Message`]: The message this application command will be used on.
You can set this in order to use this command on a different message without re-fetching it.
"""
return self._message
@target_message.setter
def target_message(self, value: Optional[Message]) -> None:
from .message import Message
if not isinstance(value, Message) and value is not None:
raise TypeError('message must be Message')
self._message = value
class SlashCommand(_SlashMixin, _BaseCommand):
"""Represents a slash command.
Attributes
----------
id: :class:`int`
The command's ID.
name: :class:`str`
The command's name.
description: :class:`str`
The command's description, if any.
type: :class:`CommandType`
The type of application command. Always :class:`CommandType.chat_input`.
default_permission: :class:`bool`
Whether the command is enabled in guilds by default.
options: List[:class:`Option`]
The command's options.
children: List[:class:`SubCommand`]
The command's subcommands. If a command has subcommands, it is a group and cannot be used.
You can access (and use) subcommands directly as attributes of the class.
"""
def __init__(
self, *, data: Dict[str, Any], **kwargs
) -> None:
super().__init__(data=data, **kwargs)
self._unwrap_options(data.get('options', []))
async def __call__(self, channel, /, **kwargs):
r"""Use the slash command.
Parameters
----------
channel: Optional[:class:`abc.Messageable`]
The channel to use the command on. Overrides :attr:`target_channel`.
Required if :attr:`target_message` is not set.
\*\*kwargs: Any
The options to use. These will be casted to the correct type.
If an option has choices, they are automatically converted from name to value for you.
"""
if self.is_group():
raise TypeError('Cannot use a group')
return await super().__call__(self._parse_kwargs(kwargs), channel)
def __repr__(self) -> str:
BASE = f'<SlashCommand id={self.id} name={self.name}'
if self.options:
BASE += f' options={len(self.options)}'
if self.children:
BASE += f' children={len(self.children)}'
return BASE + '>'
def is_group(self) -> bool:
"""Query whether this command is a group.
Returns
-------
:class:`bool`
Whether this command is a group.
"""
return bool(self.children)
class SubCommand(_SlashMixin, ApplicationCommand):
"""Represents a slash command child.
This could be a subcommand, or a subgroup.
Attributes
----------
parent: :class:`SlashCommand`
The parent command.
name: :class:`str`
The command's name.
description: :class:`str`
The command's description, if any.
type: :class:`CommandType`
The type of application command. Always :class:`CommandType.chat_input`.
"""
def __init__(self, *, parent, data):
super().__init__(data)
self.parent: Union[SlashCommand, SubCommand] = parent
self._parent: SlashCommand = getattr(parent, 'parent', parent) # type: ignore
self.type = CommandType.chat_input # Avoid confusion I guess
self._type: OptionType = try_enum(OptionType, data['type'])
self._unwrap_options(data.get('options', []))
def _walk_parents(self):
parent = self.parent
while True:
if isinstance(parent, SubCommand):
parent = parent.parent
else:
break
yield parent
async def __call__(self, channel, /, **kwargs):
r"""Use the sub command.
Parameters
----------
channel: Optional[:class:`abc.Messageable`]
The channel to use the command on. Overrides :attr:`target_channel`.
Required if :attr:`target_message` is not set.
\*\*kwargs: Any
The options to use. These will be casted to the correct type.
If an option has choices, they are automatically converted from name to value for you.
"""
if self.is_group():
raise TypeError('Cannot use a group')
options = [{
'type': self._type.value,
'name': self.name,
'options': self._parse_kwargs(kwargs),
}]
for parent in self._walk_parents():
options = [{
'type': parent.type.value,
'name': parent.name,
'options': options,
}]
return await super().__call__(options, channel)
def __repr__(self) -> str:
BASE = f'<SubCommand name={self.name}'
if self.options:
BASE += f' options={len(self.options)}'
if self.children:
BASE += f' children={len(self.children)}'
return BASE + '>'
@property
def _application_id(self) -> int:
return self._parent._application_id
@property
def version(self) -> int:
""":class:`int`: The version of the command."""
return self._parent.version
@property
def default_permission(self) -> bool:
""":class:`bool`: Whether the command is enabled in guilds by default."""
return self._parent.default_permission
def is_group(self) -> bool:
"""Query whether this command is a group.
Returns
-------
:class:`bool`
Whether this command is a group.
"""
return self._type is OptionType.sub_command_group
@property
def application(self):
"""The application this command belongs to."""
return self._parent.application
@property
def target_channel(self):
"""Optional[:class:`abc.Messageable`]: The channel this command will be used on.
You can set this in order to use this command on a different channel without re-fetching it.
"""
return self._parent.target_channel
@target_channel.setter
def target_channel(self, value: Optional[Messageable]) -> None:
self._parent.target_channel = value
class Option:
"""Represents a command option.
Attributes
----------
name: :class:`str`
The option's name.
description: :class:`str`
The option's description, if any.
type: :class:`OptionType`
The type of option.
required: :class:`bool`
Whether the option is required.
min_value: Optional[Union[:class:`int`, :class:`float`]]
Minimum value of the option. Only applicable to :attr:`OptionType.integer` and :attr:`OptionType.number`.
max_value: Optional[Union[:class:`int`, :class:`float`]]
Maximum value of the option. Only applicable to :attr:`OptionType.integer` and :attr:`OptionType.number`.
choices: List[:class:`OptionChoice`]
A list of possible choices to choose from. If these are present, you must choose one from them.
Only applicable to :attr:`OptionType.string`, :attr:`OptionType.integer`, and :attr:`OptionType.number`.
channel_types: List[:class:`ChannelType`]
A list of channel types that you can choose from. If these are present, you must choose a channel that is one of these types.
Only applicable to :attr:`OptionType.channel`.
autocomplete: :class:`bool`
Whether the option autocompletes. Always ``False`` if :attr:`choices` are present.
"""
def __init__(self, data):
self.name: str = data['name']
self.description: str = data['description']
self.type: OptionType = try_enum(OptionType, data['type'])
self.required: bool = data.get('required', False)
self.min_value: Optional[Union[int, float]] = data.get('min_value')
self.max_value: Optional[int] = data.get('max_value')
self.choices = [OptionChoice(choice, self.type) for choice in data.get('choices', [])]
self.channel_types: List[ChannelType] = [try_enum(ChannelType, c) for c in data.get('channel_types', [])]
self.autocomplete: bool = data.get('autocomplete', False)
def __repr__(self) -> str:
return f'<Option name={self.name} type={self.type} required={self.required}>'
def _convert(self, value):
for choice in self.choices:
if (new_value := choice._convert(value)) != value:
return new_value
return value
class OptionChoice:
"""Represents a choice for an option.
Attributes
----------
name: :class:`str`
The choice's displayed name.
value: Any
The choice's value. The type of this depends on the option's type.
"""
def __init__(self, data: Dict[str, str], type: OptionType):
self.name: str = data['name']
if type is OptionType.string:
self.value: str = data['value'] # type: ignore
elif type is OptionType.integer:
self.value: int = int(data['value']) # type: ignore
elif type is OptionType.number:
self.value: float = float(data['value']) # type: ignore
def __repr__(self) -> str:
return f'<OptionChoice name={self.name} value={self.value}>'
def _convert(self, value):
if value == self.name:
return self.value
return value
def _command_factory(command_type: int) -> Tuple[CommandType, _BaseCommand]:
value = try_enum(CommandType, command_type)
if value is CommandType.chat_input:
return value, SlashCommand
elif value is CommandType.user:
return value, UserCommand
elif value is CommandType.message:
return value, MessageCommand
else:
return value, _BaseCommand # IDK about this

8
discord/enums.py

@ -66,8 +66,8 @@ __all__ = (
'RequiredActionType', 'RequiredActionType',
'ReportType', 'ReportType',
'BrowserEnum', 'BrowserEnum',
'ApplicationCommandType', 'CommandType',
'ApplicationCommandOptionType', 'OptionType',
) )
@ -673,7 +673,7 @@ class InteractionType(Enum, comparable=True):
component = 3 component = 3
class ApplicationCommandType(Enum, comparable=True): class CommandType(Enum, comparable=True):
chat_input = 1 chat_input = 1
chat = 1 chat = 1
slash = 1 slash = 1
@ -684,7 +684,7 @@ class ApplicationCommandType(Enum, comparable=True):
return self.value return self.value
class ApplicationCommandOptionType(Enum, comparable=True): class OptionType(Enum, comparable=True):
sub_command = 1 sub_command = 1
sub_command_group = 2 sub_command_group = 2
string = 3 string = 3

77
discord/gateway.py

@ -24,10 +24,8 @@ DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
from collections import namedtuple, deque from collections import namedtuple, deque
import concurrent.futures
import logging import logging
import struct import struct
import sys
import time import time
import threading import threading
import traceback import traceback
@ -100,7 +98,7 @@ class GatewayRatelimiter:
async with self.lock: async with self.lock:
delta = self.get_delay() delta = self.get_delay()
if delta: if delta:
_log.warning('WebSocket is ratelimited, waiting %.2f seconds.', delta) _log.warning('Gateway is ratelimited, waiting %.2f seconds.', delta)
await asyncio.sleep(delta) await asyncio.sleep(delta)
@ -245,10 +243,12 @@ class DiscordWebSocket:
a connection issue. a connection issue.
GUILD_SYNC GUILD_SYNC
Send only. Requests a guild sync. This is unfortunately no longer functional. Send only. Requests a guild sync. This is unfortunately no longer functional.
ACCESS_DM CALL_CONNECT
Send only. Tracking. Send only. Maybe used for calling? Probably just tracking.
GUILD_SUBSCRIBE GUILD_SUBSCRIBE
Send only. Subscribes you to guilds/guild members. Might respond with GUILD_MEMBER_LIST_UPDATE. Send only. Subscribes you to guilds/guild members. Might respond with GUILD_MEMBER_LIST_UPDATE.
REQUEST_COMMANDS
Send only. Requests application commands from a guild. Responds with GUILD_APPLICATION_COMMANDS_UPDATE.
gateway gateway
The gateway we are currently connected to. The gateway we are currently connected to.
token token
@ -268,8 +268,9 @@ class DiscordWebSocket:
HELLO = 10 HELLO = 10
HEARTBEAT_ACK = 11 HEARTBEAT_ACK = 11
GUILD_SYNC = 12 # :( GUILD_SYNC = 12 # :(
ACCESS_DM = 13 CALL_CONNECT = 13
GUILD_SUBSCRIBE = 14 GUILD_SUBSCRIBE = 14
REQUEST_COMMANDS = 24
def __init__(self, socket, *, loop): def __init__(self, socket, *, loop):
self.socket = socket self.socket = socket
@ -354,7 +355,7 @@ class DiscordWebSocket:
Parameters Parameters
----------- -----------
event: :class:`str` event: :class:`str`
The event name in all upper case to wait for. The event to wait for.
predicate predicate
A function that takes a data parameter to check for event A function that takes a data parameter to check for event
properties. The data parameter is the 'd' key in the JSON message. properties. The data parameter is the 'd' key in the JSON message.
@ -368,6 +369,7 @@ class DiscordWebSocket:
A future to wait for. A future to wait for.
""" """
event = event.upper()
future = self.loop.create_future() future = self.loop.create_future()
entry = EventListener(event=event, predicate=predicate, result=result, future=future) entry = EventListener(event=event, predicate=predicate, result=result, future=future)
self._dispatch_listeners.append(entry) self._dispatch_listeners.append(entry)
@ -690,7 +692,7 @@ class DiscordWebSocket:
async def access_dm(self, channel_id): async def access_dm(self, channel_id):
payload = { payload = {
'op': self.ACCESS_DM, 'op': self.CALL_CONNECT,
'd': { 'd': {
'channel_id': channel_id 'channel_id': channel_id
} }
@ -699,6 +701,32 @@ class DiscordWebSocket:
_log.debug('Sending ACCESS_DM for channel %s.', channel_id) _log.debug('Sending ACCESS_DM for channel %s.', channel_id)
await self.send_as_json(payload) await self.send_as_json(payload)
async def request_commands(self, guild_id, type, *, nonce=None, limit=None, applications=None, offset=0, query=None, command_ids=None, application_id=None):
payload = {
'op': self.REQUEST_COMMANDS,
'd': {
'guild_id': guild_id,
'type': type,
}
}
if nonce is not None:
payload['d']['nonce'] = nonce
if applications is not None:
payload['d']['applications'] = applications
if limit is not None and limit != 25:
payload['d']['limit'] = limit
if offset:
payload['d']['offset'] = offset
if query is not None:
payload['d']['query'] = query
if command_ids is not None:
payload['d']['command_ids'] = command_ids
if application_id is not None:
payload['d']['application_id'] = application_id
await self.send_as_json(payload)
async def close(self, code=4000): async def close(self, code=4000):
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
@ -720,7 +748,7 @@ class DiscordVoiceWebSocket:
Receive only. Tells the websocket that the initial connection has completed. Receive only. Tells the websocket that the initial connection has completed.
HEARTBEAT HEARTBEAT
Send only. Keeps your websocket connection alive. Send only. Keeps your websocket connection alive.
SESSION_DESCRIPTION SELECT_PROTOCOL_ACK
Receive only. Gives you the secret key required for voice. Receive only. Gives you the secret key required for voice.
SPEAKING SPEAKING
Send and receive. Notifies the client if anyone begins speaking. Send and receive. Notifies the client if anyone begins speaking.
@ -732,24 +760,25 @@ class DiscordVoiceWebSocket:
Receive only. Tells you that your websocket connection was acknowledged. Receive only. Tells you that your websocket connection was acknowledged.
RESUMED RESUMED
Sent only. Tells you that your RESUME request has succeeded. Sent only. Tells you that your RESUME request has succeeded.
CLIENT_CONNECT
Indicates a user has connected to voice.
CLIENT_DISCONNECT CLIENT_DISCONNECT
Receive only. Indicates a user has disconnected from voice. Receive only. Indicates a user has disconnected from voice.
""" """
IDENTIFY = 0 IDENTIFY = 0
SELECT_PROTOCOL = 1 SELECT_PROTOCOL = 1
READY = 2 READY = 2
HEARTBEAT = 3 HEARTBEAT = 3
SESSION_DESCRIPTION = 4 SELECT_PROTOCOL_ACK = 4
SPEAKING = 5 SPEAKING = 5
HEARTBEAT_ACK = 6 HEARTBEAT_ACK = 6
RESUME = 7 RESUME = 7
HELLO = 8 HELLO = 8
RESUMED = 9 RESUMED = 9
CLIENT_CONNECT = 12 VIDEO = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
SESSION_UPDATE = 14
MEDIA_SINK_WANTS = 15
VOICE_BACKEND_VERSION = 16
def __init__(self, socket, loop, *, hook=None): def __init__(self, socket, loop, *, hook=None):
self.ws = socket self.ws = socket
@ -861,7 +890,7 @@ class DiscordVoiceWebSocket:
elif op == self.RESUMED: elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.') _log.info('Voice RESUME succeeded.')
self.secret_key = self._connection.secret_key self.secret_key = self._connection.secret_key
elif op == self.SESSION_DESCRIPTION: elif op == self.SELECT_PROTOCOL_ACK:
self._connection.mode = data['mode'] self._connection.mode = data['mode']
await self.load_secret_key(data) await self.load_secret_key(data)
elif op == self.HELLO: elif op == self.HELLO:

269
discord/iterators.py

@ -26,12 +26,15 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, Tuple, AsyncIterator, Dict
from .errors import NoMoreItems from .errors import InvalidData, NoMoreItems
from .utils import snowflake_time, time_snowflake, maybe_coroutine from .utils import snowflake_time, time_snowflake, maybe_coroutine, utcnow
from .object import Object from .object import Object
from .audit_logs import AuditLogEntry from .audit_logs import AuditLogEntry
from .commands import _command_factory
from .enums import CommandType
from .errors import InvalidArgument
__all__ = ( __all__ = (
'ReactionIterator', 'ReactionIterator',
@ -61,10 +64,11 @@ if TYPE_CHECKING:
from .member import Member from .member import Member
from .user import User from .user import User
from .message import Message from .message import Message
from .audit_logs import AuditLogEntry
from .guild import Guild from .guild import Guild
from .threads import Thread from .threads import Thread
from .abc import Snowflake from .abc import Snowflake, Messageable
from .commands import ApplicationCommand
from .channel import DMChannel
T = TypeVar('T') T = TypeVar('T')
OT = TypeVar('OT') OT = TypeVar('OT')
@ -106,7 +110,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0: if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.') raise ValueError('Chunk size must be greater than 0')
return _ChunkedAsyncIterator(self, max_size) return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@ -156,7 +160,7 @@ class _MappedAsyncIterator(_AsyncIterator[T]):
self.func = func self.func = func
async def next(self) -> T: async def next(self) -> T:
# this raises NoMoreItems and will propagate appropriately # This raises NoMoreItems and will propagate appropriately
item = await self.iterator.next() item = await self.iterator.next()
return await maybe_coroutine(self.func, item) return await maybe_coroutine(self.func, item)
@ -204,7 +208,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
raise NoMoreItems() raise NoMoreItems()
async def fill_users(self): async def fill_users(self):
# this is a hack because >circular imports< # This is a hack because >circular imports<
from .user import User from .user import User
if self.limit > 0: if self.limit > 0:
@ -286,7 +290,7 @@ class HistoryIterator(_AsyncIterator['Message']):
self.after = after or OLDEST_OBJECT self.after = after or OLDEST_OBJECT
self.around = around self.around = around
self._filter = None # message dict -> bool self._filter = None # Message dict -> bool
self.state = self.messageable._state self.state = self.messageable._state
self.logs_from = self.state.http.logs_from self.logs_from = self.state.http.logs_from
@ -298,7 +302,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self.limit > 101: if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter") raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101: elif self.limit == 101:
self.limit = 100 # Thanks discord self.limit = 100 # Thanks Discord
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after: if self.before and self.after:
@ -336,15 +340,14 @@ class HistoryIterator(_AsyncIterator['Message']):
return r > 0 return r > 0
async def fill_messages(self): async def fill_messages(self):
if not hasattr(self, 'channel'): if not hasattr(self, 'channel'): # Do the required set up
# do the required set up
channel = await self.messageable._get_channel() channel = await self.messageable._get_channel()
self.channel = channel self.channel = channel
if self._get_retrieve(): if self._get_retrieve():
data = await self._retrieve_messages(self.retrieve) data = await self._retrieve_messages(self.retrieve)
if len(data) < 100: if len(data) < 100:
self.limit = 0 # terminate the infinite loop self.limit = 0 # Terminate the infinite loop
if self.reverse: if self.reverse:
data = reversed(data) data = reversed(data)
@ -571,8 +574,7 @@ class GuildIterator(_AsyncIterator['Guild']):
async def fill_guilds(self): async def fill_guilds(self):
if self._get_retrieve(): if self._get_retrieve():
data = await self._retrieve_guilds(self.retrieve) data = await self._retrieve_guilds(self.retrieve)
if self.limit is None or len(data) < 200: self.limit = 0 # Max amount of guilds a user can be in is 200
self.limit = 0
if self._filter: if self._filter:
data = filter(self._filter, data) data = filter(self._filter, data)
@ -693,3 +695,240 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread: def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data) return Thread(guild=self.guild, state=self.guild._state, data=data)
def _is_fake(item: Union[Messageable, Message]) -> bool: # I hate this too, but <circular imports> and performance exist
try:
item.guild # type: ignore
except AttributeError:
return True
try:
item.channel.me # type: ignore
except AttributeError:
return False
return True
class CommandIterator(_AsyncIterator['ApplicationCommand']):
def __new__(cls, *args, **kwargs) -> Union[CommandIterator, FakeCommandIterator]:
if _is_fake(args[0]):
return FakeCommandIterator(*args)
else:
return super().__new__(cls)
def __init__(
self,
item: Union[Messageable, Message],
type: CommandType,
query: Optional[str] = None,
limit: Optional[int] = None,
command_ids: Optional[List[int]] = None,
**kwargs,
) -> None:
self.item = item
self.channel = None
self.state = item._state
self._tuple = None
self.type = type
_, self.cls = _command_factory(int(type))
self.query = query
self.limit = limit
self.command_ids = command_ids
self.applications: bool = kwargs.get('applications', True)
self.application: Snowflake = kwargs.get('application', None)
self.commands = asyncio.Queue()
async def _process_args(self) -> Tuple[DMChannel, Optional[str], Optional[Union[User, Message]]]:
item = self.item
if self.type is CommandType.user:
channel = await item._get_channel() # type: ignore
if getattr(item, 'bot', None):
item = item
else:
item = None
text = 'user'
elif self.type is CommandType.message:
message = self.item
channel = message.channel # type: ignore
text = 'message'
elif self.type is CommandType.chat_input:
channel = await item._get_channel() # type: ignore
item = None
text = None
self._process_kwargs(channel) # type: ignore
return channel, text, item # type: ignore
def _process_kwargs(self, channel) -> None:
kwargs = {
'guild_id': channel.guild.id,
'type': self.type.value,
'offset': 0,
}
if self.applications:
kwargs['applications'] = True # Only sent if it's True...
if (app := self.application):
kwargs['application'] = app.id
if (query := self.query) is not None:
kwargs['query'] = query
if (cmds := self.command_ids):
kwargs['command_ids'] = cmds
self.kwargs = kwargs
async def next(self) -> ApplicationCommand:
if self.commands.empty():
await self.fill_commands()
try:
return self.commands.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
def _get_retrieve(self):
l = self.limit
if l is None or l > 100:
r = 100
else:
r = l
self.retrieve = r
return r > 0
async def fill_commands(self) -> None:
if not self._tuple: # Do the required setup
self._tuple = await self._process_args()
if not self._get_retrieve():
return
state = self.state
kwargs = self.kwargs
retrieve = self.retrieve
nonce = str(time_snowflake(utcnow()))
def predicate(d):
return d.get('nonce') == nonce
data = None
for _ in range(3):
await state.ws.request_commands(**kwargs, limit=retrieve, nonce=nonce)
try:
data: Optional[Dict[str, Any]] = await asyncio.wait_for(state.ws.wait_for('guild_application_commands_update', predicate), timeout=3)
except asyncio.TimeoutError:
pass
if data is None:
raise InvalidData('Didn\'t receive a response from Discord')
cmds = data['application_commands']
if len(cmds) < retrieve:
self.limit = 0
elif self.limit is not None:
self.limit -= retrieve
kwargs['offset'] += retrieve
for cmd in cmds:
self.commands.put_nowait(self.create_command(cmd))
for app in data.get('applications', []):
...
def create_command(self, data) -> ApplicationCommand:
channel, item, value = self._tuple # type: ignore
if item is not None:
kwargs = {item: value}
else:
kwargs = {}
return self.cls(state=channel._state, data=data, channel=channel, **kwargs)
class FakeCommandIterator(_AsyncIterator['ApplicationCommand']):
def __init__(
self,
item: Union[User, Message, DMChannel],
type: CommandType,
query: Optional[str] = None,
limit: Optional[int] = None,
command_ids: Optional[List[int]] = None,
) -> None:
self.item = item
self.channel = None
self._tuple = None
self.type = type
_, self.cls = _command_factory(int(type))
self.query = query
self.limit = limit
self.command_ids = command_ids
self.has_more = False
self.commands = asyncio.Queue()
async def _process_args(self) -> Tuple[DMChannel, Optional[str], Optional[Union[User, Message]]]:
item = self.item
if self.type is CommandType.user:
channel = await item._get_channel() # type: ignore
if getattr(item, 'bot', None):
item = item
else:
item = None
text = 'user'
elif self.type is CommandType.message:
message = self.item
channel = message.channel # type: ignore
text = 'message'
elif self.type is CommandType.chat_input:
channel = await item._get_channel() # type: ignore
item = None
text = None
if not channel.recipient.bot:
raise InvalidArgument('User is not a bot')
return channel, text, item # type: ignore
async def next(self) -> ApplicationCommand:
if self.commands.empty():
await self.fill_commands()
try:
return self.commands.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
async def fill_commands(self) -> None:
if self.has_more:
raise NoMoreItems()
if not (stuff := self._tuple):
self._tuple = channel, _, _ = await self._process_args()
else:
channel = stuff[0]
limit = self.limit or -1
data = await channel._state.http.get_application_commands(channel.recipient.id)
ids = self.command_ids
query = self.query and self.query.lower()
type = self.type.value
for cmd in data:
if cmd['type'] != type:
continue
if ids:
if not int(cmd['id']) in ids:
continue
if query:
if not query in cmd['name'].lower():
continue
self.commands.put_nowait(self.create_command(cmd))
limit -= 1
if limit == 0:
break
self.has_more = True
def create_command(self, data) -> ApplicationCommand:
channel, item, value = self._tuple # type: ignore
if item is not None:
kwargs = {item: value}
else:
kwargs = {}
return self.cls(state=channel._state, data=data, channel=channel, **kwargs)

40
discord/utils.py

@ -91,6 +91,7 @@ __all__ = (
'escape_mentions', 'escape_mentions',
'as_chunks', 'as_chunks',
'format_dt', 'format_dt',
'set_target',
) )
DISCORD_EPOCH = 1420070400000 DISCORD_EPOCH = 1420070400000
@ -134,9 +135,11 @@ if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from .permissions import Permissions from .permissions import Permissions
from .abc import Snowflake from .abc import Messageable, Snowflake
from .invite import Invite from .invite import Invite
from .message import Message
from .template import Template from .template import Template
from .commands import ApplicationCommand
class _RequestLike(Protocol): class _RequestLike(Protocol):
headers: Mapping[str, Any] headers: Mapping[str, Any]
@ -1029,6 +1032,41 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
return f'<t:{int(dt.timestamp())}:{style}>' return f'<t:{int(dt.timestamp())}:{style}>'
def set_target(
items: Iterable[ApplicationCommand], *, channel: Messageable = None, message: Message = None, user: Snowflake = None
) -> None:
"""A helper function to set the target for a list of items.
This is used to set the target for a list of application commands.
Suppresses all AttributeErrors so you can pass multiple types of commands and
not worry about which elements support which parameter.
Parameters
-----------
items: Iterable[:class:`ApplicationCommand`]
A list of items to set the target for.
channel: :class:`Messageable`
The channel to target.
message: :class:`Message`
The message to target.
user: :class:`Snowflake`
The user to target.
"""
attrs = {
'target_channel': channel,
'target_message': message,
'target_user': user,
}
for item in items:
for k, v in attrs.items():
if v is not None:
try:
setattr(item, k, v) # type: ignore
except AttributeError:
pass
class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester
def __init__(self, timeout: int, maxsize: int = 0) -> None: def __init__(self, timeout: int, maxsize: int = 0) -> None:
super().__init__(maxsize) super().__init__(maxsize)

Loading…
Cancel
Save