|
|
@ -26,24 +26,10 @@ from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import logging |
|
|
|
import signal |
|
|
|
import sys |
|
|
|
import traceback |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Callable, |
|
|
|
Coroutine, |
|
|
|
Dict, |
|
|
|
Generator, |
|
|
|
Iterable, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
TYPE_CHECKING, |
|
|
|
Tuple, |
|
|
|
TypeVar, |
|
|
|
Type, |
|
|
|
Union, |
|
|
|
) |
|
|
|
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union |
|
|
|
|
|
|
|
import aiohttp |
|
|
|
|
|
|
@ -82,7 +68,6 @@ if TYPE_CHECKING: |
|
|
|
from .message import Message |
|
|
|
from .member import Member |
|
|
|
from .voice_client import VoiceProtocol |
|
|
|
from types import TracebackType |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'Client', |
|
|
@ -93,8 +78,36 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) |
|
|
|
|
|
|
|
log: logging.Logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
C = TypeVar('C', bound='Client') |
|
|
|
|
|
|
|
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} |
|
|
|
|
|
|
|
if not tasks: |
|
|
|
return |
|
|
|
|
|
|
|
log.info('Cleaning up after %d tasks.', len(tasks)) |
|
|
|
for task in tasks: |
|
|
|
task.cancel() |
|
|
|
|
|
|
|
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) |
|
|
|
log.info('All tasks finished cancelling.') |
|
|
|
|
|
|
|
for task in tasks: |
|
|
|
if task.cancelled(): |
|
|
|
continue |
|
|
|
if task.exception() is not None: |
|
|
|
loop.call_exception_handler({ |
|
|
|
'message': 'Unhandled exception during Client.run shutdown.', |
|
|
|
'exception': task.exception(), |
|
|
|
'task': task |
|
|
|
}) |
|
|
|
|
|
|
|
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
try: |
|
|
|
_cancel_tasks(loop) |
|
|
|
loop.run_until_complete(loop.shutdown_asyncgens()) |
|
|
|
finally: |
|
|
|
log.info('Closing the event loop.') |
|
|
|
loop.close() |
|
|
|
|
|
|
|
class Client: |
|
|
|
r"""Represents a client connection that connects to Discord. |
|
|
@ -187,7 +200,6 @@ class Client: |
|
|
|
loop: :class:`asyncio.AbstractEventLoop` |
|
|
|
The event loop that the client uses for asynchronous operations. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
*, |
|
|
@ -195,8 +207,7 @@ class Client: |
|
|
|
**options: Any, |
|
|
|
): |
|
|
|
self.ws: DiscordWebSocket = None # type: ignore |
|
|
|
# this is filled in later |
|
|
|
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop |
|
|
|
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop |
|
|
|
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} |
|
|
|
self.shard_id: Optional[int] = options.get('shard_id') |
|
|
|
self.shard_count: Optional[int] = options.get('shard_count') |
|
|
@ -205,16 +216,14 @@ class Client: |
|
|
|
proxy: Optional[str] = options.pop('proxy', None) |
|
|
|
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) |
|
|
|
unsync_clock: bool = options.pop('assume_unsync_clock', True) |
|
|
|
self.http: HTTPClient = HTTPClient( |
|
|
|
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop |
|
|
|
) |
|
|
|
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) |
|
|
|
|
|
|
|
self._handlers: Dict[str, Callable] = { |
|
|
|
'ready': self._handle_ready, |
|
|
|
'ready': self._handle_ready |
|
|
|
} |
|
|
|
|
|
|
|
self._hooks: Dict[str, Callable] = { |
|
|
|
'before_identify': self._call_before_identify_hook, |
|
|
|
'before_identify': self._call_before_identify_hook |
|
|
|
} |
|
|
|
|
|
|
|
self._enable_debug_events: bool = options.pop('enable_debug_events', False) |
|
|
@ -235,9 +244,8 @@ class Client: |
|
|
|
return self.ws |
|
|
|
|
|
|
|
def _get_state(self, **options: Any) -> ConnectionState: |
|
|
|
return ConnectionState( |
|
|
|
dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options |
|
|
|
) |
|
|
|
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, |
|
|
|
hooks=self._hooks, http=self.http, loop=self.loop, **options) |
|
|
|
|
|
|
|
def _handle_ready(self) -> None: |
|
|
|
self._ready.set() |
|
|
@ -335,9 +343,7 @@ class Client: |
|
|
|
""":class:`bool`: Specifies if the client's internal cache is ready for use.""" |
|
|
|
return self._ready.is_set() |
|
|
|
|
|
|
|
async def _run_event( |
|
|
|
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any |
|
|
|
) -> None: |
|
|
|
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: |
|
|
|
try: |
|
|
|
await coro(*args, **kwargs) |
|
|
|
except asyncio.CancelledError: |
|
|
@ -348,9 +354,7 @@ class Client: |
|
|
|
except asyncio.CancelledError: |
|
|
|
pass |
|
|
|
|
|
|
|
def _schedule_event( |
|
|
|
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any |
|
|
|
) -> asyncio.Task: |
|
|
|
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: |
|
|
|
wrapped = self._run_event(coro, event_name, *args, **kwargs) |
|
|
|
# Schedules the task |
|
|
|
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') |
|
|
@ -462,8 +466,7 @@ class Client: |
|
|
|
""" |
|
|
|
|
|
|
|
log.info('logging in using static token') |
|
|
|
self.loop = loop = asyncio.get_running_loop() |
|
|
|
self._connection.loop = loop |
|
|
|
|
|
|
|
data = await self.http.static_login(token.strip()) |
|
|
|
self._connection.user = ClientUser(state=self._connection, data=data) |
|
|
|
|
|
|
@ -509,14 +512,12 @@ class Client: |
|
|
|
self.dispatch('disconnect') |
|
|
|
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) |
|
|
|
continue |
|
|
|
except ( |
|
|
|
OSError, |
|
|
|
HTTPException, |
|
|
|
GatewayNotFound, |
|
|
|
ConnectionClosed, |
|
|
|
aiohttp.ClientError, |
|
|
|
asyncio.TimeoutError, |
|
|
|
) as exc: |
|
|
|
except (OSError, |
|
|
|
HTTPException, |
|
|
|
GatewayNotFound, |
|
|
|
ConnectionClosed, |
|
|
|
aiohttp.ClientError, |
|
|
|
asyncio.TimeoutError) as exc: |
|
|
|
|
|
|
|
self.dispatch('disconnect') |
|
|
|
if not reconnect: |
|
|
@ -557,22 +558,6 @@ class Client: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Closes the connection to Discord. |
|
|
|
|
|
|
|
Instead of calling this directly, it is recommended to use the asynchronous context |
|
|
|
manager to allow resources to be cleaned up automatically: |
|
|
|
|
|
|
|
.. code-block:: python3 |
|
|
|
|
|
|
|
async def main(): |
|
|
|
async with Client() as client: |
|
|
|
await client.login(token) |
|
|
|
await client.connect() |
|
|
|
|
|
|
|
asyncio.run(main()) |
|
|
|
|
|
|
|
|
|
|
|
.. versionchanged:: 2.0 |
|
|
|
The client can now be closed with an asynchronous context manager |
|
|
|
""" |
|
|
|
if self._closed: |
|
|
|
return |
|
|
@ -604,47 +589,36 @@ class Client: |
|
|
|
self._connection.clear() |
|
|
|
self.http.recreate() |
|
|
|
|
|
|
|
async def __aenter__(self: C) -> C: |
|
|
|
return self |
|
|
|
|
|
|
|
async def __aexit__( |
|
|
|
self, |
|
|
|
exc_type: Optional[Type[BaseException]], |
|
|
|
exc_value: Optional[BaseException], |
|
|
|
traceback: Optional[TracebackType], |
|
|
|
) -> None: |
|
|
|
await self.close() |
|
|
|
|
|
|
|
async def start(self, token: str, *, reconnect: bool = True) -> None: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
A shorthand function equivalent to the following: |
|
|
|
A shorthand coroutine for :meth:`login` + :meth:`connect`. |
|
|
|
|
|
|
|
.. code-block:: python3 |
|
|
|
|
|
|
|
async with client: |
|
|
|
await client.login(token) |
|
|
|
await client.connect() |
|
|
|
|
|
|
|
This closes the client when it returns. |
|
|
|
Raises |
|
|
|
------- |
|
|
|
TypeError |
|
|
|
An unexpected keyword argument was received. |
|
|
|
""" |
|
|
|
try: |
|
|
|
await self.login(token) |
|
|
|
await self.connect(reconnect=reconnect) |
|
|
|
finally: |
|
|
|
await self.close() |
|
|
|
await self.login(token) |
|
|
|
await self.connect(reconnect=reconnect) |
|
|
|
|
|
|
|
def run(self, *args: Any, **kwargs: Any) -> None: |
|
|
|
"""A convenience blocking call that abstracts away the event loop |
|
|
|
"""A blocking call that abstracts away the event loop |
|
|
|
initialisation from you. |
|
|
|
|
|
|
|
If you want more control over the event loop then this |
|
|
|
function should not be used. Use :meth:`start` coroutine |
|
|
|
or :meth:`connect` + :meth:`login`. |
|
|
|
|
|
|
|
Equivalent to: :: |
|
|
|
Roughly Equivalent to: :: |
|
|
|
|
|
|
|
asyncio.run(bot.start(*args, **kwargs)) |
|
|
|
try: |
|
|
|
loop.run_until_complete(start(*args, **kwargs)) |
|
|
|
except KeyboardInterrupt: |
|
|
|
loop.run_until_complete(close()) |
|
|
|
# cancel all tasks lingering |
|
|
|
finally: |
|
|
|
loop.close() |
|
|
|
|
|
|
|
.. warning:: |
|
|
|
|
|
|
@ -652,7 +626,41 @@ class Client: |
|
|
|
is blocking. That means that registration of events or anything being |
|
|
|
called after this function call will not execute until it returns. |
|
|
|
""" |
|
|
|
asyncio.run(self.start(*args, **kwargs)) |
|
|
|
loop = self.loop |
|
|
|
|
|
|
|
try: |
|
|
|
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) |
|
|
|
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) |
|
|
|
except NotImplementedError: |
|
|
|
pass |
|
|
|
|
|
|
|
async def runner(): |
|
|
|
try: |
|
|
|
await self.start(*args, **kwargs) |
|
|
|
finally: |
|
|
|
if not self.is_closed(): |
|
|
|
await self.close() |
|
|
|
|
|
|
|
def stop_loop_on_completion(f): |
|
|
|
loop.stop() |
|
|
|
|
|
|
|
future = asyncio.ensure_future(runner(), loop=loop) |
|
|
|
future.add_done_callback(stop_loop_on_completion) |
|
|
|
try: |
|
|
|
loop.run_forever() |
|
|
|
except KeyboardInterrupt: |
|
|
|
log.info('Received signal to terminate bot and event loop.') |
|
|
|
finally: |
|
|
|
future.remove_done_callback(stop_loop_on_completion) |
|
|
|
log.info('Cleaning up tasks.') |
|
|
|
_cleanup_loop(loop) |
|
|
|
|
|
|
|
if not future.cancelled(): |
|
|
|
try: |
|
|
|
return future.result() |
|
|
|
except KeyboardInterrupt: |
|
|
|
# I am unsure why this gets raised here but suppress it anyway |
|
|
|
return None |
|
|
|
|
|
|
|
# properties |
|
|
|
|
|
|
@ -965,10 +973,8 @@ class Client: |
|
|
|
|
|
|
|
future = self.loop.create_future() |
|
|
|
if check is None: |
|
|
|
|
|
|
|
def _check(*args): |
|
|
|
return True |
|
|
|
|
|
|
|
check = _check |
|
|
|
|
|
|
|
ev = event.lower() |
|
|
@ -1077,7 +1083,7 @@ class Client: |
|
|
|
*, |
|
|
|
limit: Optional[int] = 100, |
|
|
|
before: SnowflakeTime = None, |
|
|
|
after: SnowflakeTime = None, |
|
|
|
after: SnowflakeTime = None |
|
|
|
) -> GuildIterator: |
|
|
|
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. |
|
|
|
|
|
|
@ -1157,7 +1163,7 @@ class Client: |
|
|
|
""" |
|
|
|
code = utils.resolve_template(code) |
|
|
|
data = await self.http.get_template(code) |
|
|
|
return Template(data=data, state=self._connection) # type: ignore |
|
|
|
return Template(data=data, state=self._connection) # type: ignore |
|
|
|
|
|
|
|
async def fetch_guild(self, guild_id: int) -> Guild: |
|
|
|
"""|coro| |
|
|
@ -1278,9 +1284,7 @@ class Client: |
|
|
|
|
|
|
|
# Invite management |
|
|
|
|
|
|
|
async def fetch_invite( |
|
|
|
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True |
|
|
|
) -> Invite: |
|
|
|
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: |
|
|
|
"""|coro| |
|
|
|
|
|
|
|
Gets an :class:`.Invite` from a discord.gg URL or ID. |
|
|
@ -1516,7 +1520,7 @@ class Client: |
|
|
|
""" |
|
|
|
data = await self.http.get_sticker(sticker_id) |
|
|
|
cls, _ = _sticker_factory(data['type']) # type: ignore |
|
|
|
return cls(state=self._connection, data=data) # type: ignore |
|
|
|
return cls(state=self._connection, data=data) # type: ignore |
|
|
|
|
|
|
|
async def fetch_premium_sticker_packs(self) -> List[StickerPack]: |
|
|
|
"""|coro| |
|
|
|