From 339541784499db837c3f70ee48d0c6b95a6b6491 Mon Sep 17 00:00:00 2001 From: dolfies Date: Sat, 2 Apr 2022 15:03:35 -0400 Subject: [PATCH] Improve typing across the board, remove old browser references --- discord/appinfo.py | 2 + discord/enums.py | 10 -- discord/guild.py | 14 +-- discord/http.py | 2 +- discord/invite.py | 4 +- discord/member.py | 6 +- discord/settings.py | 2 +- discord/state.py | 18 +++- discord/types/user.py | 2 +- discord/user.py | 2 +- discord/utils.py | 191 ++++---------------------------------- discord/welcome_screen.py | 7 +- 12 files changed, 52 insertions(+), 208 deletions(-) diff --git a/discord/appinfo.py b/discord/appinfo.py index 6e5deef90..c06fd89eb 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -188,6 +188,8 @@ class PartialApplication(Hashable): 'hook', 'premium_tier_level', 'tags', + 'max_participants', + 'install_url', ) def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): diff --git a/discord/enums.py b/discord/enums.py index ec89930ff..8ae7a3433 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -72,7 +72,6 @@ __all__ = ( 'UnavailableGuildType', 'RequiredActionType', 'ReportType', - 'BrowserEnum', 'ApplicationVerificationState', 'StoreApplicationState', 'RPCApplicationState', @@ -655,15 +654,6 @@ class RequiredActionType(Enum): accept_terms = 'AGREEMENTS' -class BrowserEnum(Enum): - google_chrome = 'chrome' - chrome = 'chrome' - chromium = 'chromium' - microsoft_edge = 'microsoft-edge' - edge = 'microsoft-edge' - opera = 'opera' - - class InviteTarget(Enum): unknown = 0 stream = 1 diff --git a/discord/guild.py b/discord/guild.py index 888c42493..dc1e3d651 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -129,7 +129,7 @@ if TYPE_CHECKING: StageChannel as StageChannelPayload, ) from .types.integration import IntegrationType - from .types.snowflake import SnowflakeList + from .types.snowflake import SnowflakeList, Snowflake as _Snowflake from .types.widget import EditWidgetSettings VocalGuildChannel = Union[VoiceChannel, StageChannel] @@ -405,7 +405,9 @@ class Guild(Hashable): inner = ' '.join('%s=%r' % t for t in attrs) return f'' - def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: + def _update_voice_state( + self, data: GuildVoiceState, channel_id: Optional[int] + ) -> Tuple[Optional[Member], VoiceState, VoiceState]: cache_flags = self._state.member_cache_flags user_id = int(data['user_id']) channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel @@ -3270,9 +3272,9 @@ class Guild(Hashable): if action: action = action.value - if isinstance(before, datetime.datetime): + if isinstance(before, datetime): before = Object(id=utils.time_snowflake(before, high=False)) - if isinstance(after, datetime.datetime): + if isinstance(after, datetime): after = Object(id=utils.time_snowflake(after, high=True)) if oldest_first is MISSING: @@ -3636,10 +3638,10 @@ class Guild(Hashable): limit = min(100, limit or 5) members = await self._state.query_members( - self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache + self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache # type: ignore - The two types are compatible ) if subscribe: - ids = [str(m.id) for m in members] + ids: List[_Snowflake] = [str(m.id) for m in members] await self._state.ws.request_lazy_guild(self.id, members=ids) return members diff --git a/discord/http.py b/discord/http.py index 107f12728..74c052547 100644 --- a/discord/http.py +++ b/discord/http.py @@ -634,7 +634,7 @@ class HTTPClient: # PM functionality - def start_group(self, recipients: List[Snowflake]) -> Response[channel.GroupDMChannel]: + def start_group(self, recipients: SnowflakeList) -> Response[channel.GroupDMChannel]: payload = { 'recipients': recipients, } diff --git a/discord/invite.py b/discord/invite.py index fc95d8b51..522709709 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -477,9 +477,9 @@ class Invite(Hashable): channel = state.get_channel(getattr(channel, 'id', None)) or channel if message is not None: - data['message'] = message + data['message'] = message # type: ignore - Not a real field - return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) + return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) # type: ignore @classmethod def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self: diff --git a/discord/member.py b/discord/member.py index 85a34bc09..04df7c456 100644 --- a/discord/member.py +++ b/discord/member.py @@ -142,12 +142,12 @@ class VoiceState: ) def __init__( - self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel] = None + self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel] = None ): self.session_id: Optional[str] = data.get('session_id') self._update(data, channel) - def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel]): + def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel]): self.self_mute: bool = data.get('self_mute', False) self.self_deaf: bool = data.get('self_deaf', False) self.self_stream: bool = data.get('self_stream', False) @@ -874,7 +874,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): data = await http.edit_member(guild_id, self.id, reason=reason, **payload) if data: - return Member(data=data, guild=self.guild, state=self._state) # type: ignore + return Member(data=data, guild=self.guild, state=self._state) async def request_to_speak(self) -> None: """|coro| diff --git a/discord/settings.py b/discord/settings.py index 17a361a6e..b55560e5b 100644 --- a/discord/settings.py +++ b/discord/settings.py @@ -385,7 +385,7 @@ class ChannelSettings: self._channel_id = int(data['channel_id']) self.collapsed = data.get('collapsed', False) - self.level = try_enum(NotificationLevel, data.get('message_notifications', 3)) # type: ignore + self.level = try_enum(NotificationLevel, data.get('message_notifications', 3)) self.muted = MuteConfig(data.get('muted', False), data.get('mute_config') or {}) @property diff --git a/discord/state.py b/discord/state.py index 036fa8860..fdc1ad171 100644 --- a/discord/state.py +++ b/discord/state.py @@ -527,10 +527,12 @@ class ConnectionState: def voice_clients(self) -> List[VoiceProtocol]: return list(self._voice_clients.values()) - def _update_voice_state(self, data: GuildVoiceState, channel_id: Optional[int]) -> Tuple[User, VoiceState, VoiceState]: + def _update_voice_state( + self, data: GuildVoiceState, channel_id: Optional[int] + ) -> Tuple[Optional[User], VoiceState, VoiceState]: user_id = int(data['user_id']) user = self.get_user(user_id) - channel = self._get_private_channel(channel_id) + channel: Optional[Union[DMChannel, GroupChannel]] = self._get_private_channel(channel_id) # type: ignore try: # Check if we should remove the voice state from cache @@ -756,7 +758,13 @@ class ConnectionState: return self.ws.request_chunks([guild_id], query=query, limit=limit, presences=presences, nonce=nonce) async def query_members( - self, guild: Guild, query: Optional[str], limit: int, user_ids: Optional[List[int]], cache: bool, presences: bool + self, + guild: Guild, + query: Optional[str], + limit: int, + user_ids: Optional[List[Snowflake]], + cache: bool, + presences: bool, ) -> List[Member]: guild_id = guild.id request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) @@ -830,7 +838,7 @@ class ConnectionState: ) or {'guild_id': guild_data['id']} for presence in merged_presences: - presence['user'] = {'id': presence['user_id']} + presence['user'] = {'id': presence['user_id']} # type: ignore - :( voice_states = guild_data.setdefault('voice_states', []) voice_states.extend(guild_extra.get('voice_states', [])) @@ -1660,7 +1668,7 @@ class ConnectionState: else: raise RuntimeError('No channels viewable') - requests = {str(channel.id): [[0, 99]]} + requests: Dict[Snowflake, List[List[int]]] = {str(channel.id): [[0, 99]]} def predicate(data): return int(data['guild_id']) == guild.id diff --git a/discord/types/user.py b/discord/types/user.py index 699322345..1d44dcfbb 100644 --- a/discord/types/user.py +++ b/discord/types/user.py @@ -51,4 +51,4 @@ class User(PartialUser, total=False): bio: str analytics_token: str phone: Optional[str] - token: str \ No newline at end of file + token: str diff --git a/discord/user.py b/discord/user.py index acb03d5e0..74f592bdc 100644 --- a/discord/user.py +++ b/discord/user.py @@ -93,7 +93,7 @@ class Note: def note(self) -> Optional[str]: """Returns the note. - There is an alias for this named :attr:`value`. + There is an alias for this called :attr:`value`. Raises ------- diff --git a/discord/utils.py b/discord/utils.py index 65e23c433..282d74a1f 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -62,20 +62,15 @@ from inspect import isawaitable as _isawaitable, signature as _signature from operator import attrgetter import json import logging -import os -import platform import random import re import string -import subprocess import sys -import tempfile from threading import Timer import types import warnings import yarl -from .enums import BrowserEnum try: import orjson # type: ignore @@ -1168,7 +1163,11 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) def set_target( - items: Iterable[ApplicationCommand], *, channel: Messageable = None, message: Message = None, user: Snowflake = None + items: Iterable[ApplicationCommand], + *, + channel: Optional[Messageable] = MISSING, + message: Optional[Message] = MISSING, + user: Optional[Snowflake] = MISSING, ) -> None: """A helper function to set the target for a list of items. @@ -1181,24 +1180,26 @@ def set_target( ----------- items: Iterable[:class:`ApplicationCommand`] A list of items to set the target for. - channel: :class:`Messageable` + channel: Optional[:class:`Messageable`] The channel to target. - message: :class:`Message` + message: Optional[:class:`Message`] The message to target. - user: :class:`Snowflake` + user: Optional[:class:`~abc.Snowflake`] The user to target. """ - attrs = { - 'target_channel': channel, - 'target_message': message, - 'target_user': user, - } + attrs = {} + if channel is not MISSING: + attrs['target_channel'] = channel + if message is not MISSING: + attrs['target_message'] = message + if user is not MISSING: + attrs['target_user'] = user for item in items: for k, v in attrs.items(): if v is not None: try: - setattr(item, k, v) # type: ignore + setattr(item, k, v) except AttributeError: pass @@ -1207,39 +1208,6 @@ def _generate_session_id() -> str: return ''.join(random.choices(string.ascii_letters + string.digits, k=16)) -class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester - def __init__(self, timeout: int, maxsize: int = 0) -> None: - super().__init__(maxsize) - self.timeout = timeout - self.timers: asyncio.Queue = asyncio.Queue() - - async def put(self, item: str) -> None: - thread: Timer = Timer(self.timeout, self.expire) - thread.start() - await self.timers.put(thread) - await super().put(item) - - async def get(self, block: bool = True) -> str: - if block: - thread = await self.timers.get() - else: - thread = self.timers.get_nowait() - thread.cancel() - if block: - return await super().get() - else: - return self.get_nowait() - - def expire(self) -> None: - try: - self._queue.popleft() - except: - pass - - def to_list(self) -> List[str]: - return list(self._queue) - - class ExpiringString(collections.UserString): def __init__(self, data: str, timeout: int) -> None: super().__init__(data) @@ -1263,133 +1231,6 @@ class ExpiringString(collections.UserString): self._timer.cancel() -class Browser: # Inspired from https://github.com/NoahCardoza/CaptchaHarvester - def __init__(self, browser: Union[BrowserEnum, str] = None) -> None: - if isinstance(browser, (BrowserEnum, type(None))): - try: - browser = self.get_browser(browser) - except Exception: - raise RuntimeError('Could not find browser. Please pass browser path manually.') - - if browser is None: - raise RuntimeError('Could not find browser. Please pass browser path manually.') - - self.browser: str = browser - self.proc: subprocess.Popen = MISSING - - def get_mac_browser(pkg: str, binary: str) -> Optional[os.PathLike]: - import plistlib as plist - - pfile: str = f'{os.environ["HOME"]}/Library/Preferences/{pkg}.plist' - if os.path.exists(pfile): - with open(pfile, 'rb') as f: - binary_path: Optional[str] = plist.load(f).get('LastRunAppBundlePath') - if binary_path is not None: - return os.path.join(binary_path, 'Contents', 'MacOS', binary) - - def get_windows_browser(browser: str) -> Optional[str]: - import winreg as reg - - reg_path: str = f'SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\{browser}.exe' - exe_path: Optional[str] = None - for install_type in reg.HKEY_CURRENT_USER, reg.HKEY_LOCAL_MACHINE: - try: - reg_key: str = reg.OpenKey(install_type, reg_path, 0, reg.KEY_READ) - exe_path: Optional[str] = reg.QueryValue(reg_key, None) - reg_key.Close() - if not os.path.isfile(exe_path): - continue - except reg.WindowsError: - pass - else: - break - return exe_path - - def get_linux_browser(browser: str) -> Optional[str]: - from shutil import which as exists - - possibilities: List[str] = [browser + channel for channel in ('', '-beta', '-dev', '-developer', '-canary')] - for browser in possibilities: - if exists(browser): - return browser - - registry: Dict[str, Dict[str, functools.partial]] = { - 'Windows': { - 'chrome': functools.partial(get_windows_browser, 'chrome'), - 'chromium': functools.partial(get_windows_browser, 'chromium'), - 'microsoft-edge': functools.partial(get_windows_browser, 'msedge'), - 'opera': functools.partial(get_windows_browser, 'opera'), - }, - 'Darwin': { - 'chrome': functools.partial(get_mac_browser, 'com.google.Chrome', 'Google Chrome'), - 'chromium': functools.partial(get_mac_browser, 'org.chromium.Chromium', 'Chromium'), - 'microsoft-edge': functools.partial(get_mac_browser, 'com.microsoft.Edge', 'Microsoft Edge'), - 'opera': functools.partial(get_mac_browser, 'com.operasoftware.Opera', 'Opera'), - }, - 'Linux': { - 'chrome': functools.partial(get_linux_browser, 'chrome'), - 'chromium': functools.partial(get_linux_browser, 'chromium'), - 'microsoft-edge': functools.partial(get_linux_browser, 'microsoft-edge'), - 'opera': functools.partial(get_linux_browser, 'opera'), - }, - } - - def get_browser(self, browser: Optional[BrowserEnum] = None) -> Optional[str]: - if browser is not None: - return self.registry.get(platform.system(), {})[browser.value]() - - for browser in self.registry.get(platform.system(), {}).values(): - browser = browser() - if browser is not None: - return browser - - @property - def running(self) -> bool: - try: - return self.proc.poll() is None - except: - return False - - def launch( - self, - domain: Optional[str] = None, - server: Tuple[Optional[str], Optional[int]] = (None, None), - width: int = 400, - height: int = 500, - browser_args: List[str] = [], - extensions: Optional[str] = None, - ) -> None: - browser_command: List[str] = [self.browser, *browser_args] - - if extensions: - browser_command.append(f'--load-extension={extensions}') - - browser_command.extend( - ( - '--disable-default-apps', - '--no-default-browser-check', - '--no-check-default-browser', - '--no-first-run', - '--ignore-certificate-errors', - '--disable-background-networking', - '--disable-component-update', - '--disable-domain-reliability', - f'--user-data-dir={os.path.join(tempfile.TemporaryDirectory().name, "Profiles")}', - f'--host-rules=MAP {domain} {server[0]}:{server[1]}', - f'--window-size={width},{height}', - f'--app=https://{domain}', - ) - ) - - self.proc = subprocess.Popen(browser_command, stdout=-1, stderr=-1) - - def stop(self) -> None: - try: - self.proc.terminate() - except: - pass - - async def _get_info(session: ClientSession) -> Tuple[str, str, int]: for _ in range(3): try: diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index d48aa8f34..7e54151ef 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from .abc import Snowflake from .emoji import Emoji from .guild import Guild + from .invite import PartialInviteGuild from .state import ConnectionState from .types.welcome_screen import ( WelcomeScreen as WelcomeScreenPayload, @@ -71,7 +72,7 @@ class WelcomeChannel: @classmethod def _from_dict(cls, *, data: WelcomeScreenChannelPayload, state: ConnectionState) -> WelcomeChannel: - channel_id = _get_as_snowflake(data, 'channel_id') + channel_id = int(data['channel_id']) channel = state.get_channel(channel_id) or Object(id=channel_id) emoji = None @@ -83,7 +84,7 @@ class WelcomeChannel: return cls(channel=channel, description=data.get('description', ''), emoji=emoji) def _to_dict(self) -> WelcomeScreenChannelPayload: - data = { + data: WelcomeScreenChannelPayload = { 'channel_id': self.channel.id, 'description': self.description, 'emoji_id': None, @@ -117,7 +118,7 @@ class WelcomeScreen: The channels shown on the welcome screen. """ - def __init__(self, *, data: WelcomeScreenPayload, guild: Guild) -> None: + def __init__(self, *, data: WelcomeScreenPayload, guild: Union[Guild, PartialInviteGuild]) -> None: self.guild = guild self._update(data)