Browse Source

Improve typing across the board, remove old browser references

pull/10109/head
dolfies 3 years ago
parent
commit
3395417844
  1. 2
      discord/appinfo.py
  2. 10
      discord/enums.py
  3. 14
      discord/guild.py
  4. 2
      discord/http.py
  5. 4
      discord/invite.py
  6. 6
      discord/member.py
  7. 2
      discord/settings.py
  8. 18
      discord/state.py
  9. 2
      discord/types/user.py
  10. 2
      discord/user.py
  11. 191
      discord/utils.py
  12. 7
      discord/welcome_screen.py

2
discord/appinfo.py

@ -188,6 +188,8 @@ class PartialApplication(Hashable):
'hook',
'premium_tier_level',
'tags',
'max_participants',
'install_url',
)
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):

10
discord/enums.py

@ -72,7 +72,6 @@ __all__ = (
'UnavailableGuildType',
'RequiredActionType',
'ReportType',
'BrowserEnum',
'ApplicationVerificationState',
'StoreApplicationState',
'RPCApplicationState',
@ -655,15 +654,6 @@ class RequiredActionType(Enum):
accept_terms = 'AGREEMENTS'
class BrowserEnum(Enum):
google_chrome = 'chrome'
chrome = 'chrome'
chromium = 'chromium'
microsoft_edge = 'microsoft-edge'
edge = 'microsoft-edge'
opera = 'opera'
class InviteTarget(Enum):
unknown = 0
stream = 1

14
discord/guild.py

@ -129,7 +129,7 @@ if TYPE_CHECKING:
StageChannel as StageChannelPayload,
)
from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList
from .types.snowflake import SnowflakeList, Snowflake as _Snowflake
from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel]
@ -405,7 +405,9 @@ class Guild(Hashable):
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Guild {inner}>'
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]:
def _update_voice_state(
self, data: GuildVoiceState, channel_id: Optional[int]
) -> Tuple[Optional[Member], VoiceState, VoiceState]:
cache_flags = self._state.member_cache_flags
user_id = int(data['user_id'])
channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel
@ -3270,9 +3272,9 @@ class Guild(Hashable):
if action:
action = action.value
if isinstance(before, datetime.datetime):
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is MISSING:
@ -3636,10 +3638,10 @@ class Guild(Hashable):
limit = min(100, limit or 5)
members = await self._state.query_members(
self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache
self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache # type: ignore - The two types are compatible
)
if subscribe:
ids = [str(m.id) for m in members]
ids: List[_Snowflake] = [str(m.id) for m in members]
await self._state.ws.request_lazy_guild(self.id, members=ids)
return members

2
discord/http.py

@ -634,7 +634,7 @@ class HTTPClient:
# PM functionality
def start_group(self, recipients: List[Snowflake]) -> Response[channel.GroupDMChannel]:
def start_group(self, recipients: SnowflakeList) -> Response[channel.GroupDMChannel]:
payload = {
'recipients': recipients,
}

4
discord/invite.py

@ -477,9 +477,9 @@ class Invite(Hashable):
channel = state.get_channel(getattr(channel, 'id', None)) or channel
if message is not None:
data['message'] = message
data['message'] = message # type: ignore - Not a real field
return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen)
return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) # type: ignore
@classmethod
def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self:

6
discord/member.py

@ -142,12 +142,12 @@ class VoiceState:
)
def __init__(
self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel] = None
self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel] = None
):
self.session_id: Optional[str] = data.get('session_id')
self._update(data, channel)
def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel]):
def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[ConnectableChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
@ -874,7 +874,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
if data:
return Member(data=data, guild=self.guild, state=self._state) # type: ignore
return Member(data=data, guild=self.guild, state=self._state)
async def request_to_speak(self) -> None:
"""|coro|

2
discord/settings.py

@ -385,7 +385,7 @@ class ChannelSettings:
self._channel_id = int(data['channel_id'])
self.collapsed = data.get('collapsed', False)
self.level = try_enum(NotificationLevel, data.get('message_notifications', 3)) # type: ignore
self.level = try_enum(NotificationLevel, data.get('message_notifications', 3))
self.muted = MuteConfig(data.get('muted', False), data.get('mute_config') or {})
@property

18
discord/state.py

@ -527,10 +527,12 @@ class ConnectionState:
def voice_clients(self) -> List[VoiceProtocol]:
return list(self._voice_clients.values())
def _update_voice_state(self, data: GuildVoiceState, channel_id: Optional[int]) -> Tuple[User, VoiceState, VoiceState]:
def _update_voice_state(
self, data: GuildVoiceState, channel_id: Optional[int]
) -> Tuple[Optional[User], VoiceState, VoiceState]:
user_id = int(data['user_id'])
user = self.get_user(user_id)
channel = self._get_private_channel(channel_id)
channel: Optional[Union[DMChannel, GroupChannel]] = self._get_private_channel(channel_id) # type: ignore
try:
# Check if we should remove the voice state from cache
@ -756,7 +758,13 @@ class ConnectionState:
return self.ws.request_chunks([guild_id], query=query, limit=limit, presences=presences, nonce=nonce)
async def query_members(
self, guild: Guild, query: Optional[str], limit: int, user_ids: Optional[List[int]], cache: bool, presences: bool
self,
guild: Guild,
query: Optional[str],
limit: int,
user_ids: Optional[List[Snowflake]],
cache: bool,
presences: bool,
) -> List[Member]:
guild_id = guild.id
request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
@ -830,7 +838,7 @@ class ConnectionState:
) or {'guild_id': guild_data['id']}
for presence in merged_presences:
presence['user'] = {'id': presence['user_id']}
presence['user'] = {'id': presence['user_id']} # type: ignore - :(
voice_states = guild_data.setdefault('voice_states', [])
voice_states.extend(guild_extra.get('voice_states', []))
@ -1660,7 +1668,7 @@ class ConnectionState:
else:
raise RuntimeError('No channels viewable')
requests = {str(channel.id): [[0, 99]]}
requests: Dict[Snowflake, List[List[int]]] = {str(channel.id): [[0, 99]]}
def predicate(data):
return int(data['guild_id']) == guild.id

2
discord/types/user.py

@ -51,4 +51,4 @@ class User(PartialUser, total=False):
bio: str
analytics_token: str
phone: Optional[str]
token: str
token: str

2
discord/user.py

@ -93,7 +93,7 @@ class Note:
def note(self) -> Optional[str]:
"""Returns the note.
There is an alias for this named :attr:`value`.
There is an alias for this called :attr:`value`.
Raises
-------

191
discord/utils.py

@ -62,20 +62,15 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter
import json
import logging
import os
import platform
import random
import re
import string
import subprocess
import sys
import tempfile
from threading import Timer
import types
import warnings
import yarl
from .enums import BrowserEnum
try:
import orjson # type: ignore
@ -1168,7 +1163,11 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
def set_target(
items: Iterable[ApplicationCommand], *, channel: Messageable = None, message: Message = None, user: Snowflake = None
items: Iterable[ApplicationCommand],
*,
channel: Optional[Messageable] = MISSING,
message: Optional[Message] = MISSING,
user: Optional[Snowflake] = MISSING,
) -> None:
"""A helper function to set the target for a list of items.
@ -1181,24 +1180,26 @@ def set_target(
-----------
items: Iterable[:class:`ApplicationCommand`]
A list of items to set the target for.
channel: :class:`Messageable`
channel: Optional[:class:`Messageable`]
The channel to target.
message: :class:`Message`
message: Optional[:class:`Message`]
The message to target.
user: :class:`Snowflake`
user: Optional[:class:`~abc.Snowflake`]
The user to target.
"""
attrs = {
'target_channel': channel,
'target_message': message,
'target_user': user,
}
attrs = {}
if channel is not MISSING:
attrs['target_channel'] = channel
if message is not MISSING:
attrs['target_message'] = message
if user is not MISSING:
attrs['target_user'] = user
for item in items:
for k, v in attrs.items():
if v is not None:
try:
setattr(item, k, v) # type: ignore
setattr(item, k, v)
except AttributeError:
pass
@ -1207,39 +1208,6 @@ def _generate_session_id() -> str:
return ''.join(random.choices(string.ascii_letters + string.digits, k=16))
class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester
def __init__(self, timeout: int, maxsize: int = 0) -> None:
super().__init__(maxsize)
self.timeout = timeout
self.timers: asyncio.Queue = asyncio.Queue()
async def put(self, item: str) -> None:
thread: Timer = Timer(self.timeout, self.expire)
thread.start()
await self.timers.put(thread)
await super().put(item)
async def get(self, block: bool = True) -> str:
if block:
thread = await self.timers.get()
else:
thread = self.timers.get_nowait()
thread.cancel()
if block:
return await super().get()
else:
return self.get_nowait()
def expire(self) -> None:
try:
self._queue.popleft()
except:
pass
def to_list(self) -> List[str]:
return list(self._queue)
class ExpiringString(collections.UserString):
def __init__(self, data: str, timeout: int) -> None:
super().__init__(data)
@ -1263,133 +1231,6 @@ class ExpiringString(collections.UserString):
self._timer.cancel()
class Browser: # Inspired from https://github.com/NoahCardoza/CaptchaHarvester
def __init__(self, browser: Union[BrowserEnum, str] = None) -> None:
if isinstance(browser, (BrowserEnum, type(None))):
try:
browser = self.get_browser(browser)
except Exception:
raise RuntimeError('Could not find browser. Please pass browser path manually.')
if browser is None:
raise RuntimeError('Could not find browser. Please pass browser path manually.')
self.browser: str = browser
self.proc: subprocess.Popen = MISSING
def get_mac_browser(pkg: str, binary: str) -> Optional[os.PathLike]:
import plistlib as plist
pfile: str = f'{os.environ["HOME"]}/Library/Preferences/{pkg}.plist'
if os.path.exists(pfile):
with open(pfile, 'rb') as f:
binary_path: Optional[str] = plist.load(f).get('LastRunAppBundlePath')
if binary_path is not None:
return os.path.join(binary_path, 'Contents', 'MacOS', binary)
def get_windows_browser(browser: str) -> Optional[str]:
import winreg as reg
reg_path: str = f'SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\{browser}.exe'
exe_path: Optional[str] = None
for install_type in reg.HKEY_CURRENT_USER, reg.HKEY_LOCAL_MACHINE:
try:
reg_key: str = reg.OpenKey(install_type, reg_path, 0, reg.KEY_READ)
exe_path: Optional[str] = reg.QueryValue(reg_key, None)
reg_key.Close()
if not os.path.isfile(exe_path):
continue
except reg.WindowsError:
pass
else:
break
return exe_path
def get_linux_browser(browser: str) -> Optional[str]:
from shutil import which as exists
possibilities: List[str] = [browser + channel for channel in ('', '-beta', '-dev', '-developer', '-canary')]
for browser in possibilities:
if exists(browser):
return browser
registry: Dict[str, Dict[str, functools.partial]] = {
'Windows': {
'chrome': functools.partial(get_windows_browser, 'chrome'),
'chromium': functools.partial(get_windows_browser, 'chromium'),
'microsoft-edge': functools.partial(get_windows_browser, 'msedge'),
'opera': functools.partial(get_windows_browser, 'opera'),
},
'Darwin': {
'chrome': functools.partial(get_mac_browser, 'com.google.Chrome', 'Google Chrome'),
'chromium': functools.partial(get_mac_browser, 'org.chromium.Chromium', 'Chromium'),
'microsoft-edge': functools.partial(get_mac_browser, 'com.microsoft.Edge', 'Microsoft Edge'),
'opera': functools.partial(get_mac_browser, 'com.operasoftware.Opera', 'Opera'),
},
'Linux': {
'chrome': functools.partial(get_linux_browser, 'chrome'),
'chromium': functools.partial(get_linux_browser, 'chromium'),
'microsoft-edge': functools.partial(get_linux_browser, 'microsoft-edge'),
'opera': functools.partial(get_linux_browser, 'opera'),
},
}
def get_browser(self, browser: Optional[BrowserEnum] = None) -> Optional[str]:
if browser is not None:
return self.registry.get(platform.system(), {})[browser.value]()
for browser in self.registry.get(platform.system(), {}).values():
browser = browser()
if browser is not None:
return browser
@property
def running(self) -> bool:
try:
return self.proc.poll() is None
except:
return False
def launch(
self,
domain: Optional[str] = None,
server: Tuple[Optional[str], Optional[int]] = (None, None),
width: int = 400,
height: int = 500,
browser_args: List[str] = [],
extensions: Optional[str] = None,
) -> None:
browser_command: List[str] = [self.browser, *browser_args]
if extensions:
browser_command.append(f'--load-extension={extensions}')
browser_command.extend(
(
'--disable-default-apps',
'--no-default-browser-check',
'--no-check-default-browser',
'--no-first-run',
'--ignore-certificate-errors',
'--disable-background-networking',
'--disable-component-update',
'--disable-domain-reliability',
f'--user-data-dir={os.path.join(tempfile.TemporaryDirectory().name, "Profiles")}',
f'--host-rules=MAP {domain} {server[0]}:{server[1]}',
f'--window-size={width},{height}',
f'--app=https://{domain}',
)
)
self.proc = subprocess.Popen(browser_command, stdout=-1, stderr=-1)
def stop(self) -> None:
try:
self.proc.terminate()
except:
pass
async def _get_info(session: ClientSession) -> Tuple[str, str, int]:
for _ in range(3):
try:

7
discord/welcome_screen.py

@ -34,6 +34,7 @@ if TYPE_CHECKING:
from .abc import Snowflake
from .emoji import Emoji
from .guild import Guild
from .invite import PartialInviteGuild
from .state import ConnectionState
from .types.welcome_screen import (
WelcomeScreen as WelcomeScreenPayload,
@ -71,7 +72,7 @@ class WelcomeChannel:
@classmethod
def _from_dict(cls, *, data: WelcomeScreenChannelPayload, state: ConnectionState) -> WelcomeChannel:
channel_id = _get_as_snowflake(data, 'channel_id')
channel_id = int(data['channel_id'])
channel = state.get_channel(channel_id) or Object(id=channel_id)
emoji = None
@ -83,7 +84,7 @@ class WelcomeChannel:
return cls(channel=channel, description=data.get('description', ''), emoji=emoji)
def _to_dict(self) -> WelcomeScreenChannelPayload:
data = {
data: WelcomeScreenChannelPayload = {
'channel_id': self.channel.id,
'description': self.description,
'emoji_id': None,
@ -117,7 +118,7 @@ class WelcomeScreen:
The channels shown on the welcome screen.
"""
def __init__(self, *, data: WelcomeScreenPayload, guild: Guild) -> None:
def __init__(self, *, data: WelcomeScreenPayload, guild: Union[Guild, PartialInviteGuild]) -> None:
self.guild = guild
self._update(data)

Loading…
Cancel
Save