Browse Source

Revert "Refactor Client.run to use asyncio.run"

This reverts commit 6e6c8a7b28.
pull/7378/head
Rapptz 4 years ago
parent
commit
08a4db3961
  1. 200
      discord/client.py
  2. 3
      discord/http.py

200
discord/client.py

@ -26,24 +26,10 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import signal
import sys import sys
import traceback import traceback
from typing import ( from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
Any,
Callable,
Coroutine,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Type,
Union,
)
import aiohttp import aiohttp
@ -82,7 +68,6 @@ if TYPE_CHECKING:
from .message import Message from .message import Message
from .member import Member from .member import Member
from .voice_client import VoiceProtocol from .voice_client import VoiceProtocol
from types import TracebackType
__all__ = ( __all__ = (
'Client', 'Client',
@ -93,8 +78,36 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__) log: logging.Logger = logging.getLogger(__name__)
C = TypeVar('C', bound='Client') def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task
})
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
log.info('Closing the event loop.')
loop.close()
class Client: class Client:
r"""Represents a client connection that connects to Discord. r"""Represents a client connection that connects to Discord.
@ -187,7 +200,6 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations. The event loop that the client uses for asynchronous operations.
""" """
def __init__( def __init__(
self, self,
*, *,
@ -195,8 +207,7 @@ class Client:
**options: Any, **options: Any,
): ):
self.ws: DiscordWebSocket = None # type: ignore self.ws: DiscordWebSocket = None # type: ignore
# this is filled in later self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id') self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count') self.shard_count: Optional[int] = options.get('shard_count')
@ -205,16 +216,14 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None) proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient( self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
)
self._handlers: Dict[str, Callable] = { self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready, 'ready': self._handle_ready
} }
self._hooks: Dict[str, Callable] = { self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook, 'before_identify': self._call_before_identify_hook
} }
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@ -235,9 +244,8 @@ class Client:
return self.ws return self.ws
def _get_state(self, **options: Any) -> ConnectionState: def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState( return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options hooks=self._hooks, http=self.http, loop=self.loop, **options)
)
def _handle_ready(self) -> None: def _handle_ready(self) -> None:
self._ready.set() self._ready.set()
@ -335,9 +343,7 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use.""" """:class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set() return self._ready.is_set()
async def _run_event( async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> None:
try: try:
await coro(*args, **kwargs) await coro(*args, **kwargs)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -348,9 +354,7 @@ class Client:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
def _schedule_event( def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs) wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task # Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@ -462,8 +466,7 @@ class Client:
""" """
log.info('logging in using static token') log.info('logging in using static token')
self.loop = loop = asyncio.get_running_loop()
self._connection.loop = loop
data = await self.http.static_login(token.strip()) data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data) self._connection.user = ClientUser(state=self._connection, data=data)
@ -509,14 +512,12 @@ class Client:
self.dispatch('disconnect') self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue continue
except ( except (OSError,
OSError, HTTPException,
HTTPException, GatewayNotFound,
GatewayNotFound, ConnectionClosed,
ConnectionClosed, aiohttp.ClientError,
aiohttp.ClientError, asyncio.TimeoutError) as exc:
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect') self.dispatch('disconnect')
if not reconnect: if not reconnect:
@ -557,22 +558,6 @@ class Client:
"""|coro| """|coro|
Closes the connection to Discord. Closes the connection to Discord.
Instead of calling this directly, it is recommended to use the asynchronous context
manager to allow resources to be cleaned up automatically:
.. code-block:: python3
async def main():
async with Client() as client:
await client.login(token)
await client.connect()
asyncio.run(main())
.. versionchanged:: 2.0
The client can now be closed with an asynchronous context manager
""" """
if self._closed: if self._closed:
return return
@ -604,47 +589,36 @@ class Client:
self._connection.clear() self._connection.clear()
self.http.recreate() self.http.recreate()
async def __aenter__(self: C) -> C:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()
async def start(self, token: str, *, reconnect: bool = True) -> None: async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro| """|coro|
A shorthand function equivalent to the following: A shorthand coroutine for :meth:`login` + :meth:`connect`.
.. code-block:: python3 Raises
-------
async with client: TypeError
await client.login(token) An unexpected keyword argument was received.
await client.connect()
This closes the client when it returns.
""" """
try: await self.login(token)
await self.login(token) await self.connect(reconnect=reconnect)
await self.connect(reconnect=reconnect)
finally:
await self.close()
def run(self, *args: Any, **kwargs: Any) -> None: def run(self, *args: Any, **kwargs: Any) -> None:
"""A convenience blocking call that abstracts away the event loop """A blocking call that abstracts away the event loop
initialisation from you. initialisation from you.
If you want more control over the event loop then this If you want more control over the event loop then this
function should not be used. Use :meth:`start` coroutine function should not be used. Use :meth:`start` coroutine
or :meth:`connect` + :meth:`login`. or :meth:`connect` + :meth:`login`.
Equivalent to: :: Roughly Equivalent to: ::
asyncio.run(bot.start(*args, **kwargs)) try:
loop.run_until_complete(start(*args, **kwargs))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
.. warning:: .. warning::
@ -652,7 +626,41 @@ class Client:
is blocking. That means that registration of events or anything being is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns. called after this function call will not execute until it returns.
""" """
asyncio.run(self.start(*args, **kwargs)) loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt:
log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# properties # properties
@ -965,10 +973,8 @@ class Client:
future = self.loop.create_future() future = self.loop.create_future()
if check is None: if check is None:
def _check(*args): def _check(*args):
return True return True
check = _check check = _check
ev = event.lower() ev = event.lower()
@ -1077,7 +1083,7 @@ class Client:
*, *,
limit: Optional[int] = 100, limit: Optional[int] = 100,
before: SnowflakeTime = None, before: SnowflakeTime = None,
after: SnowflakeTime = None, after: SnowflakeTime = None
) -> GuildIterator: ) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1157,7 +1163,7 @@ class Client:
""" """
code = utils.resolve_template(code) code = utils.resolve_template(code)
data = await self.http.get_template(code) data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild: async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro| """|coro|
@ -1278,9 +1284,7 @@ class Client:
# Invite management # Invite management
async def fetch_invite( async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
"""|coro| """|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID. Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1516,7 +1520,7 @@ class Client:
""" """
data = await self.http.get_sticker(sticker_id) data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]: async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro| """|coro|

3
discord/http.py

@ -167,7 +167,7 @@ class HTTPClient:
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True unsync_clock: bool = True
) -> None: ) -> None:
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.connector = connector self.connector = connector
self.__session: aiohttp.ClientSession = MISSING # filled in static_login self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
@ -371,7 +371,6 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User: async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation # Necessary to get aiohttp to stop complaining about session creation
self.loop = asyncio.get_running_loop()
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
old_token = self.token old_token = self.token
self.token = token self.token = token

Loading…
Cancel
Save