diff --git a/discord/client.py b/discord/client.py index 64c9a6789..2b3c3e17c 100644 --- a/discord/client.py +++ b/discord/client.py @@ -29,7 +29,7 @@ import logging import signal import sys import traceback -from typing import Any, Generator, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union import aiohttp @@ -38,12 +38,13 @@ from .invite import Invite from .template import Template from .widget import Widget from .guild import Guild +from .emoji import Emoji from .channel import _channel_factory from .enums import ChannelType from .mentions import AllowedMentions from .errors import * from .enums import Status, VoiceRegion -from .flags import ApplicationFlags +from .flags import ApplicationFlags, Intents from .gateway import * from .activity import BaseActivity, create_activity from .voice_client import VoiceClient @@ -58,16 +59,24 @@ from .appinfo import AppInfo from .ui.view import View from .stage_instance import StageInstance +if TYPE_CHECKING: + from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake + from .channel import DMChannel + from .user import ClientUser + from .message import Message + from .member import Member + from .voice_client import VoiceProtocol + __all__ = ( 'Client', ) -if TYPE_CHECKING: - from .abc import SnowflakeTime +Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) -log = logging.getLogger(__name__) -def _cancel_tasks(loop): +log: logging.Logger = logging.getLogger(__name__) + +def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} if not tasks: @@ -90,7 +99,7 @@ def _cancel_tasks(loop): 'task': task }) -def _cleanup_loop(loop): +def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: try: _cancel_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) @@ -116,7 +125,7 @@ class Client: The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations. Defaults to ``None``, in which case the default event loop is used via :func:`asyncio.get_event_loop()`. - connector: :class:`aiohttp.BaseConnector` + connector: Optional[:class:`aiohttp.BaseConnector`] The connector to use for connection pooling. proxy: Optional[:class:`str`] Proxy URL. @@ -181,31 +190,36 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ - def __init__(self, *, loop=None, **options): - self.ws = None - self.loop = asyncio.get_event_loop() if loop is None else loop - self._listeners = {} - self.shard_id = options.get('shard_id') - self.shard_count = options.get('shard_count') - - connector = options.pop('connector', None) - proxy = options.pop('proxy', None) - proxy_auth = options.pop('proxy_auth', None) - unsync_clock = options.pop('assume_unsync_clock', True) - self.http = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) - - self._handlers = { + def __init__( + self, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + **options: Any, + ): + self.ws: DiscordWebSocket = None # type: ignore + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} + self.shard_id: Optional[int] = options.get('shard_id') + self.shard_count: Optional[int] = options.get('shard_count') + + connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None) + proxy: Optional[str] = options.pop('proxy', None) + proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) + unsync_clock: bool = options.pop('assume_unsync_clock', True) + self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) + + self._handlers: Dict[str, Callable] = { 'ready': self._handle_ready } - self._hooks = { + self._hooks: Dict[str, Callable] = { 'before_identify': self._call_before_identify_hook } - self._connection = self._get_state(**options) + self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count - self._closed = False - self._ready = asyncio.Event() + self._closed: bool = False + self._ready: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self @@ -215,18 +229,18 @@ class Client: # internals - def _get_websocket(self, guild_id=None, *, shard_id=None): + def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: return self.ws - def _get_state(self, **options): + def _get_state(self, **options: Any) -> ConnectionState: return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options) - def _handle_ready(self): + def _handle_ready(self) -> None: self._ready.set() @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. This could be referred to as the Discord WebSocket protocol latency. @@ -234,7 +248,7 @@ class Client: ws = self.ws return float('nan') if not ws else ws.latency - def is_ws_ratelimited(self): + def is_ws_ratelimited(self) -> bool: """:class:`bool`: Whether the websocket is currently rate limited. This can be useful to know when deciding whether you should query members @@ -247,22 +261,22 @@ class Client: return False @property - def user(self): + def user(self) -> Optional[ClientUser]: """Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in.""" return self._connection.user @property - def guilds(self): + def guilds(self) -> List[Guild]: """List[:class:`.Guild`]: The guilds that the connected client is a member of.""" return self._connection.guilds @property - def emojis(self): + def emojis(self) -> List[Emoji]: """List[:class:`.Emoji`]: The emojis that the connected client has.""" return self._connection.emojis @property - def cached_messages(self): + def cached_messages(self) -> Sequence[Message]: """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. .. versionadded:: 1.1 @@ -270,7 +284,7 @@ class Client: return utils.SequenceProxy(self._connection._messages or []) @property - def private_channels(self): + def private_channels(self) -> List[PrivateChannel]: """List[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on. .. note:: @@ -281,7 +295,7 @@ class Client: return self._connection.private_channels @property - def voice_clients(self): + def voice_clients(self) -> List[VoiceProtocol]: """List[:class:`.VoiceProtocol`]: Represents a list of voice connections. These are usually :class:`.VoiceClient` instances. @@ -289,7 +303,7 @@ class Client: return self._connection.voice_clients @property - def application_id(self): + def application_id(self) -> Optional[int]: """Optional[:class:`int`]: The client's application ID. If this is not passed via ``__init__`` then this is retrieved @@ -306,11 +320,11 @@ class Client: """ return self._connection.application_flags # type: ignore - def is_ready(self): + def is_ready(self) -> bool: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event(self, coro, event_name, *args, **kwargs): + async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -321,12 +335,12 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event(self, coro, event_name, *args, **kwargs): + def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') - def dispatch(self, event, *args, **kwargs): + def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: log.debug('Dispatching event %s', event) method = 'on_' + event @@ -366,7 +380,7 @@ class Client: else: self._schedule_event(coro, method, *args, **kwargs) - async def on_error(self, event_method, *args, **kwargs): + async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: """|coro| The default error handler provided by the client. @@ -380,13 +394,13 @@ class Client: # hooks - async def _call_before_identify_hook(self, shard_id, *, initial=False): + async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook(self, shard_id, *, initial=False): + async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -410,7 +424,7 @@ class Client: # login state management - async def login(self, token): + async def login(self, token: str) -> None: """|coro| Logs in the client with the specified credentials. @@ -435,7 +449,7 @@ class Client: log.info('logging in using static token') await self.http.static_login(token.strip()) - async def connect(self, *, reconnect=True): + async def connect(self, *, reconnect: bool = True) -> None: """|coro| Creates a websocket connection and lets the websocket listen @@ -519,7 +533,7 @@ class Client: # This is apparently what the official Discord client does. ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) - async def close(self): + async def close(self) -> None: """|coro| Closes the connection to Discord. @@ -531,7 +545,7 @@ class Client: for voice in self.voice_clients: try: - await voice.disconnect() + await voice.disconnect(force=True) except Exception: # if an error happens during disconnects, disregard it. pass @@ -542,7 +556,7 @@ class Client: await self.http.close() self._ready.clear() - def clear(self): + def clear(self) -> None: """Clears the internal state of the bot. After this, the bot can be considered "re-opened", i.e. :meth:`is_closed` @@ -554,7 +568,7 @@ class Client: self._connection.clear() self.http.recreate() - async def start(self, token, *, reconnect=True): + async def start(self, token: str, *, reconnect: bool = True) -> None: """|coro| A shorthand coroutine for :meth:`login` + :meth:`connect`. @@ -567,7 +581,7 @@ class Client: await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args, **kwargs): + def run(self, *args: Any, **kwargs: Any) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -629,19 +643,19 @@ class Client: # properties - def is_closed(self): + def is_closed(self) -> bool: """:class:`bool`: Indicates if the websocket connection is closed.""" return self._closed @property - def activity(self): + def activity(self) -> Optional[BaseActivity]: """Optional[:class:`.BaseActivity`]: The activity being used upon logging in. """ return create_activity(self._connection._activity) @activity.setter - def activity(self, value): + def activity(self, value: Optional[BaseActivity]) -> None: if value is None: self._connection._activity = None elif isinstance(value, BaseActivity): @@ -650,7 +664,7 @@ class Client: raise TypeError('activity must derive from BaseActivity.') @property - def allowed_mentions(self): + def allowed_mentions(self) -> Optional[AllowedMentions]: """Optional[:class:`~discord.AllowedMentions`]: The allowed mention configuration. .. versionadded:: 1.4 @@ -658,14 +672,14 @@ class Client: return self._connection.allowed_mentions @allowed_mentions.setter - def allowed_mentions(self, value): + def allowed_mentions(self, value: Optional[AllowedMentions]) -> None: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}') @property - def intents(self): + def intents(self) -> Intents: """:class:`~discord.Intents`: The intents configured for this connection. .. versionadded:: 1.5 @@ -675,11 +689,11 @@ class Client: # helpers/getters @property - def users(self): + def users(self) -> List[User]: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id): + def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]: """Returns a channel with the given ID. Parameters @@ -716,7 +730,7 @@ class Client: if isinstance(channel, StageChannel): return channel.instance - def get_guild(self, id): + def get_guild(self, id) -> Optional[Guild]: """Returns a guild with the given ID. Parameters @@ -731,7 +745,7 @@ class Client: """ return self._connection._get_guild(id) - def get_user(self, id): + def get_user(self, id) -> Optional[User]: """Returns a user with the given ID. Parameters @@ -746,7 +760,7 @@ class Client: """ return self._connection.get_user(id) - def get_emoji(self, id): + def get_emoji(self, id) -> Optional[Emoji]: """Returns an emoji with the given ID. Parameters @@ -761,7 +775,7 @@ class Client: """ return self._connection.get_emoji(id) - def get_all_channels(self): + def get_all_channels(self) -> Generator[GuildChannel, None, None]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. This is equivalent to: :: @@ -785,7 +799,7 @@ class Client: for guild in self.guilds: yield from guild.channels - def get_all_members(self): + def get_all_members(self) -> Generator[Member, None, None]: """Returns a generator with every :class:`.Member` the client can see. This is equivalent to: :: @@ -804,14 +818,20 @@ class Client: # listeners/waiters - async def wait_until_ready(self): + async def wait_until_ready(self) -> None: """|coro| Waits until the client's internal cache is all ready. """ await self._ready.wait() - def wait_for(self, event, *, check=None, timeout=None): + def wait_for( + self, + event: str, + *, + check: Optional[Callable[..., bool]] = None, + timeout: Optional[float] = None, + ) -> Any: """|coro| Waits for a WebSocket event to be dispatched. @@ -911,7 +931,7 @@ class Client: # event registration - def event(self, coro): + def event(self, coro: Coro) -> Coro: """A decorator that registers an event to listen to. You can find more info about the events on the :ref:`documentation below `. @@ -940,7 +960,13 @@ class Client: log.debug('%s has successfully been registered as an event', coro.__name__) return coro - async def change_presence(self, *, activity=None, status=None, afk=False): + async def change_presence( + self, + *, + activity: Optional[BaseActivity] = None, + status: Optional[Status] = None, + afk: bool = False, + ): """|coro| Changes the client's presence. @@ -972,16 +998,15 @@ class Client: """ if status is None: - status = 'online' - status_enum = Status.online + status_str = 'online' + status = Status.online elif status is Status.offline: - status = 'invisible' - status_enum = Status.offline + status_str = 'invisible' + status = Status.offline else: - status_enum = status - status = str(status) + status_str = str(status) - await self.ws.change_presence(activity=activity, status=status, afk=afk) + await self.ws.change_presence(activity=activity, status=status_str, afk=afk) for guild in self._connection.guilds: me = guild.me @@ -993,11 +1018,17 @@ class Client: else: me.activities = () - me.status = status_enum + me.status = status # Guild stuff - def fetch_guilds(self, *, limit: int = 100, before: SnowflakeTime = None, after: SnowflakeTime = None) -> List[Guild]: + def fetch_guilds( + self, + *, + limit: Optional[int] = 100, + before: SnowflakeTime = None, + after: SnowflakeTime = None + ) -> List[Guild]: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. .. note:: @@ -1052,7 +1083,7 @@ class Client: """ return GuildIterator(self, limit=limit, before=before, after=after) - async def fetch_template(self, code): + async def fetch_template(self, code: Union[Template, str]) -> Template: """|coro| Gets a :class:`.Template` from a discord.new URL or code. @@ -1078,7 +1109,7 @@ class Client: data = await self.http.get_template(code) return Template(data=data, state=self._connection) # type: ignore - async def fetch_guild(self, guild_id): + async def fetch_guild(self, guild_id: int) -> Guild: """|coro| Retrieves a :class:`.Guild` from an ID. @@ -1112,7 +1143,7 @@ class Client: data = await self.http.get_guild(guild_id) return Guild(data=data, state=self._connection) - async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None): + async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None) -> Guild: """|coro| Creates a :class:`.Guild`. @@ -1259,7 +1290,7 @@ class Client: # Miscellaneous stuff - async def fetch_widget(self, guild_id): + async def fetch_widget(self, guild_id: int) -> Widget: """|coro| Gets a :class:`.Widget` from a guild ID. @@ -1289,7 +1320,7 @@ class Client: return Widget(state=self._connection, data=data) - async def application_info(self): + async def application_info(self) -> AppInfo: """|coro| Retrieves the bot's application information. @@ -1309,7 +1340,7 @@ class Client: data['rpc_origins'] = None return AppInfo(self._connection, data) - async def fetch_user(self, user_id): + async def fetch_user(self, user_id: int) -> User: """|coro| Retrieves a :class:`~discord.User` based on their ID. @@ -1340,7 +1371,7 @@ class Client: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel(self, channel_id): + async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]: """|coro| Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID. @@ -1382,7 +1413,7 @@ class Client: return channel - async def fetch_webhook(self, webhook_id): + async def fetch_webhook(self, webhook_id: int) -> Webhook: """|coro| Retrieves a :class:`.Webhook` with the specified ID. @@ -1404,7 +1435,7 @@ class Client: data = await self.http.get_webhook(webhook_id) return Webhook.from_state(data, state=self._connection) - async def create_dm(self, user): + async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| Creates a :class:`.DMChannel` with this user.