@ -26,10 +26,24 @@ from __future__ import annotations
import asyncio
import logging
import signal
import sys
import traceback
from typing import Any , Callable , Coroutine , Dict , Generator , Iterable , List , Optional , Sequence , TYPE_CHECKING , Tuple , TypeVar , Union
from typing import (
Any ,
Callable ,
Coroutine ,
Dict ,
Generator ,
Iterable ,
List ,
Optional ,
Sequence ,
TYPE_CHECKING ,
Tuple ,
TypeVar ,
Type ,
Union ,
)
import aiohttp
@ -68,6 +82,7 @@ if TYPE_CHECKING:
from . message import Message
from . member import Member
from . voice_client import VoiceProtocol
from types import TracebackType
__all__ = (
' Client ' ,
@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log : logging . Logger = logging . getLogger ( __name__ )
def _cancel_tasks ( loop : asyncio . AbstractEventLoop ) - > None :
tasks = { t for t in asyncio . all_tasks ( loop = loop ) if not t . done ( ) }
if not tasks :
return
log . info ( ' Cleaning up after %d tasks. ' , len ( tasks ) )
for task in tasks :
task . cancel ( )
loop . run_until_complete ( asyncio . gather ( * tasks , return_exceptions = True ) )
log . info ( ' All tasks finished cancelling. ' )
for task in tasks :
if task . cancelled ( ) :
continue
if task . exception ( ) is not None :
loop . call_exception_handler ( {
' message ' : ' Unhandled exception during Client.run shutdown. ' ,
' exception ' : task . exception ( ) ,
' task ' : task
} )
def _cleanup_loop ( loop : asyncio . AbstractEventLoop ) - > None :
try :
_cancel_tasks ( loop )
loop . run_until_complete ( loop . shutdown_asyncgens ( ) )
finally :
log . info ( ' Closing the event loop. ' )
loop . close ( )
C = TypeVar ( ' C ' , bound = ' Client ' )
class Client :
r """ Represents a client connection that connects to Discord.
@ -200,6 +187,7 @@ class Client:
loop : : class : ` asyncio . AbstractEventLoop `
The event loop that the client uses for asynchronous operations .
"""
def __init__ (
self ,
* ,
@ -207,7 +195,8 @@ class Client:
* * options : Any ,
) :
self . ws : DiscordWebSocket = None # type: ignore
self . loop : asyncio . AbstractEventLoop = asyncio . get_event_loop ( ) if loop is None else loop
# this is filled in later
self . loop : asyncio . AbstractEventLoop = MISSING if loop is None else loop
self . _listeners : Dict [ str , List [ Tuple [ asyncio . Future , Callable [ . . . , bool ] ] ] ] = { }
self . shard_id : Optional [ int ] = options . get ( ' shard_id ' )
self . shard_count : Optional [ int ] = options . get ( ' shard_count ' )
@ -216,14 +205,16 @@ class Client:
proxy : Optional [ str ] = options . pop ( ' proxy ' , None )
proxy_auth : Optional [ aiohttp . BasicAuth ] = options . pop ( ' proxy_auth ' , None )
unsync_clock : bool = options . pop ( ' assume_unsync_clock ' , True )
self . http : HTTPClient = HTTPClient ( connector , proxy = proxy , proxy_auth = proxy_auth , unsync_clock = unsync_clock , loop = self . loop )
self . http : HTTPClient = HTTPClient (
connector , proxy = proxy , proxy_auth = proxy_auth , unsync_clock = unsync_clock , loop = loop
)
self . _handlers : Dict [ str , Callable ] = {
' ready ' : self . _handle_ready
' ready ' : self . _handle_ready ,
}
self . _hooks : Dict [ str , Callable ] = {
' before_identify ' : self . _call_before_identify_hook
' before_identify ' : self . _call_before_identify_hook ,
}
self . _enable_debug_events : bool = options . pop ( ' enable_debug_events ' , False )
@ -244,8 +235,9 @@ class Client:
return self . ws
def _get_state ( self , * * options : Any ) - > ConnectionState :
return ConnectionState ( dispatch = self . dispatch , handlers = self . _handlers ,
hooks = self . _hooks , http = self . http , loop = self . loop , * * options )
return ConnectionState (
dispatch = self . dispatch , handlers = self . _handlers , hooks = self . _hooks , http = self . http , loop = self . loop , * * options
)
def _handle_ready ( self ) - > None :
self . _ready . set ( )
@ -343,7 +335,9 @@ class Client:
""" :class:`bool`: Specifies if the client ' s internal cache is ready for use. """
return self . _ready . is_set ( )
async def _run_event ( self , coro : Callable [ . . . , Coroutine [ Any , Any , Any ] ] , event_name : str , * args : Any , * * kwargs : Any ) - > None :
async def _run_event (
self , coro : Callable [ . . . , Coroutine [ Any , Any , Any ] ] , event_name : str , * args : Any , * * kwargs : Any
) - > None :
try :
await coro ( * args , * * kwargs )
except asyncio . CancelledError :
@ -354,7 +348,9 @@ class Client:
except asyncio . CancelledError :
pass
def _schedule_event ( self , coro : Callable [ . . . , Coroutine [ Any , Any , Any ] ] , event_name : str , * args : Any , * * kwargs : Any ) - > asyncio . Task :
def _schedule_event (
self , coro : Callable [ . . . , Coroutine [ Any , Any , Any ] ] , event_name : str , * args : Any , * * kwargs : Any
) - > asyncio . Task :
wrapped = self . _run_event ( coro , event_name , * args , * * kwargs )
# Schedules the task
return asyncio . create_task ( wrapped , name = f ' discord.py: { event_name } ' )
@ -466,7 +462,8 @@ class Client:
"""
log . info ( ' logging in using static token ' )
self . loop = loop = asyncio . get_running_loop ( )
self . _connection . loop = loop
data = await self . http . static_login ( token . strip ( ) )
self . _connection . user = ClientUser ( state = self . _connection , data = data )
@ -512,12 +509,14 @@ class Client:
self . dispatch ( ' disconnect ' )
ws_params . update ( sequence = self . ws . sequence , resume = e . resume , session = self . ws . session_id )
continue
except ( OSError ,
HTTPException ,
GatewayNotFound ,
ConnectionClosed ,
aiohttp . ClientError ,
asyncio . TimeoutError ) as exc :
except (
OSError ,
HTTPException ,
GatewayNotFound ,
ConnectionClosed ,
aiohttp . ClientError ,
asyncio . TimeoutError ,
) as exc :
self . dispatch ( ' disconnect ' )
if not reconnect :
@ -558,6 +557,22 @@ class Client:
""" |coro|
Closes the connection to Discord .
Instead of calling this directly , it is recommended to use the asynchronous context
manager to allow resources to be cleaned up automatically :
. . code - block : : python3
async def main ( ) :
async with Client ( ) as client :
await client . login ( token )
await client . connect ( )
asyncio . run ( main ( ) )
. . versionchanged : : 2.0
The client can now be closed with an asynchronous context manager
"""
if self . _closed :
return
@ -589,36 +604,47 @@ class Client:
self . _connection . clear ( )
self . http . recreate ( )
async def __aenter__ ( self : C ) - > C :
return self
async def __aexit__ (
self ,
exc_type : Optional [ Type [ BaseException ] ] ,
exc_value : Optional [ BaseException ] ,
traceback : Optional [ TracebackType ] ,
) - > None :
await self . close ( )
async def start ( self , token : str , * , reconnect : bool = True ) - > None :
""" |coro|
A shorthand coroutine for : meth : ` login ` + : meth : ` connect ` .
A shorthand function equivalent to the following :
Raises
- - - - - - -
TypeError
An unexpected keyword argument was received .
. . code - block : : python3
async with client :
await client . login ( token )
await client . connect ( )
This closes the client when it returns .
"""
await self . login ( token )
await self . connect ( reconnect = reconnect )
try :
await self . login ( token )
await self . connect ( reconnect = reconnect )
finally :
await self . close ( )
def run ( self , * args : Any , * * kwargs : Any ) - > None :
""" A blocking call that abstracts away the event loop
""" A convenience blocking call that abstracts away the event loop
initialisation from you .
If you want more control over the event loop then this
function should not be used . Use : meth : ` start ` coroutine
or : meth : ` connect ` + : meth : ` login ` .
Roughly Equivalent to : : :
Equivalent to : : :
try :
loop . run_until_complete ( start ( * args , * * kwargs ) )
except KeyboardInterrupt :
loop . run_until_complete ( close ( ) )
# cancel all tasks lingering
finally :
loop . close ( )
asyncio . run ( bot . start ( * args , * * kwargs ) )
. . warning : :
@ -626,41 +652,7 @@ class Client:
is blocking . That means that registration of events or anything being
called after this function call will not execute until it returns .
"""
loop = self . loop
try :
loop . add_signal_handler ( signal . SIGINT , lambda : loop . stop ( ) )
loop . add_signal_handler ( signal . SIGTERM , lambda : loop . stop ( ) )
except NotImplementedError :
pass
async def runner ( ) :
try :
await self . start ( * args , * * kwargs )
finally :
if not self . is_closed ( ) :
await self . close ( )
def stop_loop_on_completion ( f ) :
loop . stop ( )
future = asyncio . ensure_future ( runner ( ) , loop = loop )
future . add_done_callback ( stop_loop_on_completion )
try :
loop . run_forever ( )
except KeyboardInterrupt :
log . info ( ' Received signal to terminate bot and event loop. ' )
finally :
future . remove_done_callback ( stop_loop_on_completion )
log . info ( ' Cleaning up tasks. ' )
_cleanup_loop ( loop )
if not future . cancelled ( ) :
try :
return future . result ( )
except KeyboardInterrupt :
# I am unsure why this gets raised here but suppress it anyway
return None
asyncio . run ( self . start ( * args , * * kwargs ) )
# properties
@ -973,8 +965,10 @@ class Client:
future = self . loop . create_future ( )
if check is None :
def _check ( * args ) :
return True
check = _check
ev = event . lower ( )
@ -1083,7 +1077,7 @@ class Client:
* ,
limit : Optional [ int ] = 100 ,
before : SnowflakeTime = None ,
after : SnowflakeTime = None
after : SnowflakeTime = None ,
) - > GuildIterator :
""" Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1163,7 +1157,7 @@ class Client:
"""
code = utils . resolve_template ( code )
data = await self . http . get_template ( code )
return Template ( data = data , state = self . _connection ) # type: ignore
return Template ( data = data , state = self . _connection ) # type: ignore
async def fetch_guild ( self , guild_id : int ) - > Guild :
""" |coro|
@ -1284,7 +1278,9 @@ class Client:
# Invite management
async def fetch_invite ( self , url : Union [ Invite , str ] , * , with_counts : bool = True , with_expiration : bool = True ) - > Invite :
async def fetch_invite (
self , url : Union [ Invite , str ] , * , with_counts : bool = True , with_expiration : bool = True
) - > Invite :
""" |coro|
Gets an : class : ` . Invite ` from a discord . gg URL or ID .
@ -1520,7 +1516,7 @@ class Client:
"""
data = await self . http . get_sticker ( sticker_id )
cls , _ = _sticker_factory ( data [ ' type ' ] ) # type: ignore
return cls ( state = self . _connection , data = data ) # type: ignore
return cls ( state = self . _connection , data = data ) # type: ignore
async def fetch_premium_sticker_packs ( self ) - > List [ StickerPack ] :
""" |coro|