Browse Source

Fix types in guild.py

pull/7494/head
Josh 3 years ago
committed by Rapptz
parent
commit
285069de08
  1. 148
      discord/guild.py
  2. 6
      discord/http.py
  3. 22
      discord/member.py
  4. 4
      discord/state.py
  5. 6
      discord/types/invite.py

148
discord/guild.py

@ -31,6 +31,7 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
ClassVar, ClassVar,
Coroutine,
Dict, Dict,
List, List,
NamedTuple, NamedTuple,
@ -91,7 +92,13 @@ MISSING = utils.MISSING
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import Snowflake, SnowflakeTime from .abc import Snowflake, SnowflakeTime
from .types.guild import Ban as BanPayload, Guild as GuildPayload, MFALevel, GuildFeature from .types.guild import (
Ban as BanPayload,
Guild as GuildPayload,
RolePositionUpdate as RolePositionUpdatePayload,
MFALevel,
GuildFeature,
)
from .types.threads import ( from .types.threads import (
Thread as ThreadPayload, Thread as ThreadPayload,
) )
@ -102,6 +109,17 @@ if TYPE_CHECKING:
from .webhook import Webhook from .webhook import Webhook
from .state import ConnectionState from .state import ConnectionState
from .voice_client import VoiceProtocol from .voice_client import VoiceProtocol
from .types.channel import (
GuildChannel as GuildChannelPayload,
TextChannel as TextChannelPayload,
NewsChannel as NewsChannelPayload,
VoiceChannel as VoiceChannelPayload,
CategoryChannel as CategoryChannelPayload,
StoreChannel as StoreChannelPayload,
StageChannel as StageChannelPayload,
)
from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList
VocalGuildChannel = Union[VoiceChannel, StageChannel] VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel] GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -289,7 +307,7 @@ class Guild(Hashable):
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600), 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
} }
def __init__(self, *, data: GuildPayload, state: ConnectionState): def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None:
self._channels: Dict[int, GuildChannel] = {} self._channels: Dict[int, GuildChannel] = {}
self._members: Dict[int, Member] = {} self._members: Dict[int, Member] = {}
self._voice_states: Dict[int, VoiceState] = {} self._voice_states: Dict[int, VoiceState] = {}
@ -353,7 +371,7 @@ class Guild(Hashable):
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]:
user_id = int(data['user_id']) user_id = int(data['user_id'])
channel = self.get_channel(channel_id) channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel
try: try:
# check if we should remove the voice state from cache # check if we should remove the voice state from cache
if channel is None: if channel is None:
@ -408,14 +426,14 @@ class Guild(Hashable):
if member_count is not None: if member_count is not None:
self._member_count: int = member_count self._member_count: int = member_count
self.name: str = guild.get('name') self.name: str = guild.get('name', '')
self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region')) self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region'))
self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level')) self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level'))
self.default_notifications: NotificationLevel = try_enum( self.default_notifications: NotificationLevel = try_enum(
NotificationLevel, guild.get('default_message_notifications') NotificationLevel, guild.get('default_message_notifications')
) )
self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0)) self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0))
self.afk_timeout: int = guild.get('afk_timeout') self.afk_timeout: int = guild.get('afk_timeout', 0)
self._icon: Optional[str] = guild.get('icon') self._icon: Optional[str] = guild.get('icon')
self._banner: Optional[str] = guild.get('banner') self._banner: Optional[str] = guild.get('banner')
self.unavailable: bool = guild.get('unavailable', False) self.unavailable: bool = guild.get('unavailable', False)
@ -426,7 +444,7 @@ class Guild(Hashable):
role = Role(guild=self, data=r, state=state) role = Role(guild=self, data=r, state=state)
self._roles[role.id] = role self._roles[role.id] = role
self.mfa_level: MFALevel = guild.get('mfa_level') self.mfa_level: MFALevel = guild.get('mfa_level', 0)
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', []))) self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', [])))
self.stickers: Tuple[GuildSticker, ...] = tuple( self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get('stickers', [])) map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))
@ -455,7 +473,7 @@ class Guild(Hashable):
cache_joined = self._state.member_cache_flags.joined cache_joined = self._state.member_cache_flags.joined
self_id = self._state.self_id self_id = self._state.self_id
for mdata in guild.get('members', []): for mdata in guild.get('members', []):
member = Member(data=mdata, guild=self, state=state) member = Member(data=mdata, guild=self, state=state) # type: ignore - Members will have the 'user' key in this scenario
if cache_joined or member.id == self_id: if cache_joined or member.id == self_id:
self._add_member(member) self._add_member(member)
@ -548,7 +566,7 @@ class Guild(Hashable):
""":class:`Member`: Similar to :attr:`Client.user` except an instance of :class:`Member`. """:class:`Member`: Similar to :attr:`Client.user` except an instance of :class:`Member`.
This is essentially used to get the member version of yourself. This is essentially used to get the member version of yourself.
""" """
self_id = self._state.user.id self_id = self._state.user.id # type: ignore - state.user won't be None if we're logged in
# The self member is *always* cached # The self member is *always* cached
return self.get_member(self_id) # type: ignore return self.get_member(self_id) # type: ignore
@ -976,6 +994,94 @@ class Guild(Hashable):
return utils.find(pred, members) return utils.find(pred, members)
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, TextChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, VoiceChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.stage_voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StageChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.category],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, CategoryChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.news],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, NewsChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.store],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StoreChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
...
def _create_channel( def _create_channel(
self, self,
name: str, name: str,
@ -983,7 +1089,7 @@ class Guild(Hashable):
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[Snowflake] = None, category: Optional[Snowflake] = None,
**options: Any, **options: Any,
): ) -> Coroutine[Any, Any, GuildChannelPayload]:
if overwrites is MISSING: if overwrites is MISSING:
overwrites = {} overwrites = {}
elif not isinstance(overwrites, dict): elif not isinstance(overwrites, dict):
@ -1829,11 +1935,11 @@ class Guild(Hashable):
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
raise InvalidData('Channel ID resolved to a private channel') raise InvalidData('Channel ID resolved to a private channel')
guild_id = int(data['guild_id']) guild_id = int(data['guild_id']) # type: ignore - channel won't be a private channel
if self.id != guild_id: if self.id != guild_id:
raise InvalidData('Guild ID resolved to a different guild') raise InvalidData('Guild ID resolved to a different guild')
channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore - channel won't be a private channel
return channel return channel
async def bans(self) -> List[BanEntry]: async def bans(self) -> List[BanEntry]:
@ -1977,13 +2083,16 @@ class Guild(Hashable):
data = await self._state.http.guild_webhooks(self.id) data = await self._state.http.guild_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data] return [Webhook.from_state(d, state=self._state) for d in data]
async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int: async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> Optional[int]:
"""|coro| """|coro|
Similar to :meth:`prune_members` except instead of actually Similar to :meth:`prune_members` except instead of actually
pruning members, it returns how many members it would prune pruning members, it returns how many members it would prune
from the guild had it been called. from the guild had it been called.
.. versionchanged:: 2.0
The returned value can be ``None``.
Parameters Parameters
----------- -----------
days: :class:`int` days: :class:`int`
@ -2005,7 +2114,7 @@ class Guild(Hashable):
Returns Returns
--------- ---------
:class:`int` Optional[:class:`int`]
The number of members estimated to be pruned. The number of members estimated to be pruned.
""" """
@ -2077,7 +2186,7 @@ class Guild(Hashable):
return Template(state=self._state, data=data) return Template(state=self._state, data=data)
async def create_integration(self, *, type: str, id: int) -> None: async def create_integration(self, *, type: IntegrationType, id: int) -> None:
"""|coro| """|coro|
Attaches an integration to the guild. Attaches an integration to the guild.
@ -2380,7 +2489,7 @@ class Guild(Hashable):
img = utils._bytes_to_base64_data(image) img = utils._bytes_to_base64_data(image)
if roles: if roles:
role_ids = [role.id for role in roles] role_ids: SnowflakeList = [role.id for role in roles]
else: else:
role_ids = [] role_ids = []
@ -2612,10 +2721,10 @@ class Guild(Hashable):
if not isinstance(positions, dict): if not isinstance(positions, dict):
raise InvalidArgument('positions parameter expects a dict.') raise InvalidArgument('positions parameter expects a dict.')
role_positions: List[Dict[str, Any]] = [] role_positions = []
for role, position in positions.items(): for role, position in positions.items():
payload = {'id': role.id, 'position': position} payload: RolePositionUpdatePayload = {'id': role.id, 'position': position}
role_positions.append(payload) role_positions.append(payload)
@ -2754,7 +2863,7 @@ class Guild(Hashable):
payload['max_uses'] = 0 payload['max_uses'] = 0
payload['max_age'] = 0 payload['max_age'] = 0
payload['uses'] = payload.get('uses', 0) payload['uses'] = payload.get('uses', 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel) return Invite(state=self._state, data=payload, guild=self, channel=channel) # type: ignore - we're faking a payload here
async def audit_logs( async def audit_logs(
self, self,
@ -2990,7 +3099,8 @@ class Guild(Hashable):
raise ClientException('Intents.members must be enabled to use this.') raise ClientException('Intents.members must be enabled to use this.')
if not self._state.is_guild_evicted(self): if not self._state.is_guild_evicted(self):
return await self._state.chunk_guild(self, cache=cache) await self._state.chunk_guild(self, cache=cache)
return
async def query_members( async def query_members(
self, self,

6
discord/http.py

@ -1123,7 +1123,7 @@ class HTTPClient:
def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]: def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]:
return self.request(Route('GET', '/guilds/{guild_id}/templates', guild_id=guild_id)) return self.request(Route('GET', '/guilds/{guild_id}/templates', guild_id=guild_id))
def create_template(self, guild_id: Snowflake, payload: template.CreateTemplate) -> Response[template.Template]: def create_template(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[template.Template]:
return self.request(Route('POST', '/guilds/{guild_id}/templates', guild_id=guild_id), json=payload) return self.request(Route('POST', '/guilds/{guild_id}/templates', guild_id=guild_id), json=payload)
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
@ -1229,7 +1229,7 @@ class HTTPClient:
) )
def create_guild_sticker( def create_guild_sticker(
self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str self, guild_id: Snowflake, payload: Dict[str, Any], file: File, reason: Optional[str]
) -> Response[sticker.GuildSticker]: ) -> Response[sticker.GuildSticker]:
initial_bytes = file.fp.read(16) initial_bytes = file.fp.read(16)
@ -1293,7 +1293,7 @@ class HTTPClient:
self, self,
guild_id: Snowflake, guild_id: Snowflake,
name: str, name: str,
image: bytes, image: str,
*, *,
roles: Optional[SnowflakeList] = None, roles: Optional[SnowflakeList] = None,
reason: Optional[str] = None, reason: Optional[str] = None,

22
discord/member.py

@ -65,7 +65,10 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from .message import Message from .message import Message
from .role import Role from .role import Role
from .types.voice import VoiceState as VoiceStatePayload from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceState as VoiceStatePayload,
)
VocalGuildChannel = Union[VoiceChannel, StageChannel] VocalGuildChannel = Union[VoiceChannel, StageChannel]
@ -127,11 +130,13 @@ class VoiceState:
'suppress', 'suppress',
) )
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): def __init__(
self.session_id: str = data.get('session_id') self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel] = None
):
self.session_id: Optional[str] = data.get('session_id')
self._update(data, channel) self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False) self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False) self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False) self.self_stream: bool = data.get('self_stream', False)
@ -748,11 +753,13 @@ class Member(discord.abc.Messageable, _UserTag):
payload['mute'] = mute payload['mute'] = mute
if suppress is not MISSING: if suppress is not MISSING:
voice_state_payload = { voice_state_payload: Dict[str, Any] = {
'channel_id': self.voice.channel.id,
'suppress': suppress, 'suppress': suppress,
} }
if self.voice is not None and self.voice.channel is not None:
voice_state_payload['channel_id'] = self.voice.channel.id
if suppress or self.bot: if suppress or self.bot:
voice_state_payload['request_to_speak_timestamp'] = None voice_state_payload['request_to_speak_timestamp'] = None
@ -804,6 +811,9 @@ class Member(discord.abc.Messageable, _UserTag):
HTTPException HTTPException
The operation failed. The operation failed.
""" """
if self.voice is None or self.voice.channel is None:
raise RuntimeError('Cannot request to speak while not connected to a voice channel.')
payload = { payload = {
'channel_id': self.voice.channel.id, 'channel_id': self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(), 'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),

4
discord/state.py

@ -470,7 +470,9 @@ class ConnectionState:
ws = self._get_websocket(guild_id) # This is ignored upstream ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): async def query_members(
self, guild: Guild, query: Optional[str], limit: int, user_ids: Optional[List[int]], cache: bool, presences: bool
) -> List[Member]:
guild_id = guild.id guild_id = guild.id
ws = self._get_websocket(guild_id) ws = self._get_websocket(guild_id)
if ws is None: if ws is None:

6
discord/types/invite.py

@ -52,7 +52,11 @@ class _InviteMetadata(TypedDict, total=False):
expires_at: Optional[str] expires_at: Optional[str]
class VanityInvite(_InviteMetadata): class _VanityInviteOptional(_InviteMetadata, total=False):
revoked: bool
class VanityInvite(_VanityInviteOptional):
code: Optional[str] code: Optional[str]

Loading…
Cancel
Save