Browse Source

Refactor Client.run to use asyncio.run

This also adds asynchronous context manager support to allow for
idiomatic asyncio usage for the lower-level counterpart. At first
I wanted to remove Client.run but I figured that a lot of beginners
would have been confused or not enjoyed the verbosity of the newer
approach of using async-with.
pull/7378/head
Rapptz 4 years ago
parent
commit
6e6c8a7b28
  1. 200
      discord/client.py
  2. 3
      discord/http.py

200
discord/client.py

@ -26,10 +26,24 @@ from __future__ import annotations
import asyncio
import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Type,
Union,
)
import aiohttp
@ -68,6 +82,7 @@ if TYPE_CHECKING:
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
from types import TracebackType
__all__ = (
'Client',
@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task
})
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
log.info('Closing the event loop.')
loop.close()
C = TypeVar('C', bound='Client')
class Client:
r"""Represents a client connection that connects to Discord.
@ -200,6 +187,7 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
@ -207,7 +195,8 @@ class Client:
**options: Any,
):
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
# this is filled in later
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
@ -216,14 +205,16 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready
'ready': self._handle_ready,
}
self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook
'before_identify': self._call_before_identify_hook,
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@ -244,8 +235,9 @@ class Client:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
return ConnectionState(
dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
)
def _handle_ready(self) -> None:
self._ready.set()
@ -343,7 +335,9 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
async def _run_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@ -354,7 +348,9 @@ class Client:
except asyncio.CancelledError:
pass
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
def _schedule_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@ -466,7 +462,8 @@ class Client:
"""
log.info('logging in using static token')
self.loop = loop = asyncio.get_running_loop()
self._connection.loop = loop
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@ -512,12 +509,14 @@ class Client:
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
except (OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError) as exc:
except (
OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect')
if not reconnect:
@ -558,6 +557,22 @@ class Client:
"""|coro|
Closes the connection to Discord.
Instead of calling this directly, it is recommended to use the asynchronous context
manager to allow resources to be cleaned up automatically:
.. code-block:: python3
async def main():
async with Client() as client:
await client.login(token)
await client.connect()
asyncio.run(main())
.. versionchanged:: 2.0
The client can now be closed with an asynchronous context manager
"""
if self._closed:
return
@ -589,36 +604,47 @@ class Client:
self._connection.clear()
self.http.recreate()
async def __aenter__(self: C) -> C:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
A shorthand coroutine for :meth:`login` + :meth:`connect`.
A shorthand function equivalent to the following:
Raises
-------
TypeError
An unexpected keyword argument was received.
.. code-block:: python3
async with client:
await client.login(token)
await client.connect()
This closes the client when it returns.
"""
await self.login(token)
await self.connect(reconnect=reconnect)
try:
await self.login(token)
await self.connect(reconnect=reconnect)
finally:
await self.close()
def run(self, *args: Any, **kwargs: Any) -> None:
"""A blocking call that abstracts away the event loop
"""A convenience blocking call that abstracts away the event loop
initialisation from you.
If you want more control over the event loop then this
function should not be used. Use :meth:`start` coroutine
or :meth:`connect` + :meth:`login`.
Roughly Equivalent to: ::
Equivalent to: ::
try:
loop.run_until_complete(start(*args, **kwargs))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
asyncio.run(bot.start(*args, **kwargs))
.. warning::
@ -626,41 +652,7 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt:
log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
asyncio.run(self.start(*args, **kwargs))
# properties
@ -973,8 +965,10 @@ class Client:
future = self.loop.create_future()
if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
@ -1083,7 +1077,7 @@ class Client:
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
after: SnowflakeTime = None,
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1163,7 +1157,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
@ -1284,7 +1278,9 @@ class Client:
# Invite management
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
async def fetch_invite(
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1520,7 +1516,7 @@ class Client:
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|

3
discord/http.py

@ -167,7 +167,7 @@ class HTTPClient:
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True
) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login
self.connector = connector
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
@ -371,6 +371,7 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation
self.loop = asyncio.get_running_loop()
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
old_token = self.token
self.token = token

Loading…
Cancel
Save