From 5390caa67d388b518f1852792a190dfa1c81182c Mon Sep 17 00:00:00 2001 From: Stocker <44980366+StockerMC@users.noreply.github.com> Date: Fri, 20 Aug 2021 20:05:02 -0400 Subject: [PATCH] Typehint shard.py --- discord/shard.py | 140 ++++++++++++++++++++++++++--------------------- 1 file changed, 77 insertions(+), 63 deletions(-) diff --git a/discord/shard.py b/discord/shard.py index 06e3f2137..ef5ed119a 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -22,8 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio -import itertools import logging import aiohttp @@ -34,22 +35,29 @@ from .backoff import ExponentialBackoff from .gateway import * from .errors import ( ClientException, - InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed, PrivilegedIntentsRequired, ) -from . import utils from .enums import Status +from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar + +if TYPE_CHECKING: + from .gateway import DiscordWebSocket + from .activity import BaseActivity + from .enums import Status + + EI = TypeVar('EI', bound='EventItem') + __all__ = ( 'AutoShardedClient', 'ShardInfo', ) -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) class EventType: close = 0 @@ -62,36 +70,36 @@ class EventType: class EventItem: __slots__ = ('type', 'shard', 'error') - def __init__(self, etype, shard, error): - self.type = etype - self.shard = shard - self.error = error + def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: + self.type: int = etype + self.shard: Optional['Shard'] = shard + self.error: Optional[Exception] = error - def __lt__(self, other): + def __lt__(self: EI, other: EI) -> bool: if not isinstance(other, EventItem): return NotImplemented return self.type < other.type - def __eq__(self, other): + def __eq__(self: EI, other: EI) -> bool: if not isinstance(other, EventItem): return NotImplemented return self.type == other.type - def __hash__(self): + def __hash__(self) -> int: return hash(self.type) class Shard: - def __init__(self, ws, client, queue_put): - self.ws = ws - self._client = client - self._dispatch = client.dispatch - self._queue_put = queue_put - self.loop = self._client.loop - self._disconnect = False + def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None: + self.ws: DiscordWebSocket = ws + self._client: Client = client + self._dispatch: Callable[..., None] = client.dispatch + self._queue_put: Callable[[EventItem], None] = queue_put + self.loop: asyncio.AbstractEventLoop = self._client.loop + self._disconnect: bool = False self._reconnect = client._reconnect - self._backoff = ExponentialBackoff() - self._task = None - self._handled_exceptions = ( + self._backoff: ExponentialBackoff = ExponentialBackoff() + self._task: Optional[asyncio.Task] = None + self._handled_exceptions: Tuple[Type[Exception], ...] = ( OSError, HTTPException, GatewayNotFound, @@ -101,25 +109,26 @@ class Shard: ) @property - def id(self): - return self.ws.shard_id + def id(self) -> int: + # DiscordWebSocket.shard_id is set in the from_client classmethod + return self.ws.shard_id # type: ignore - def launch(self): + def launch(self) -> None: self._task = self.loop.create_task(self.worker()) - def _cancel_task(self): + def _cancel_task(self) -> None: if self._task is not None and not self._task.done(): self._task.cancel() - async def close(self): + async def close(self) -> None: self._cancel_task() await self.ws.close(code=1000) - async def disconnect(self): + async def disconnect(self) -> None: await self.close() self._dispatch('shard_disconnect', self.id) - async def _handle_disconnect(self, e): + async def _handle_disconnect(self, e: Exception) -> None: self._dispatch('disconnect') self._dispatch('shard_disconnect', self.id) if not self._reconnect: @@ -148,7 +157,7 @@ class Shard: await asyncio.sleep(retry) self._queue_put(EventItem(EventType.reconnect, self, e)) - async def worker(self): + async def worker(self) -> None: while not self._client.is_closed(): try: await self.ws.poll_event() @@ -165,7 +174,7 @@ class Shard: self._queue_put(EventItem(EventType.terminate, self, e)) break - async def reidentify(self, exc): + async def reidentify(self, exc: ReconnectWebSocket) -> None: self._cancel_task() self._dispatch('disconnect') self._dispatch('shard_disconnect', self.id) @@ -183,7 +192,7 @@ class Shard: else: self.launch() - async def reconnect(self): + async def reconnect(self) -> None: self._cancel_task() try: coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) @@ -215,16 +224,16 @@ class ShardInfo: __slots__ = ('_parent', 'id', 'shard_count') - def __init__(self, parent, shard_count): - self._parent = parent - self.id = parent.id - self.shard_count = shard_count + def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: + self._parent: Shard = parent + self.id: int = parent.id + self.shard_count: Optional[int] = shard_count - def is_closed(self): + def is_closed(self) -> bool: """:class:`bool`: Whether the shard connection is currently closed.""" return not self._parent.ws.open - async def disconnect(self): + async def disconnect(self) -> None: """|coro| Disconnects a shard. When this is called, the shard connection will no @@ -237,7 +246,7 @@ class ShardInfo: await self._parent.disconnect() - async def reconnect(self): + async def reconnect(self) -> None: """|coro| Disconnects and then connects the shard again. @@ -246,7 +255,7 @@ class ShardInfo: await self._parent.disconnect() await self._parent.reconnect() - async def connect(self): + async def connect(self) -> None: """|coro| Connects a shard. If the shard is already connected this does nothing. @@ -257,11 +266,11 @@ class ShardInfo: await self._parent.reconnect() @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard.""" return self._parent.ws.latency - def is_ws_ratelimited(self): + def is_ws_ratelimited(self) -> bool: """:class:`bool`: Whether the websocket is currently rate limited. This can be useful to know when deciding whether you should query members @@ -297,9 +306,12 @@ class AutoShardedClient(Client): shard_ids: Optional[List[:class:`int`]] An optional list of shard_ids to launch the shards with. """ - def __init__(self, *args, loop=None, **kwargs): + if TYPE_CHECKING: + _connection: AutoShardedConnectionState + + def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: kwargs.pop('shard_id', None) - self.shard_ids = kwargs.pop('shard_ids', None) + self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) super().__init__(*args, loop=loop, **kwargs) if self.shard_ids is not None: @@ -315,18 +327,19 @@ class AutoShardedClient(Client): self._connection._get_client = lambda: self self.__queue = asyncio.PriorityQueue() - def _get_websocket(self, guild_id=None, *, shard_id=None): + def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: if shard_id is None: - shard_id = (guild_id >> 22) % self.shard_count + # guild_id won't be None if shard_id is None and shard_count won't be None here + shard_id = (guild_id >> 22) % self.shard_count # type: ignore return self.__shards[shard_id].ws - def _get_state(self, **options): + def _get_state(self, **options: Any) -> AutoShardedConnectionState: return AutoShardedConnectionState(dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options) @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. This operates similarly to :meth:`Client.latency` except it uses the average @@ -338,14 +351,14 @@ class AutoShardedClient(Client): return sum(latency for _, latency in self.latencies) / len(self.__shards) @property - def latencies(self): + def latencies(self) -> List[Tuple[int, float]]: """List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. This returns a list of tuples with elements ``(shard_id, latency)``. """ return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] - def get_shard(self, shard_id): + def get_shard(self, shard_id: int) -> Optional[ShardInfo]: """Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" try: parent = self.__shards[shard_id] @@ -355,11 +368,11 @@ class AutoShardedClient(Client): return ShardInfo(parent, self.shard_count) @property - def shards(self): + def shards(self) -> Dict[int, ShardInfo]: """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() } - async def launch_shard(self, gateway, shard_id, *, initial=False): + async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: try: coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) @@ -372,7 +385,7 @@ class AutoShardedClient(Client): self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait) ret.launch() - async def launch_shards(self): + async def launch_shards(self) -> None: if self.shard_count is None: self.shard_count, gateway = await self.http.get_bot_gateway() else: @@ -389,7 +402,7 @@ class AutoShardedClient(Client): self._connection.shards_launched.set() - async def connect(self, *, reconnect=True): + async def connect(self, *, reconnect: bool = True) -> None: self._reconnect = reconnect await self.launch_shards() @@ -413,7 +426,7 @@ class AutoShardedClient(Client): elif item.type == EventType.clean_close: return - async def close(self): + async def close(self) -> None: """|coro| Closes the connection to Discord. @@ -425,7 +438,7 @@ class AutoShardedClient(Client): for vc in self.voice_clients: try: - await vc.disconnect() + await vc.disconnect(force=True) except Exception: pass @@ -436,7 +449,7 @@ class AutoShardedClient(Client): await self.http.close() self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) - async def change_presence(self, *, activity=None, status=None, shard_id=None): + async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[Status] = None, shard_id: int = None) -> None: """|coro| Changes the client's presence. @@ -468,23 +481,23 @@ class AutoShardedClient(Client): """ if status is None: - status = 'online' + status_value = 'online' status_enum = Status.online elif status is Status.offline: - status = 'invisible' + status_value = 'invisible' status_enum = Status.offline else: status_enum = status - status = str(status) + status_value = str(status) if shard_id is None: for shard in self.__shards.values(): - await shard.ws.change_presence(activity=activity, status=status) + await shard.ws.change_presence(activity=activity, status=status_value) guilds = self._connection.guilds else: shard = self.__shards[shard_id] - await shard.ws.change_presence(activity=activity, status=status) + await shard.ws.change_presence(activity=activity, status=status_value) guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] activities = () if activity is None else (activity,) @@ -493,10 +506,11 @@ class AutoShardedClient(Client): if me is None: continue - me.activities = activities + # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] + me.activities = activities # type: ignore me.status = status_enum - def is_ws_ratelimited(self): + def is_ws_ratelimited(self) -> bool: """:class:`bool`: Whether the websocket is currently rate limited. This can be useful to know when deciding whether you should query members