Browse Source

Fix types in guild.py

pull/7494/head
Josh 3 years ago
committed by Rapptz
parent
commit
285069de08
  1. 148
      discord/guild.py
  2. 6
      discord/http.py
  3. 22
      discord/member.py
  4. 4
      discord/state.py
  5. 6
      discord/types/invite.py

148
discord/guild.py

@ -31,6 +31,7 @@ from typing import (
Any,
AsyncIterator,
ClassVar,
Coroutine,
Dict,
List,
NamedTuple,
@ -91,7 +92,13 @@ MISSING = utils.MISSING
if TYPE_CHECKING:
from .abc import Snowflake, SnowflakeTime
from .types.guild import Ban as BanPayload, Guild as GuildPayload, MFALevel, GuildFeature
from .types.guild import (
Ban as BanPayload,
Guild as GuildPayload,
RolePositionUpdate as RolePositionUpdatePayload,
MFALevel,
GuildFeature,
)
from .types.threads import (
Thread as ThreadPayload,
)
@ -102,6 +109,17 @@ if TYPE_CHECKING:
from .webhook import Webhook
from .state import ConnectionState
from .voice_client import VoiceProtocol
from .types.channel import (
GuildChannel as GuildChannelPayload,
TextChannel as TextChannelPayload,
NewsChannel as NewsChannelPayload,
VoiceChannel as VoiceChannelPayload,
CategoryChannel as CategoryChannelPayload,
StoreChannel as StoreChannelPayload,
StageChannel as StageChannelPayload,
)
from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -289,7 +307,7 @@ class Guild(Hashable):
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
}
def __init__(self, *, data: GuildPayload, state: ConnectionState):
def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None:
self._channels: Dict[int, GuildChannel] = {}
self._members: Dict[int, Member] = {}
self._voice_states: Dict[int, VoiceState] = {}
@ -353,7 +371,7 @@ class Guild(Hashable):
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]:
user_id = int(data['user_id'])
channel = self.get_channel(channel_id)
channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel
try:
# check if we should remove the voice state from cache
if channel is None:
@ -408,14 +426,14 @@ class Guild(Hashable):
if member_count is not None:
self._member_count: int = member_count
self.name: str = guild.get('name')
self.name: str = guild.get('name', '')
self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region'))
self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level'))
self.default_notifications: NotificationLevel = try_enum(
NotificationLevel, guild.get('default_message_notifications')
)
self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0))
self.afk_timeout: int = guild.get('afk_timeout')
self.afk_timeout: int = guild.get('afk_timeout', 0)
self._icon: Optional[str] = guild.get('icon')
self._banner: Optional[str] = guild.get('banner')
self.unavailable: bool = guild.get('unavailable', False)
@ -426,7 +444,7 @@ class Guild(Hashable):
role = Role(guild=self, data=r, state=state)
self._roles[role.id] = role
self.mfa_level: MFALevel = guild.get('mfa_level')
self.mfa_level: MFALevel = guild.get('mfa_level', 0)
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', [])))
self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))
@ -455,7 +473,7 @@ class Guild(Hashable):
cache_joined = self._state.member_cache_flags.joined
self_id = self._state.self_id
for mdata in guild.get('members', []):
member = Member(data=mdata, guild=self, state=state)
member = Member(data=mdata, guild=self, state=state) # type: ignore - Members will have the 'user' key in this scenario
if cache_joined or member.id == self_id:
self._add_member(member)
@ -548,7 +566,7 @@ class Guild(Hashable):
""":class:`Member`: Similar to :attr:`Client.user` except an instance of :class:`Member`.
This is essentially used to get the member version of yourself.
"""
self_id = self._state.user.id
self_id = self._state.user.id # type: ignore - state.user won't be None if we're logged in
# The self member is *always* cached
return self.get_member(self_id) # type: ignore
@ -976,6 +994,94 @@ class Guild(Hashable):
return utils.find(pred, members)
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, TextChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, VoiceChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.stage_voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StageChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.category],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, CategoryChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.news],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, NewsChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.store],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StoreChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
...
@overload
def _create_channel(
self,
name: str,
channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
...
def _create_channel(
self,
name: str,
@ -983,7 +1089,7 @@ class Guild(Hashable):
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[Snowflake] = None,
**options: Any,
):
) -> Coroutine[Any, Any, GuildChannelPayload]:
if overwrites is MISSING:
overwrites = {}
elif not isinstance(overwrites, dict):
@ -1829,11 +1935,11 @@ class Guild(Hashable):
if ch_type in (ChannelType.group, ChannelType.private):
raise InvalidData('Channel ID resolved to a private channel')
guild_id = int(data['guild_id'])
guild_id = int(data['guild_id']) # type: ignore - channel won't be a private channel
if self.id != guild_id:
raise InvalidData('Guild ID resolved to a different guild')
channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore
channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore - channel won't be a private channel
return channel
async def bans(self) -> List[BanEntry]:
@ -1977,13 +2083,16 @@ class Guild(Hashable):
data = await self._state.http.guild_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]
async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int:
async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> Optional[int]:
"""|coro|
Similar to :meth:`prune_members` except instead of actually
pruning members, it returns how many members it would prune
from the guild had it been called.
.. versionchanged:: 2.0
The returned value can be ``None``.
Parameters
-----------
days: :class:`int`
@ -2005,7 +2114,7 @@ class Guild(Hashable):
Returns
---------
:class:`int`
Optional[:class:`int`]
The number of members estimated to be pruned.
"""
@ -2077,7 +2186,7 @@ class Guild(Hashable):
return Template(state=self._state, data=data)
async def create_integration(self, *, type: str, id: int) -> None:
async def create_integration(self, *, type: IntegrationType, id: int) -> None:
"""|coro|
Attaches an integration to the guild.
@ -2380,7 +2489,7 @@ class Guild(Hashable):
img = utils._bytes_to_base64_data(image)
if roles:
role_ids = [role.id for role in roles]
role_ids: SnowflakeList = [role.id for role in roles]
else:
role_ids = []
@ -2612,10 +2721,10 @@ class Guild(Hashable):
if not isinstance(positions, dict):
raise InvalidArgument('positions parameter expects a dict.')
role_positions: List[Dict[str, Any]] = []
role_positions = []
for role, position in positions.items():
payload = {'id': role.id, 'position': position}
payload: RolePositionUpdatePayload = {'id': role.id, 'position': position}
role_positions.append(payload)
@ -2754,7 +2863,7 @@ class Guild(Hashable):
payload['max_uses'] = 0
payload['max_age'] = 0
payload['uses'] = payload.get('uses', 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel)
return Invite(state=self._state, data=payload, guild=self, channel=channel) # type: ignore - we're faking a payload here
async def audit_logs(
self,
@ -2990,7 +3099,8 @@ class Guild(Hashable):
raise ClientException('Intents.members must be enabled to use this.')
if not self._state.is_guild_evicted(self):
return await self._state.chunk_guild(self, cache=cache)
await self._state.chunk_guild(self, cache=cache)
return
async def query_members(
self,

6
discord/http.py

@ -1123,7 +1123,7 @@ class HTTPClient:
def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]:
return self.request(Route('GET', '/guilds/{guild_id}/templates', guild_id=guild_id))
def create_template(self, guild_id: Snowflake, payload: template.CreateTemplate) -> Response[template.Template]:
def create_template(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[template.Template]:
return self.request(Route('POST', '/guilds/{guild_id}/templates', guild_id=guild_id), json=payload)
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
@ -1229,7 +1229,7 @@ class HTTPClient:
)
def create_guild_sticker(
self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str
self, guild_id: Snowflake, payload: Dict[str, Any], file: File, reason: Optional[str]
) -> Response[sticker.GuildSticker]:
initial_bytes = file.fp.read(16)
@ -1293,7 +1293,7 @@ class HTTPClient:
self,
guild_id: Snowflake,
name: str,
image: bytes,
image: str,
*,
roles: Optional[SnowflakeList] = None,
reason: Optional[str] = None,

22
discord/member.py

@ -65,7 +65,10 @@ if TYPE_CHECKING:
from .state import ConnectionState
from .message import Message
from .role import Role
from .types.voice import VoiceState as VoiceStatePayload
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceState as VoiceStatePayload,
)
VocalGuildChannel = Union[VoiceChannel, StageChannel]
@ -127,11 +130,13 @@ class VoiceState:
'suppress',
)
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get('session_id')
def __init__(
self, *, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel] = None
):
self.session_id: Optional[str] = data.get('session_id')
self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
def _update(self, data: Union[VoiceStatePayload, GuildVoiceStatePayload], channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
@ -748,11 +753,13 @@ class Member(discord.abc.Messageable, _UserTag):
payload['mute'] = mute
if suppress is not MISSING:
voice_state_payload = {
'channel_id': self.voice.channel.id,
voice_state_payload: Dict[str, Any] = {
'suppress': suppress,
}
if self.voice is not None and self.voice.channel is not None:
voice_state_payload['channel_id'] = self.voice.channel.id
if suppress or self.bot:
voice_state_payload['request_to_speak_timestamp'] = None
@ -804,6 +811,9 @@ class Member(discord.abc.Messageable, _UserTag):
HTTPException
The operation failed.
"""
if self.voice is None or self.voice.channel is None:
raise RuntimeError('Cannot request to speak while not connected to a voice channel.')
payload = {
'channel_id': self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),

4
discord/state.py

@ -470,7 +470,9 @@ class ConnectionState:
ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool):
async def query_members(
self, guild: Guild, query: Optional[str], limit: int, user_ids: Optional[List[int]], cache: bool, presences: bool
) -> List[Member]:
guild_id = guild.id
ws = self._get_websocket(guild_id)
if ws is None:

6
discord/types/invite.py

@ -52,7 +52,11 @@ class _InviteMetadata(TypedDict, total=False):
expires_at: Optional[str]
class VanityInvite(_InviteMetadata):
class _VanityInviteOptional(_InviteMetadata, total=False):
revoked: bool
class VanityInvite(_VanityInviteOptional):
code: Optional[str]

Loading…
Cancel
Save