Browse Source

Improve typing across the board, remove old browser references

pull/10109/head
dolfies 3 years ago
parent
commit
3395417844
  1. 2
      discord/appinfo.py
  2. 10
      discord/enums.py
  3. 14
      discord/guild.py
  4. 2
      discord/http.py
  5. 4
      discord/invite.py
  6. 6
      discord/member.py
  7. 2
      discord/settings.py
  8. 18
      discord/state.py
  9. 2
      discord/types/user.py
  10. 2
      discord/user.py
  11. 191
      discord/utils.py
  12. 7
      discord/welcome_screen.py

2
discord/appinfo.py

@ -188,6 +188,8 @@ class PartialApplication(Hashable):
'hook', 'hook',
'premium_tier_level', 'premium_tier_level',
'tags', 'tags',
'max_participants',
'install_url',
) )
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):

10
discord/enums.py

@ -72,7 +72,6 @@ __all__ = (
'UnavailableGuildType', 'UnavailableGuildType',
'RequiredActionType', 'RequiredActionType',
'ReportType', 'ReportType',
'BrowserEnum',
'ApplicationVerificationState', 'ApplicationVerificationState',
'StoreApplicationState', 'StoreApplicationState',
'RPCApplicationState', 'RPCApplicationState',
@ -655,15 +654,6 @@ class RequiredActionType(Enum):
accept_terms = 'AGREEMENTS' accept_terms = 'AGREEMENTS'
class BrowserEnum(Enum):
google_chrome = 'chrome'
chrome = 'chrome'
chromium = 'chromium'
microsoft_edge = 'microsoft-edge'
edge = 'microsoft-edge'
opera = 'opera'
class InviteTarget(Enum): class InviteTarget(Enum):
unknown = 0 unknown = 0
stream = 1 stream = 1

14
discord/guild.py

@ -129,7 +129,7 @@ if TYPE_CHECKING:
StageChannel as StageChannelPayload, StageChannel as StageChannelPayload,
) )
from .types.integration import IntegrationType from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList from .types.snowflake import SnowflakeList, Snowflake as _Snowflake
from .types.widget import EditWidgetSettings from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel] VocalGuildChannel = Union[VoiceChannel, StageChannel]
@ -405,7 +405,9 @@ class Guild(Hashable):
inner = ' '.join('%s=%r' % t for t in attrs) inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Guild {inner}>' return f'<Guild {inner}>'
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: def _update_voice_state(
self, data: GuildVoiceState, channel_id: Optional[int]
) -> Tuple[Optional[Member], VoiceState, VoiceState]:
cache_flags = self._state.member_cache_flags cache_flags = self._state.member_cache_flags
user_id = int(data['user_id']) user_id = int(data['user_id'])
channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel
@ -3270,9 +3272,9 @@ class Guild(Hashable):
if action: if action:
action = action.value action = action.value
if isinstance(before, datetime.datetime): if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False)) before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime): if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True)) after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is MISSING: if oldest_first is MISSING:
@ -3636,10 +3638,10 @@ class Guild(Hashable):
limit = min(100, limit or 5) limit = min(100, limit or 5)
members = await self._state.query_members( members = await self._state.query_members(
self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache # type: ignore - The two types are compatible
) )
if subscribe: if subscribe:
ids = [str(m.id) for m in members] ids: List[_Snowflake] = [str(m.id) for m in members]
await self._state.ws.request_lazy_guild(self.id, members=ids) await self._state.ws.request_lazy_guild(self.id, members=ids)
return members return members

2
discord/http.py

@ -634,7 +634,7 @@ class HTTPClient:
# PM functionality # PM functionality
def start_group(self, recipients: List[Snowflake]) -> Response[channel.GroupDMChannel]: def start_group(self, recipients: SnowflakeList) -> Response[channel.GroupDMChannel]:
payload = { payload = {
'recipients': recipients, 'recipients': recipients,
} }

4
discord/invite.py

@ -477,9 +477,9 @@ class Invite(Hashable):
channel = state.get_channel(getattr(channel, 'id', None)) or channel channel = state.get_channel(getattr(channel, 'id', None)) or channel
if message is not None: if message is not None:
data['message'] = message data['message'] = message # type: ignore - Not a real field
return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) # type: ignore
@classmethod @classmethod
def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self: def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self:

6
discord/member.py

@ -142,12 +142,12 @@ class VoiceState:
) )
def __init__( def __init__(
self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel] = None self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel] = None
): ):
self.session_id: Optional[str] = data.get('session_id') self.session_id: Optional[str] = data.get('session_id')
self._update(data, channel) self._update(data, channel)
def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel]): def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel]):
self.self_mute: bool = data.get('self_mute', False) self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False) self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False) self.self_stream: bool = data.get('self_stream', False)
@ -874,7 +874,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
data = await http.edit_member(guild_id, self.id, reason=reason, **payload) data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
if data: if data:
return Member(data=data, guild=self.guild, state=self._state) # type: ignore return Member(data=data, guild=self.guild, state=self._state)
async def request_to_speak(self) -> None: async def request_to_speak(self) -> None:
"""|coro| """|coro|

2
discord/settings.py

@ -385,7 +385,7 @@ class ChannelSettings:
self._channel_id = int(data['channel_id']) self._channel_id = int(data['channel_id'])
self.collapsed = data.get('collapsed', False) self.collapsed = data.get('collapsed', False)
self.level = try_enum(NotificationLevel, data.get('message_notifications', 3)) # type: ignore self.level = try_enum(NotificationLevel, data.get('message_notifications', 3))
self.muted = MuteConfig(data.get('muted', False), data.get('mute_config') or {}) self.muted = MuteConfig(data.get('muted', False), data.get('mute_config') or {})
@property @property

18
discord/state.py

@ -527,10 +527,12 @@ class ConnectionState:
def voice_clients(self) -> List[VoiceProtocol]: def voice_clients(self) -> List[VoiceProtocol]:
return list(self._voice_clients.values()) return list(self._voice_clients.values())
def _update_voice_state(self, data: GuildVoiceState, channel_id: Optional[int]) -> Tuple[User, VoiceState, VoiceState]: def _update_voice_state(
self, data: GuildVoiceState, channel_id: Optional[int]
) -> Tuple[Optional[User], VoiceState, VoiceState]:
user_id = int(data['user_id']) user_id = int(data['user_id'])
user = self.get_user(user_id) user = self.get_user(user_id)
channel = self._get_private_channel(channel_id) channel: Optional[Union[DMChannel, GroupChannel]] = self._get_private_channel(channel_id) # type: ignore
try: try:
# Check if we should remove the voice state from cache # Check if we should remove the voice state from cache
@ -756,7 +758,13 @@ class ConnectionState:
return self.ws.request_chunks([guild_id], query=query, limit=limit, presences=presences, nonce=nonce) return self.ws.request_chunks([guild_id], query=query, limit=limit, presences=presences, nonce=nonce)
async def query_members( async def query_members(
self, guild: Guild, query: Optional[str], limit: int, user_ids: Optional[List[int]], cache: bool, presences: bool self,
guild: Guild,
query: Optional[str],
limit: int,
user_ids: Optional[List[Snowflake]],
cache: bool,
presences: bool,
) -> List[Member]: ) -> List[Member]:
guild_id = guild.id guild_id = guild.id
request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
@ -830,7 +838,7 @@ class ConnectionState:
) or {'guild_id': guild_data['id']} ) or {'guild_id': guild_data['id']}
for presence in merged_presences: for presence in merged_presences:
presence['user'] = {'id': presence['user_id']} presence['user'] = {'id': presence['user_id']} # type: ignore - :(
voice_states = guild_data.setdefault('voice_states', []) voice_states = guild_data.setdefault('voice_states', [])
voice_states.extend(guild_extra.get('voice_states', [])) voice_states.extend(guild_extra.get('voice_states', []))
@ -1660,7 +1668,7 @@ class ConnectionState:
else: else:
raise RuntimeError('No channels viewable') raise RuntimeError('No channels viewable')
requests = {str(channel.id): [[0, 99]]} requests: Dict[Snowflake, List[List[int]]] = {str(channel.id): [[0, 99]]}
def predicate(data): def predicate(data):
return int(data['guild_id']) == guild.id return int(data['guild_id']) == guild.id

2
discord/types/user.py

@ -51,4 +51,4 @@ class User(PartialUser, total=False):
bio: str bio: str
analytics_token: str analytics_token: str
phone: Optional[str] phone: Optional[str]
token: str token: str

2
discord/user.py

@ -93,7 +93,7 @@ class Note:
def note(self) -> Optional[str]: def note(self) -> Optional[str]:
"""Returns the note. """Returns the note.
There is an alias for this named :attr:`value`. There is an alias for this called :attr:`value`.
Raises Raises
------- -------

191
discord/utils.py

@ -62,20 +62,15 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter from operator import attrgetter
import json import json
import logging import logging
import os
import platform
import random import random
import re import re
import string import string
import subprocess
import sys import sys
import tempfile
from threading import Timer from threading import Timer
import types import types
import warnings import warnings
import yarl import yarl
from .enums import BrowserEnum
try: try:
import orjson # type: ignore import orjson # type: ignore
@ -1168,7 +1163,11 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
def set_target( def set_target(
items: Iterable[ApplicationCommand], *, channel: Messageable = None, message: Message = None, user: Snowflake = None items: Iterable[ApplicationCommand],
*,
channel: Optional[Messageable] = MISSING,
message: Optional[Message] = MISSING,
user: Optional[Snowflake] = MISSING,
) -> None: ) -> None:
"""A helper function to set the target for a list of items. """A helper function to set the target for a list of items.
@ -1181,24 +1180,26 @@ def set_target(
----------- -----------
items: Iterable[:class:`ApplicationCommand`] items: Iterable[:class:`ApplicationCommand`]
A list of items to set the target for. A list of items to set the target for.
channel: :class:`Messageable` channel: Optional[:class:`Messageable`]
The channel to target. The channel to target.
message: :class:`Message` message: Optional[:class:`Message`]
The message to target. The message to target.
user: :class:`Snowflake` user: Optional[:class:`~abc.Snowflake`]
The user to target. The user to target.
""" """
attrs = { attrs = {}
'target_channel': channel, if channel is not MISSING:
'target_message': message, attrs['target_channel'] = channel
'target_user': user, if message is not MISSING:
} attrs['target_message'] = message
if user is not MISSING:
attrs['target_user'] = user
for item in items: for item in items:
for k, v in attrs.items(): for k, v in attrs.items():
if v is not None: if v is not None:
try: try:
setattr(item, k, v) # type: ignore setattr(item, k, v)
except AttributeError: except AttributeError:
pass pass
@ -1207,39 +1208,6 @@ def _generate_session_id() -> str:
return ''.join(random.choices(string.ascii_letters + string.digits, k=16)) return ''.join(random.choices(string.ascii_letters + string.digits, k=16))
class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester
def __init__(self, timeout: int, maxsize: int = 0) -> None:
super().__init__(maxsize)
self.timeout = timeout
self.timers: asyncio.Queue = asyncio.Queue()
async def put(self, item: str) -> None:
thread: Timer = Timer(self.timeout, self.expire)
thread.start()
await self.timers.put(thread)
await super().put(item)
async def get(self, block: bool = True) -> str:
if block:
thread = await self.timers.get()
else:
thread = self.timers.get_nowait()
thread.cancel()
if block:
return await super().get()
else:
return self.get_nowait()
def expire(self) -> None:
try:
self._queue.popleft()
except:
pass
def to_list(self) -> List[str]:
return list(self._queue)
class ExpiringString(collections.UserString): class ExpiringString(collections.UserString):
def __init__(self, data: str, timeout: int) -> None: def __init__(self, data: str, timeout: int) -> None:
super().__init__(data) super().__init__(data)
@ -1263,133 +1231,6 @@ class ExpiringString(collections.UserString):
self._timer.cancel() self._timer.cancel()
class Browser: # Inspired from https://github.com/NoahCardoza/CaptchaHarvester
def __init__(self, browser: Union[BrowserEnum, str] = None) -> None:
if isinstance(browser, (BrowserEnum, type(None))):
try:
browser = self.get_browser(browser)
except Exception:
raise RuntimeError('Could not find browser. Please pass browser path manually.')
if browser is None:
raise RuntimeError('Could not find browser. Please pass browser path manually.')
self.browser: str = browser
self.proc: subprocess.Popen = MISSING
def get_mac_browser(pkg: str, binary: str) -> Optional[os.PathLike]:
import plistlib as plist
pfile: str = f'{os.environ["HOME"]}/Library/Preferences/{pkg}.plist'
if os.path.exists(pfile):
with open(pfile, 'rb') as f:
binary_path: Optional[str] = plist.load(f).get('LastRunAppBundlePath')
if binary_path is not None:
return os.path.join(binary_path, 'Contents', 'MacOS', binary)
def get_windows_browser(browser: str) -> Optional[str]:
import winreg as reg
reg_path: str = f'SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\{browser}.exe'
exe_path: Optional[str] = None
for install_type in reg.HKEY_CURRENT_USER, reg.HKEY_LOCAL_MACHINE:
try:
reg_key: str = reg.OpenKey(install_type, reg_path, 0, reg.KEY_READ)
exe_path: Optional[str] = reg.QueryValue(reg_key, None)
reg_key.Close()
if not os.path.isfile(exe_path):
continue
except reg.WindowsError:
pass
else:
break
return exe_path
def get_linux_browser(browser: str) -> Optional[str]:
from shutil import which as exists
possibilities: List[str] = [browser + channel for channel in ('', '-beta', '-dev', '-developer', '-canary')]
for browser in possibilities:
if exists(browser):
return browser
registry: Dict[str, Dict[str, functools.partial]] = {
'Windows': {
'chrome': functools.partial(get_windows_browser, 'chrome'),
'chromium': functools.partial(get_windows_browser, 'chromium'),
'microsoft-edge': functools.partial(get_windows_browser, 'msedge'),
'opera': functools.partial(get_windows_browser, 'opera'),
},
'Darwin': {
'chrome': functools.partial(get_mac_browser, 'com.google.Chrome', 'Google Chrome'),
'chromium': functools.partial(get_mac_browser, 'org.chromium.Chromium', 'Chromium'),
'microsoft-edge': functools.partial(get_mac_browser, 'com.microsoft.Edge', 'Microsoft Edge'),
'opera': functools.partial(get_mac_browser, 'com.operasoftware.Opera', 'Opera'),
},
'Linux': {
'chrome': functools.partial(get_linux_browser, 'chrome'),
'chromium': functools.partial(get_linux_browser, 'chromium'),
'microsoft-edge': functools.partial(get_linux_browser, 'microsoft-edge'),
'opera': functools.partial(get_linux_browser, 'opera'),
},
}
def get_browser(self, browser: Optional[BrowserEnum] = None) -> Optional[str]:
if browser is not None:
return self.registry.get(platform.system(), {})[browser.value]()
for browser in self.registry.get(platform.system(), {}).values():
browser = browser()
if browser is not None:
return browser
@property
def running(self) -> bool:
try:
return self.proc.poll() is None
except:
return False
def launch(
self,
domain: Optional[str] = None,
server: Tuple[Optional[str], Optional[int]] = (None, None),
width: int = 400,
height: int = 500,
browser_args: List[str] = [],
extensions: Optional[str] = None,
) -> None:
browser_command: List[str] = [self.browser, *browser_args]
if extensions:
browser_command.append(f'--load-extension={extensions}')
browser_command.extend(
(
'--disable-default-apps',
'--no-default-browser-check',
'--no-check-default-browser',
'--no-first-run',
'--ignore-certificate-errors',
'--disable-background-networking',
'--disable-component-update',
'--disable-domain-reliability',
f'--user-data-dir={os.path.join(tempfile.TemporaryDirectory().name, "Profiles")}',
f'--host-rules=MAP {domain} {server[0]}:{server[1]}',
f'--window-size={width},{height}',
f'--app=https://{domain}',
)
)
self.proc = subprocess.Popen(browser_command, stdout=-1, stderr=-1)
def stop(self) -> None:
try:
self.proc.terminate()
except:
pass
async def _get_info(session: ClientSession) -> Tuple[str, str, int]: async def _get_info(session: ClientSession) -> Tuple[str, str, int]:
for _ in range(3): for _ in range(3):
try: try:

7
discord/welcome_screen.py

@ -34,6 +34,7 @@ if TYPE_CHECKING:
from .abc import Snowflake from .abc import Snowflake
from .emoji import Emoji from .emoji import Emoji
from .guild import Guild from .guild import Guild
from .invite import PartialInviteGuild
from .state import ConnectionState from .state import ConnectionState
from .types.welcome_screen import ( from .types.welcome_screen import (
WelcomeScreen as WelcomeScreenPayload, WelcomeScreen as WelcomeScreenPayload,
@ -71,7 +72,7 @@ class WelcomeChannel:
@classmethod @classmethod
def _from_dict(cls, *, data: WelcomeScreenChannelPayload, state: ConnectionState) -> WelcomeChannel: def _from_dict(cls, *, data: WelcomeScreenChannelPayload, state: ConnectionState) -> WelcomeChannel:
channel_id = _get_as_snowflake(data, 'channel_id') channel_id = int(data['channel_id'])
channel = state.get_channel(channel_id) or Object(id=channel_id) channel = state.get_channel(channel_id) or Object(id=channel_id)
emoji = None emoji = None
@ -83,7 +84,7 @@ class WelcomeChannel:
return cls(channel=channel, description=data.get('description', ''), emoji=emoji) return cls(channel=channel, description=data.get('description', ''), emoji=emoji)
def _to_dict(self) -> WelcomeScreenChannelPayload: def _to_dict(self) -> WelcomeScreenChannelPayload:
data = { data: WelcomeScreenChannelPayload = {
'channel_id': self.channel.id, 'channel_id': self.channel.id,
'description': self.description, 'description': self.description,
'emoji_id': None, 'emoji_id': None,
@ -117,7 +118,7 @@ class WelcomeScreen:
The channels shown on the welcome screen. The channels shown on the welcome screen.
""" """
def __init__(self, *, data: WelcomeScreenPayload, guild: Guild) -> None: def __init__(self, *, data: WelcomeScreenPayload, guild: Union[Guild, PartialInviteGuild]) -> None:
self.guild = guild self.guild = guild
self._update(data) self._update(data)

Loading…
Cancel
Save