Browse Source

Refactor loop code to allow usage of asyncio.run

pull/7674/head
Han Seung Min - 한승민 3 years ago
committed by GitHub
parent
commit
93af158b0c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 114
      discord/client.py
  2. 2
      discord/ext/commands/cooldowns.py
  3. 24
      discord/ext/tasks/__init__.py
  4. 12
      discord/http.py
  5. 8
      discord/player.py
  6. 8
      discord/shard.py
  7. 2
      discord/state.py
  8. 6
      discord/ui/view.py
  9. 4
      discord/voice_client.py

114
discord/client.py

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import logging import logging
import signal
import sys import sys
import traceback import traceback
from typing import ( from typing import (
@ -97,41 +96,6 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
loop.close()
class Client: class Client:
r"""Represents a client connection that connects to Discord. r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -146,12 +110,6 @@ class Client:
.. versionchanged:: 1.3 .. versionchanged:: 1.3
Allow disabling the message cache and change the default size to ``1000``. Allow disabling the message cache and change the default size to ``1000``.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`] proxy: Optional[:class:`str`]
Proxy URL. Proxy URL.
proxy_auth: Optional[:class:`aiohttp.BasicAuth`] proxy_auth: Optional[:class:`aiohttp.BasicAuth`]
@ -220,30 +178,23 @@ class Client:
----------- -----------
ws ws
The websocket gateway the client is currently connected to. Could be ``None``. The websocket gateway the client is currently connected to. Could be ``None``.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
""" """
def __init__( def __init__(
self, self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any, **options: Any,
): ):
self.loop: asyncio.AbstractEventLoop = MISSING
# self.ws is set in the connect method # self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id') self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count') self.shard_count: Optional[int] = options.get('shard_count')
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None) proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient( self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock)
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
)
self._handlers: Dict[str, Callable] = { self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready, 'ready': self._handle_ready,
@ -399,7 +350,7 @@ class Client:
) -> asyncio.Task: ) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs) wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task # Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') return self.loop.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s', event) _log.debug('Dispatching event %s', event)
@ -623,6 +574,7 @@ class Client:
await self.http.close() await self.http.close()
self._ready.clear() self._ready.clear()
self.loop = MISSING
def clear(self) -> None: def clear(self) -> None:
"""Clears the internal state of the bot. """Clears the internal state of the bot.
@ -646,8 +598,15 @@ class Client:
TypeError TypeError
An unexpected keyword argument was received. An unexpected keyword argument was received.
""" """
await self.login(token) self.loop = asyncio.get_running_loop()
await self.connect(reconnect=reconnect) self.http.loop = self.loop
self._connection.loop = self.loop
try:
await self.login(token)
await self.connect(reconnect=reconnect)
finally:
if not self.is_closed():
await self.close()
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""|coro| """|coro|
@ -676,12 +635,9 @@ class Client:
Roughly Equivalent to: :: Roughly Equivalent to: ::
try: try:
loop.run_until_complete(start(*args, **kwargs)) asyncio.run(self.start(*args, **kwargs))
except KeyboardInterrupt: except KeyboardInterrupt:
loop.run_until_complete(close()) return
# cancel all tasks lingering
finally:
loop.close()
.. warning:: .. warning::
@ -689,41 +645,13 @@ class Client:
is blocking. That means that registration of events or anything being is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns. called after this function call will not execute until it returns.
""" """
loop = self.loop
try: try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) asyncio.run(self.start(*args, **kwargs))
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.') # nothing to do here
finally: # `asyncio.run` handles the loop cleanup
future.remove_done_callback(stop_loop_on_completion) # and `self.start` closes all sockets and the HTTPClient instance.
_log.info('Cleaning up tasks.') return
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# properties # properties
@ -1324,7 +1252,7 @@ class Client:
""" """
code = utils.resolve_template(code) code = utils.resolve_template(code)
data = await self.http.get_template(code) data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore return Template(data=data, state=self._connection)
async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild: async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild:
"""|coro| """|coro|

2
discord/ext/commands/cooldowns.py

@ -297,7 +297,7 @@ class _Semaphore:
def __init__(self, number: int) -> None: def __init__(self, number: int) -> None:
self.value: int = number self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._waiters: Deque[asyncio.Future] = deque() self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str: def __repr__(self) -> str:

24
discord/ext/tasks/__init__.py

@ -101,11 +101,9 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]], time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int], count: Optional[int],
reconnect: bool, reconnect: bool,
loop: asyncio.AbstractEventLoop,
) -> None: ) -> None:
self.coro: LF = coro self.coro: LF = coro
self.reconnect: bool = reconnect self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle: Optional[SleepHandle] = None self._handle: Optional[SleepHandle] = None
@ -147,7 +145,7 @@ class Loop(Generic[LF]):
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime): def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
return self._handle.wait() return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None: async def _loop(self, *args: Any, **kwargs: Any) -> None:
@ -219,7 +217,6 @@ class Loop(Generic[LF]):
time=self._time, time=self._time,
count=self.count, count=self.count,
reconnect=self.reconnect, reconnect=self.reconnect,
loop=self.loop,
) )
copy._injected = obj copy._injected = obj
copy._before_loop = self._before_loop copy._before_loop = self._before_loop
@ -332,10 +329,7 @@ class Loop(Generic[LF]):
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
if self.loop is MISSING: self._task = asyncio.create_task(self._loop(*args, **kwargs))
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task return self._task
def stop(self) -> None: def stop(self) -> None:
@ -740,9 +734,6 @@ def loop(
Whether to handle errors and restart the task Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`. one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
Raises Raises
-------- --------
@ -754,15 +745,6 @@ def loop(
""" """
def decorator(func: LF) -> Loop[LF]: def decorator(func: LF) -> Loop[LF]:
return Loop[LF]( return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect)
func,
seconds=seconds,
minutes=minutes,
hours=hours,
count=count,
time=time,
reconnect=reconnect,
loop=loop,
)
return decorator return decorator

12
discord/http.py

@ -332,15 +332,15 @@ class HTTPClient:
def __init__( def __init__(
self, self,
loop: asyncio.AbstractEventLoop,
connector: Optional[aiohttp.BaseConnector] = None, connector: Optional[aiohttp.BaseConnector] = None,
*, *,
proxy: Optional[str] = None, proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True, unsync_clock: bool = True,
) -> None: ) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or aiohttp.TCPConnector(limit=0) self.connector: aiohttp.BaseConnector = connector or MISSING
self.__session: aiohttp.ClientSession = MISSING # filled in static_login self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self._global_over: asyncio.Event = asyncio.Event() self._global_over: asyncio.Event = asyncio.Event()
@ -544,7 +544,11 @@ class HTTPClient:
async def static_login(self, token: str) -> user.User: async def static_login(self, token: str) -> user.User:
# Necessary to get aiohttp to stop complaining about session creation # Necessary to get aiohttp to stop complaining about session creation
self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) if self.connector is MISSING:
self.connector = aiohttp.TCPConnector(loop=self.loop, limit=0)
self.__session = aiohttp.ClientSession(
connector=self.connector, ws_response_class=DiscordClientWebSocketResponse, loop=self.loop
)
old_token = self.token old_token = self.token
self.token = token self.token = token

8
discord/player.py

@ -521,9 +521,9 @@ class FFmpegOpusAudio(FFmpegAudio):
raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'") raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'")
codec = bitrate = None codec = bitrate = None
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception: except Exception:
if not fallback: if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable) _log.exception("Probe '%s' using '%s' failed", method, executable)
@ -531,7 +531,7 @@ class FFmpegOpusAudio(FFmpegAudio):
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception: except Exception:
_log.exception("Fallback probe using '%s' failed", executable) _log.exception("Fallback probe using '%s' failed", executable)
else: else:
@ -744,6 +744,6 @@ class AudioPlayer(threading.Thread):
def _speak(self, speaking: SpeakingState) -> None: def _speak(self, speaking: SpeakingState) -> None:
try: try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop)
except Exception as e: except Exception as e:
_log.info("Speaking call in player failed: %s", e) _log.info("Speaking call in player failed: %s", e)

8
discord/shard.py

@ -95,7 +95,6 @@ class Shard:
self._client: Client = client self._client: Client = client
self._dispatch: Callable[..., None] = client.dispatch self._dispatch: Callable[..., None] = client.dispatch
self._queue_put: Callable[[EventItem], None] = queue_put self._queue_put: Callable[[EventItem], None] = queue_put
self.loop: asyncio.AbstractEventLoop = self._client.loop
self._disconnect: bool = False self._disconnect: bool = False
self._reconnect = client._reconnect self._reconnect = client._reconnect
self._backoff: ExponentialBackoff = ExponentialBackoff() self._backoff: ExponentialBackoff = ExponentialBackoff()
@ -115,7 +114,7 @@ class Shard:
return self.ws.shard_id # type: ignore return self.ws.shard_id # type: ignore
def launch(self) -> None: def launch(self) -> None:
self._task = self.loop.create_task(self.worker()) self._task = self._client.loop.create_task(self.worker())
def _cancel_task(self) -> None: def _cancel_task(self) -> None:
if self._task is not None and not self._task.done(): if self._task is not None and not self._task.done():
@ -318,10 +317,10 @@ class AutoShardedClient(Client):
if TYPE_CHECKING: if TYPE_CHECKING:
_connection: AutoShardedConnectionState _connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop('shard_id', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs) super().__init__(*args, **kwargs)
if self.shard_ids is not None: if self.shard_ids is not None:
if self.shard_count is None: if self.shard_count is None:
@ -348,7 +347,6 @@ class AutoShardedClient(Client):
handlers=self._handlers, handlers=self._handlers,
hooks=self._hooks, hooks=self._hooks,
http=self.http, http=self.http,
loop=self.loop,
**options, **options,
) )

2
discord/state.py

@ -569,7 +569,7 @@ class ConnectionState:
else: else:
self.application_id = utils._get_as_snowflake(application, 'id') self.application_id = utils._get_as_snowflake(application, 'id')
# flags will always be present here # flags will always be present here
self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore self.application_flags = ApplicationFlags._from_value(application['flags'])
for guild_data in data['guilds']: for guild_data in data['guilds']:
self._add_guild_from_data(guild_data) # type: ignore self._add_guild_from_data(guild_data) # type: ignore

6
discord/ui/view.py

@ -177,12 +177,11 @@ class View:
self.timeout = timeout self.timeout = timeout
self.children: List[Item] = self._init_children() self.children: List[Item] = self._init_children()
self.__weights = _ViewWeights(self.children) self.__weights = _ViewWeights(self.children)
loop = asyncio.get_running_loop()
self.id: str = os.urandom(16).hex() self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None self.__cancel_callback: Optional[Callable[[View], None]] = None
self.__timeout_expiry: Optional[float] = None self.__timeout_expiry: Optional[float] = None
self.__timeout_task: Optional[asyncio.Task[None]] = None self.__timeout_task: Optional[asyncio.Task[None]] = None
self.__stopped: asyncio.Future[bool] = loop.create_future() self.__stopped: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>' return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
@ -379,12 +378,11 @@ class View:
def _start_listening_from_store(self, store: ViewStore) -> None: def _start_listening_from_store(self, store: ViewStore) -> None:
self.__cancel_callback = partial(store.remove_view) self.__cancel_callback = partial(store.remove_view)
if self.timeout: if self.timeout:
loop = asyncio.get_running_loop()
if self.__timeout_task is not None: if self.__timeout_task is not None:
self.__timeout_task.cancel() self.__timeout_task.cancel()
self.__timeout_expiry = time.monotonic() + self.timeout self.__timeout_expiry = time.monotonic() + self.timeout
self.__timeout_task = loop.create_task(self.__timeout_task_impl()) self.__timeout_task = asyncio.create_task(self.__timeout_task_impl())
def _dispatch_timeout(self): def _dispatch_timeout(self):
if self.__stopped.done(): if self.__stopped.done():

4
discord/voice_client.py

@ -222,8 +222,6 @@ class VoiceClient(VoiceProtocol):
The endpoint we are connecting to. The endpoint we are connecting to.
channel: Union[:class:`VoiceChannel`, :class:`StageChannel`] channel: Union[:class:`VoiceChannel`, :class:`StageChannel`]
The voice channel connected to. The voice channel connected to.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
""" """
channel: VocalGuildChannel channel: VocalGuildChannel
@ -405,7 +403,7 @@ class VoiceClient(VoiceProtocol):
raise raise
if self._runner is MISSING: if self._runner is MISSING:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool: async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early # Attempt to stop the player thread from playing early

Loading…
Cancel
Save