Browse Source

Revert "Refactor Client.run to use asyncio.run"

This reverts commit 6e6c8a7b28.
pull/7378/head
Rapptz 4 years ago
parent
commit
08a4db3961
  1. 200
      discord/client.py
  2. 3
      discord/http.py

200
discord/client.py

@ -26,24 +26,10 @@ from __future__ import annotations
import asyncio
import logging
import signal
import sys
import traceback
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Type,
Union,
)
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
import aiohttp
@ -82,7 +68,6 @@ if TYPE_CHECKING:
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
from types import TracebackType
__all__ = (
'Client',
@ -93,8 +78,36 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log: logging.Logger = logging.getLogger(__name__)
C = TypeVar('C', bound='Client')
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task
})
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
log.info('Closing the event loop.')
loop.close()
class Client:
r"""Represents a client connection that connects to Discord.
@ -187,7 +200,6 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
@ -195,8 +207,7 @@ class Client:
**options: Any,
):
self.ws: DiscordWebSocket = None # type: ignore
# this is filled in later
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
@ -205,16 +216,14 @@ class Client:
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop
)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready,
'ready': self._handle_ready
}
self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook,
'before_identify': self._call_before_identify_hook
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
@ -235,9 +244,8 @@ class Client:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(
dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options
)
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
def _handle_ready(self) -> None:
self._ready.set()
@ -335,9 +343,7 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
async def _run_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> None:
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@ -348,9 +354,7 @@ class Client:
except asyncio.CancelledError:
pass
def _schedule_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
@ -462,8 +466,7 @@ class Client:
"""
log.info('logging in using static token')
self.loop = loop = asyncio.get_running_loop()
self._connection.loop = loop
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@ -509,14 +512,12 @@ class Client:
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
except (
OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
) as exc:
except (OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError) as exc:
self.dispatch('disconnect')
if not reconnect:
@ -557,22 +558,6 @@ class Client:
"""|coro|
Closes the connection to Discord.
Instead of calling this directly, it is recommended to use the asynchronous context
manager to allow resources to be cleaned up automatically:
.. code-block:: python3
async def main():
async with Client() as client:
await client.login(token)
await client.connect()
asyncio.run(main())
.. versionchanged:: 2.0
The client can now be closed with an asynchronous context manager
"""
if self._closed:
return
@ -604,47 +589,36 @@ class Client:
self._connection.clear()
self.http.recreate()
async def __aenter__(self: C) -> C:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
A shorthand function equivalent to the following:
A shorthand coroutine for :meth:`login` + :meth:`connect`.
.. code-block:: python3
async with client:
await client.login(token)
await client.connect()
This closes the client when it returns.
Raises
-------
TypeError
An unexpected keyword argument was received.
"""
try:
await self.login(token)
await self.connect(reconnect=reconnect)
finally:
await self.close()
await self.login(token)
await self.connect(reconnect=reconnect)
def run(self, *args: Any, **kwargs: Any) -> None:
"""A convenience blocking call that abstracts away the event loop
"""A blocking call that abstracts away the event loop
initialisation from you.
If you want more control over the event loop then this
function should not be used. Use :meth:`start` coroutine
or :meth:`connect` + :meth:`login`.
Equivalent to: ::
Roughly Equivalent to: ::
asyncio.run(bot.start(*args, **kwargs))
try:
loop.run_until_complete(start(*args, **kwargs))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
.. warning::
@ -652,7 +626,41 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
asyncio.run(self.start(*args, **kwargs))
loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt:
log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# properties
@ -965,10 +973,8 @@ class Client:
future = self.loop.create_future()
if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
@ -1077,7 +1083,7 @@ class Client:
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None,
after: SnowflakeTime = None
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1157,7 +1163,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
@ -1278,9 +1284,7 @@ class Client:
# Invite management
async def fetch_invite(
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1516,7 +1520,7 @@ class Client:
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|

3
discord/http.py

@ -167,7 +167,7 @@ class HTTPClient:
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True
) -> None:
self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.connector = connector
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
@ -371,7 +371,6 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation
self.loop = asyncio.get_running_loop()
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse)
old_token = self.token
self.token = token

Loading…
Cancel
Save