From 3ce00abeae76da3d71ada2597c87df518f0f4f84 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 22 Feb 2022 16:48:41 +1000 Subject: [PATCH] Fix some type-check errors --- discord/channel.py | 4 ++-- discord/flags.py | 2 +- discord/reaction.py | 2 +- discord/sticker.py | 2 +- discord/user.py | 11 +++++++---- discord/widget.py | 12 +++++++----- 6 files changed, 19 insertions(+), 14 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index 728da7544..60f05dc97 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -884,8 +884,8 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') self.position: int = data['position'] - self.bitrate: int = data.get('bitrate') - self.user_limit: int = data.get('user_limit') + self.bitrate: int = data['bitrate'] + self.user_limit: int = data['user_limit'] self._fill_overwrites(data) @property diff --git a/discord/flags.py b/discord/flags.py index 31a594529..697a8986b 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -59,7 +59,7 @@ class flag_value: return self return instance._has_flag(self.flag) - def __set__(self, instance: BF, value: bool) -> None: + def __set__(self, instance: BaseFlags, value: bool) -> None: instance._set_flag(self.flag, value) def __repr__(self): diff --git a/discord/reaction.py b/discord/reaction.py index 9835acb3a..9f92c55c4 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -87,7 +87,7 @@ class Reaction: self.message: Message = message self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) self.count: int = data.get('count', 1) - self.me: bool = data.get('me') + self.me: bool = data['me'] # TODO: typeguard def is_custom_emoji(self) -> bool: diff --git a/discord/sticker.py b/discord/sticker.py index 943e5cb05..4ac1234f1 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -123,7 +123,7 @@ class StickerPack(Hashable): @property def banner(self) -> Optional[Asset]: """:class:`Asset`: The banner asset of the sticker pack.""" - return self._banner and Asset._from_sticker_banner(self._state, self._banner) + return self._banner and Asset._from_sticker_banner(self._state, self._banner) # type: ignore - type-checker thinks _banner could be Literal[0] def __repr__(self) -> str: return f'' diff --git a/discord/user.py b/discord/user.py index 8f6ba1c54..6fa981ee4 100644 --- a/discord/user.py +++ b/discord/user.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Union, Type, TypeVar, TYPE_CHECKING import discord.abc from .asset import Asset @@ -41,7 +41,10 @@ if TYPE_CHECKING: from .message import Message from .state import ConnectionState from .types.channel import DMChannel as DMChannelPayload - from .types.user import User as UserPayload + from .types.user import ( + PartialUser as PartialUserPayload, + User as UserPayload, + ) __all__ = ( @@ -83,7 +86,7 @@ class BaseUser(_UserTag): _accent_colour: Optional[int] _public_flags: int - def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: + def __init__(self, *, state: ConnectionState, data: Union[UserPayload, PartialUserPayload]) -> None: self._state = state self._update(data) @@ -105,7 +108,7 @@ class BaseUser(_UserTag): def __hash__(self) -> int: return self.id >> 22 - def _update(self, data: UserPayload) -> None: + def _update(self, data: Union[UserPayload, PartialUserPayload]) -> None: self.name = data['username'] self.id = int(data['id']) self.discriminator = data['discriminator'] diff --git a/discord/widget.py b/discord/widget.py index 10075caf6..6bc6a3e49 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations +from sqlite3 import connect from typing import Any, List, Optional, TYPE_CHECKING, Union @@ -260,12 +261,13 @@ class Widget: channels = {channel.id: channel for channel in self.channels} for member in data.get('members', []): connected_channel = _get_as_snowflake(member, 'channel_id') - if connected_channel in channels: - connected_channel = channels[connected_channel] # type: ignore - elif connected_channel: - connected_channel = WidgetChannel(id=connected_channel, name='', position=0) + if connected_channel is not None: + if connected_channel in channels: + connected_channel = channels[connected_channel] + else: + connected_channel = WidgetChannel(id=connected_channel, name='', position=0) - self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore + self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) def __str__(self) -> str: return self.json_url