Browse Source

[typing] Type-hint client.py

pull/7135/head
Josh 4 years ago
committed by GitHub
parent
commit
7601d6cec3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 205
      discord/client.py

205
discord/client.py

@ -29,7 +29,7 @@ import logging
import signal
import sys
import traceback
from typing import Any, Generator, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
import aiohttp
@ -38,12 +38,13 @@ from .invite import Invite
from .template import Template
from .widget import Widget
from .guild import Guild
from .emoji import Emoji
from .channel import _channel_factory
from .enums import ChannelType
from .mentions import AllowedMentions
from .errors import *
from .enums import Status, VoiceRegion
from .flags import ApplicationFlags
from .flags import ApplicationFlags, Intents
from .gateway import *
from .activity import BaseActivity, create_activity
from .voice_client import VoiceClient
@ -58,16 +59,24 @@ from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
from .channel import DMChannel
from .user import ClientUser
from .message import Message
from .member import Member
from .voice_client import VoiceProtocol
__all__ = (
'Client',
)
if TYPE_CHECKING:
from .abc import SnowflakeTime
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
log = logging.getLogger(__name__)
def _cancel_tasks(loop):
log: logging.Logger = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
@ -90,7 +99,7 @@ def _cancel_tasks(loop):
'task': task
})
def _cleanup_loop(loop):
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
@ -116,7 +125,7 @@ class Client:
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
connector: :class:`aiohttp.BaseConnector`
connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`]
Proxy URL.
@ -181,31 +190,36 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(self, *, loop=None, **options):
self.ws = None
self.loop = asyncio.get_event_loop() if loop is None else loop
self._listeners = {}
self.shard_id = options.get('shard_id')
self.shard_count = options.get('shard_count')
connector = options.pop('connector', None)
proxy = options.pop('proxy', None)
proxy_auth = options.pop('proxy_auth', None)
unsync_clock = options.pop('assume_unsync_clock', True)
self.http = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
self._handlers = {
def __init__(
self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any,
):
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready
}
self._hooks = {
self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook
}
self._connection = self._get_state(**options)
self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed = False
self._ready = asyncio.Event()
self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event()
self._connection._get_websocket = self._get_websocket
self._connection._get_client = lambda: self
@ -215,18 +229,18 @@ class Client:
# internals
def _get_websocket(self, guild_id=None, *, shard_id=None):
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
return self.ws
def _get_state(self, **options):
def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
def _handle_ready(self):
def _handle_ready(self) -> None:
self._ready.set()
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This could be referred to as the Discord WebSocket protocol latency.
@ -234,7 +248,7 @@ class Client:
ws = self.ws
return float('nan') if not ws else ws.latency
def is_ws_ratelimited(self):
def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members
@ -247,22 +261,22 @@ class Client:
return False
@property
def user(self):
def user(self) -> Optional[ClientUser]:
"""Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in."""
return self._connection.user
@property
def guilds(self):
def guilds(self) -> List[Guild]:
"""List[:class:`.Guild`]: The guilds that the connected client is a member of."""
return self._connection.guilds
@property
def emojis(self):
def emojis(self) -> List[Emoji]:
"""List[:class:`.Emoji`]: The emojis that the connected client has."""
return self._connection.emojis
@property
def cached_messages(self):
def cached_messages(self) -> Sequence[Message]:
"""Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached.
.. versionadded:: 1.1
@ -270,7 +284,7 @@ class Client:
return utils.SequenceProxy(self._connection._messages or [])
@property
def private_channels(self):
def private_channels(self) -> List[PrivateChannel]:
"""List[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on.
.. note::
@ -281,7 +295,7 @@ class Client:
return self._connection.private_channels
@property
def voice_clients(self):
def voice_clients(self) -> List[VoiceProtocol]:
"""List[:class:`.VoiceProtocol`]: Represents a list of voice connections.
These are usually :class:`.VoiceClient` instances.
@ -289,7 +303,7 @@ class Client:
return self._connection.voice_clients
@property
def application_id(self):
def application_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The client's application ID.
If this is not passed via ``__init__`` then this is retrieved
@ -306,11 +320,11 @@ class Client:
"""
return self._connection.application_flags # type: ignore
def is_ready(self):
def is_ready(self) -> bool:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
async def _run_event(self, coro, event_name, *args, **kwargs):
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@ -321,12 +335,12 @@ class Client:
except asyncio.CancelledError:
pass
def _schedule_event(self, coro, event_name, *args, **kwargs):
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event, *args, **kwargs):
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
log.debug('Dispatching event %s', event)
method = 'on_' + event
@ -366,7 +380,7 @@ class Client:
else:
self._schedule_event(coro, method, *args, **kwargs)
async def on_error(self, event_method, *args, **kwargs):
async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None:
"""|coro|
The default error handler provided by the client.
@ -380,13 +394,13 @@ class Client:
# hooks
async def _call_before_identify_hook(self, shard_id, *, initial=False):
async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
# This hook is an internal hook that actually calls the public one.
# It allows the library to have its own hook without stepping on the
# toes of those who need to override their own hook.
await self.before_identify_hook(shard_id, initial=initial)
async def before_identify_hook(self, shard_id, *, initial=False):
async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
"""|coro|
A hook that is called before IDENTIFYing a session. This is useful
@ -410,7 +424,7 @@ class Client:
# login state management
async def login(self, token):
async def login(self, token: str) -> None:
"""|coro|
Logs in the client with the specified credentials.
@ -435,7 +449,7 @@ class Client:
log.info('logging in using static token')
await self.http.static_login(token.strip())
async def connect(self, *, reconnect=True):
async def connect(self, *, reconnect: bool = True) -> None:
"""|coro|
Creates a websocket connection and lets the websocket listen
@ -519,7 +533,7 @@ class Client:
# This is apparently what the official Discord client does.
ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
async def close(self):
async def close(self) -> None:
"""|coro|
Closes the connection to Discord.
@ -531,7 +545,7 @@ class Client:
for voice in self.voice_clients:
try:
await voice.disconnect()
await voice.disconnect(force=True)
except Exception:
# if an error happens during disconnects, disregard it.
pass
@ -542,7 +556,7 @@ class Client:
await self.http.close()
self._ready.clear()
def clear(self):
def clear(self) -> None:
"""Clears the internal state of the bot.
After this, the bot can be considered "re-opened", i.e. :meth:`is_closed`
@ -554,7 +568,7 @@ class Client:
self._connection.clear()
self.http.recreate()
async def start(self, token, *, reconnect=True):
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
A shorthand coroutine for :meth:`login` + :meth:`connect`.
@ -567,7 +581,7 @@ class Client:
await self.login(token)
await self.connect(reconnect=reconnect)
def run(self, *args, **kwargs):
def run(self, *args: Any, **kwargs: Any) -> None:
"""A blocking call that abstracts away the event loop
initialisation from you.
@ -629,19 +643,19 @@ class Client:
# properties
def is_closed(self):
def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
@property
def activity(self):
def activity(self) -> Optional[BaseActivity]:
"""Optional[:class:`.BaseActivity`]: The activity being used upon
logging in.
"""
return create_activity(self._connection._activity)
@activity.setter
def activity(self, value):
def activity(self, value: Optional[BaseActivity]) -> None:
if value is None:
self._connection._activity = None
elif isinstance(value, BaseActivity):
@ -650,7 +664,7 @@ class Client:
raise TypeError('activity must derive from BaseActivity.')
@property
def allowed_mentions(self):
def allowed_mentions(self) -> Optional[AllowedMentions]:
"""Optional[:class:`~discord.AllowedMentions`]: The allowed mention configuration.
.. versionadded:: 1.4
@ -658,14 +672,14 @@ class Client:
return self._connection.allowed_mentions
@allowed_mentions.setter
def allowed_mentions(self, value):
def allowed_mentions(self, value: Optional[AllowedMentions]) -> None:
if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
else:
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
@property
def intents(self):
def intents(self) -> Intents:
""":class:`~discord.Intents`: The intents configured for this connection.
.. versionadded:: 1.5
@ -675,11 +689,11 @@ class Client:
# helpers/getters
@property
def users(self):
def users(self) -> List[User]:
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
return list(self._connection._users.values())
def get_channel(self, id):
def get_channel(self, id: int) -> Optional[Union[GuildChannel, PrivateChannel]]:
"""Returns a channel with the given ID.
Parameters
@ -716,7 +730,7 @@ class Client:
if isinstance(channel, StageChannel):
return channel.instance
def get_guild(self, id):
def get_guild(self, id) -> Optional[Guild]:
"""Returns a guild with the given ID.
Parameters
@ -731,7 +745,7 @@ class Client:
"""
return self._connection._get_guild(id)
def get_user(self, id):
def get_user(self, id) -> Optional[User]:
"""Returns a user with the given ID.
Parameters
@ -746,7 +760,7 @@ class Client:
"""
return self._connection.get_user(id)
def get_emoji(self, id):
def get_emoji(self, id) -> Optional[Emoji]:
"""Returns an emoji with the given ID.
Parameters
@ -761,7 +775,7 @@ class Client:
"""
return self._connection.get_emoji(id)
def get_all_channels(self):
def get_all_channels(self) -> Generator[GuildChannel, None, None]:
"""A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'.
This is equivalent to: ::
@ -785,7 +799,7 @@ class Client:
for guild in self.guilds:
yield from guild.channels
def get_all_members(self):
def get_all_members(self) -> Generator[Member, None, None]:
"""Returns a generator with every :class:`.Member` the client can see.
This is equivalent to: ::
@ -804,14 +818,20 @@ class Client:
# listeners/waiters
async def wait_until_ready(self):
async def wait_until_ready(self) -> None:
"""|coro|
Waits until the client's internal cache is all ready.
"""
await self._ready.wait()
def wait_for(self, event, *, check=None, timeout=None):
def wait_for(
self,
event: str,
*,
check: Optional[Callable[..., bool]] = None,
timeout: Optional[float] = None,
) -> Any:
"""|coro|
Waits for a WebSocket event to be dispatched.
@ -911,7 +931,7 @@ class Client:
# event registration
def event(self, coro):
def event(self, coro: Coro) -> Coro:
"""A decorator that registers an event to listen to.
You can find more info about the events on the :ref:`documentation below <discord-api-events>`.
@ -940,7 +960,13 @@ class Client:
log.debug('%s has successfully been registered as an event', coro.__name__)
return coro
async def change_presence(self, *, activity=None, status=None, afk=False):
async def change_presence(
self,
*,
activity: Optional[BaseActivity] = None,
status: Optional[Status] = None,
afk: bool = False,
):
"""|coro|
Changes the client's presence.
@ -972,16 +998,15 @@ class Client:
"""
if status is None:
status = 'online'
status_enum = Status.online
status_str = 'online'
status = Status.online
elif status is Status.offline:
status = 'invisible'
status_enum = Status.offline
status_str = 'invisible'
status = Status.offline
else:
status_enum = status
status = str(status)
status_str = str(status)
await self.ws.change_presence(activity=activity, status=status, afk=afk)
await self.ws.change_presence(activity=activity, status=status_str, afk=afk)
for guild in self._connection.guilds:
me = guild.me
@ -993,11 +1018,17 @@ class Client:
else:
me.activities = ()
me.status = status_enum
me.status = status
# Guild stuff
def fetch_guilds(self, *, limit: int = 100, before: SnowflakeTime = None, after: SnowflakeTime = None) -> List[Guild]:
def fetch_guilds(
self,
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
) -> List[Guild]:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
.. note::
@ -1052,7 +1083,7 @@ class Client:
"""
return GuildIterator(self, limit=limit, before=before, after=after)
async def fetch_template(self, code):
async def fetch_template(self, code: Union[Template, str]) -> Template:
"""|coro|
Gets a :class:`.Template` from a discord.new URL or code.
@ -1078,7 +1109,7 @@ class Client:
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id):
async def fetch_guild(self, guild_id: int) -> Guild:
"""|coro|
Retrieves a :class:`.Guild` from an ID.
@ -1112,7 +1143,7 @@ class Client:
data = await self.http.get_guild(guild_id)
return Guild(data=data, state=self._connection)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None):
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None, *, code: str = None) -> Guild:
"""|coro|
Creates a :class:`.Guild`.
@ -1259,7 +1290,7 @@ class Client:
# Miscellaneous stuff
async def fetch_widget(self, guild_id):
async def fetch_widget(self, guild_id: int) -> Widget:
"""|coro|
Gets a :class:`.Widget` from a guild ID.
@ -1289,7 +1320,7 @@ class Client:
return Widget(state=self._connection, data=data)
async def application_info(self):
async def application_info(self) -> AppInfo:
"""|coro|
Retrieves the bot's application information.
@ -1309,7 +1340,7 @@ class Client:
data['rpc_origins'] = None
return AppInfo(self._connection, data)
async def fetch_user(self, user_id):
async def fetch_user(self, user_id: int) -> User:
"""|coro|
Retrieves a :class:`~discord.User` based on their ID.
@ -1340,7 +1371,7 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id):
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID.
@ -1382,7 +1413,7 @@ class Client:
return channel
async def fetch_webhook(self, webhook_id):
async def fetch_webhook(self, webhook_id: int) -> Webhook:
"""|coro|
Retrieves a :class:`.Webhook` with the specified ID.
@ -1404,7 +1435,7 @@ class Client:
data = await self.http.get_webhook(webhook_id)
return Webhook.from_state(data, state=self._connection)
async def create_dm(self, user):
async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
Creates a :class:`.DMChannel` with this user.

Loading…
Cancel
Save