From 3e92196a2bdcf57dfcbd8e277737345af1d35330 Mon Sep 17 00:00:00 2001 From: Nadir Chowdhury Date: Sat, 10 Apr 2021 07:53:24 +0100 Subject: [PATCH] Add typings for audit logs, integrations, and webhooks --- discord/asset.py | 5 ++ discord/audit_logs.py | 22 ++++++-- discord/integrations.py | 74 ++++++++++++++++++------ discord/iterators.py | 85 ++++++++++++++++++---------- discord/types/audit_log.py | 106 +++++++++++++++++++++++++++++++++++ discord/types/integration.py | 76 +++++++++++++++++++++++++ discord/types/webhook.py | 70 +++++++++++++++++++++++ discord/utils.py | 4 +- 8 files changed, 388 insertions(+), 54 deletions(-) create mode 100644 discord/types/audit_log.py create mode 100644 discord/types/integration.py create mode 100644 discord/types/webhook.py diff --git a/discord/asset.py b/discord/asset.py index da1788d1a..13f9336f1 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ import io +from typing import Literal, TYPE_CHECKING from .errors import DiscordException from .errors import InvalidArgument from . import utils @@ -31,6 +32,10 @@ __all__ = ( 'Asset', ) +if TYPE_CHECKING: + ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] + ValidAvatarFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] + VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_AVATAR_FORMATS = VALID_STATIC_FORMATS | {"gif"} diff --git a/discord/audit_logs.py b/discord/audit_logs.py index ea97eb745..fb94d9fac 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -22,6 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import Dict, List, TYPE_CHECKING from . import utils, enums from .object import Object from .permissions import PermissionOverwrite, Permissions @@ -35,6 +38,15 @@ __all__ = ( 'AuditLogEntry', ) +if TYPE_CHECKING: + from .types.audit_log import ( + AuditLogChange as AuditLogChangePayload, + AuditLogEntry as AuditLogEntryPayload, + ) + from .guild import Guild + from .user import User + + def _transform_verification_level(entry, data): return enums.try_enum(enums.VerificationLevel, data) @@ -123,7 +135,7 @@ class AuditLogChanges: 'default_message_notifications': ('default_notifications', _transform_default_notifications), } - def __init__(self, entry, data): + def __init__(self, entry, data: List[AuditLogChangePayload]): self.before = AuditLogDiff() self.after = AuditLogDiff() @@ -177,7 +189,7 @@ class AuditLogChanges: setattr(first, 'roles', []) data = [] - g = entry.guild + g: Guild = entry.guild for e in elem: role_id = int(e['id']) @@ -185,7 +197,7 @@ class AuditLogChanges: if role is None: role = Object(id=role_id) - role.name = e['name'] + role.name = e['name'] # type: ignore data.append(role) @@ -234,7 +246,7 @@ class AuditLogEntry(Hashable): which actions have this field filled out. """ - def __init__(self, *, users, data, guild): + def __init__(self, *, users: Dict[str, User], data: AuditLogEntryPayload, guild: Guild): self._state = guild._state self.guild = guild self._users = users @@ -284,7 +296,7 @@ class AuditLogEntry(Hashable): role = self.guild.get_role(instance_id) if role is None: role = Object(id=instance_id) - role.name = self.extra.get('role_name') + role.name = self.extra.get('role_name') # type: ignore self.extra = role # this key is not present when the above is present, typically. diff --git a/discord/integrations.py b/discord/integrations.py index 34458adcf..8a804bc3a 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -22,7 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import datetime +from typing import Optional, TYPE_CHECKING, overload from .utils import _get_as_snowflake, get, parse_time from .user import User from .errors import InvalidArgument @@ -33,6 +36,14 @@ __all__ = ( 'Integration', ) +if TYPE_CHECKING: + from .types.integration import ( + IntegrationAccount as IntegrationAccountPayload, + Integration as IntegrationPayload, + ) + from .guild import Guild + + class IntegrationAccount: """Represents an integration account. @@ -48,13 +59,14 @@ class IntegrationAccount: __slots__ = ('id', 'name') - def __init__(self, **kwargs): - self.id = kwargs.pop('id') - self.name = kwargs.pop('name') + def __init__(self, data: IntegrationAccountPayload) -> None: + self.id: Optional[int] = _get_as_snowflake(data, 'id') + self.name: str = data.pop('name') - def __repr__(self): + def __repr__(self) -> str: return f'' + class Integration: """Represents a guild integration. @@ -90,20 +102,34 @@ class Integration: An aware UTC datetime representing when the integration was last synced. """ - __slots__ = ('id', '_state', 'guild', 'name', 'enabled', 'type', - 'syncing', 'role', 'expire_behaviour', 'expire_behavior', - 'expire_grace_period', 'synced_at', 'user', 'account', - 'enable_emoticons', '_role_id') - - def __init__(self, *, data, guild): + __slots__ = ( + 'id', + '_state', + 'guild', + 'name', + 'enabled', + 'type', + 'syncing', + 'role', + 'expire_behaviour', + 'expire_behavior', + 'expire_grace_period', + 'synced_at', + 'user', + 'account', + 'enable_emoticons', + '_role_id', + ) + + def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: self.guild = guild self._state = guild._state self._from_data(data) - def __repr__(self): + def __repr__(self) -> str: return f'' - def _from_data(self, integ): + def _from_data(self, integ: IntegrationPayload): self.id = _get_as_snowflake(integ, 'id') self.name = integ['name'] self.type = integ['type'] @@ -118,9 +144,23 @@ class Integration: self.synced_at = parse_time(integ['synced_at']) self.user = User(state=self._state, data=integ['user']) - self.account = IntegrationAccount(**integ['account']) - - async def edit(self, **fields): + self.account = IntegrationAccount(integ['account']) + + @overload + async def edit( + self, + *, + expire_behaviour: Optional[ExpireBehaviour] = ..., + expire_grace_period: Optional[int] = ..., + enable_emoticons: Optional[bool] = ..., + ) -> None: + ... + + @overload + async def edit(self, **fields) -> None: + ... + + async def edit(self, **fields) -> None: """|coro| Edits the integration. @@ -173,7 +213,7 @@ class Integration: self.expire_grace_period = expire_grace_period self.enable_emoticons = enable_emoticons - async def sync(self): + async def sync(self) -> None: """|coro| Syncs the integration. @@ -191,7 +231,7 @@ class Integration: await self._state.http.sync_integration(self.guild.id, self.id) self.synced_at = datetime.datetime.now(datetime.timezone.utc) - async def delete(self): + async def delete(self) -> None: """|coro| Deletes the integration. diff --git a/discord/iterators.py b/discord/iterators.py index d717d83f5..cc082f55c 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -42,6 +42,19 @@ __all__ = ( ) if TYPE_CHECKING: + from .types.audit_log import ( + AuditLog as AuditLogPayload, + ) + from .types.guild import ( + Guild as GuildPayload, + ) + from .types.message import ( + Message as MessagePayload, + ) + from .types.user import ( + PartialUser as PartialUserPayload, + ) + from .member import Member from .user import User from .message import Message @@ -54,6 +67,7 @@ _Func = Callable[[T], Union[OT, Awaitable[OT]]] OLDEST_OBJECT = Object(id=0) + class _AsyncIterator(AsyncIterator[T]): __slots__ = () @@ -105,9 +119,11 @@ class _AsyncIterator(AsyncIterator[T]): except NoMoreItems: raise StopAsyncIteration() + def _identity(x): return x + class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): def __init__(self, iterator, max_size): self.iterator = iterator @@ -128,6 +144,7 @@ class _ChunkedAsyncIterator(_AsyncIterator[List[T]]): n += 1 return ret + class _MappedAsyncIterator(_AsyncIterator[T]): def __init__(self, iterator, func): self.iterator = iterator @@ -138,6 +155,7 @@ class _MappedAsyncIterator(_AsyncIterator[T]): item = await self.iterator.next() return await maybe_coroutine(self.func, item) + class _FilteredAsyncIterator(_AsyncIterator[T]): def __init__(self, iterator, predicate): self.iterator = iterator @@ -157,6 +175,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]): if ret: return item + class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): def __init__(self, message, emoji, limit=100, after=None): self.message = message @@ -187,7 +206,9 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): retrieve = self.limit if self.limit <= 100 else 100 after = self.after.id if self.after else None - data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after) + data: List[PartialUserPayload] = await self.getter( + self.channel_id, self.message.id, self.emoji, retrieve, after=after + ) if data: self.limit -= retrieve @@ -205,6 +226,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): else: await self.users.put(User(state=self.state, data=element)) + class HistoryIterator(_AsyncIterator['Message']): """Iterator for receiving a channel's message history. @@ -239,8 +261,7 @@ class HistoryIterator(_AsyncIterator['Message']): ``True`` if `after` is specified, otherwise ``False``. """ - def __init__(self, messageable, limit, - before=None, after=None, around=None, oldest_first=None): + def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None): if isinstance(before, datetime.datetime): before = Object(id=time_snowflake(before, high=False)) @@ -274,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']): elif self.limit == 101: self.limit = 100 # Thanks discord - self._retrieve_messages = self._retrieve_messages_around_strategy + self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore if self.before and self.after: self._filter = lambda m: self.after.id < int(m['id']) < self.before.id elif self.before: @@ -283,12 +304,12 @@ class HistoryIterator(_AsyncIterator['Message']): self._filter = lambda m: self.after.id < int(m['id']) else: if self.reverse: - self._retrieve_messages = self._retrieve_messages_after_strategy - if (self.before): + self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore + if self.before: self._filter = lambda m: int(m['id']) < self.before.id else: - self._retrieve_messages = self._retrieve_messages_before_strategy - if (self.after and self.after != OLDEST_OBJECT): + self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore + if self.after and self.after != OLDEST_OBJECT: self._filter = lambda m: int(m['id']) > self.after.id async def next(self) -> Message: @@ -318,7 +339,7 @@ class HistoryIterator(_AsyncIterator['Message']): if self._get_retrieve(): data = await self._retrieve_messages(self.retrieve) if len(data) < 100: - self.limit = 0 # terminate the infinite loop + self.limit = 0 # terminate the infinite loop if self.reverse: data = reversed(data) @@ -329,14 +350,14 @@ class HistoryIterator(_AsyncIterator['Message']): for element in data: await self.messages.put(self.state.create_message(channel=channel, data=element)) - async def _retrieve_messages(self, retrieve): + async def _retrieve_messages(self, retrieve) -> List[Message]: """Retrieve messages and update next parameters.""" - pass + raise NotImplementedError async def _retrieve_messages_before_strategy(self, retrieve): """Retrieve messages using before parameter.""" before = self.before.id if self.before else None - data = await self.logs_from(self.channel.id, retrieve, before=before) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) if len(data): if self.limit is not None: self.limit -= retrieve @@ -346,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']): async def _retrieve_messages_after_strategy(self, retrieve): """Retrieve messages using after parameter.""" after = self.after.id if self.after else None - data = await self.logs_from(self.channel.id, retrieve, after=after) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) if len(data): if self.limit is not None: self.limit -= retrieve @@ -357,11 +378,12 @@ class HistoryIterator(_AsyncIterator['Message']): """Retrieve messages using around parameter.""" if self.around: around = self.around.id if self.around else None - data = await self.logs_from(self.channel.id, retrieve, around=around) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) self.around = None return data return [] + class AuditLogIterator(_AsyncIterator['AuditLogEntry']): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): if isinstance(before, datetime.datetime): @@ -369,7 +391,6 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): if isinstance(after, datetime.datetime): after = Object(id=time_snowflake(after, high=True)) - if oldest_first is None: self.reverse = after is not None else: @@ -386,12 +407,10 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): self._users = {} self._state = guild._state - self._filter = None # entry dict -> bool self.entries = asyncio.Queue() - if self.reverse: self._strategy = self._after_strategy if self.before: @@ -403,8 +422,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): async def _before_strategy(self, retrieve): before = self.before.id if self.before else None - data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, - action_type=self.action_type, before=before) + data: AuditLogPayload = await self.request( + self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before + ) entries = data.get('audit_log_entries', []) if len(data) and entries: @@ -415,8 +435,9 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): async def _after_strategy(self, retrieve): after = self.after.id if self.after else None - data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, - action_type=self.action_type, after=after) + data: AuditLogPayload = await self.request( + self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after + ) entries = data.get('audit_log_entries', []) if len(data) and entries: if self.limit is not None: @@ -448,7 +469,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): if self._get_retrieve(): users, data = await self._strategy(self.retrieve) if len(data) < 100: - self.limit = 0 # terminate the infinite loop + self.limit = 0 # terminate the infinite loop if self.reverse: data = reversed(data) @@ -495,6 +516,7 @@ class GuildIterator(_AsyncIterator['Guild']): after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] Object after which all guilds must be. """ + def __init__(self, bot, limit, before=None, after=None): if isinstance(before, datetime.datetime): @@ -514,12 +536,12 @@ class GuildIterator(_AsyncIterator['Guild']): self.guilds = asyncio.Queue() if self.before and self.after: - self._retrieve_guilds = self._retrieve_guilds_before_strategy + self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore self._filter = lambda m: int(m['id']) > self.after.id elif self.after: - self._retrieve_guilds = self._retrieve_guilds_after_strategy + self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore else: - self._retrieve_guilds = self._retrieve_guilds_before_strategy + self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore async def next(self) -> Guild: if self.guilds.empty(): @@ -541,6 +563,7 @@ class GuildIterator(_AsyncIterator['Guild']): def create_guild(self, data): from .guild import Guild + return Guild(state=self.state, data=data) async def fill_guilds(self): @@ -555,14 +578,14 @@ class GuildIterator(_AsyncIterator['Guild']): for element in data: await self.guilds.put(self.create_guild(element)) - async def _retrieve_guilds(self, retrieve): + async def _retrieve_guilds(self, retrieve) -> List[Guild]: """Retrieve guilds and update next parameters.""" - pass + raise NotImplementedError async def _retrieve_guilds_before_strategy(self, retrieve): """Retrieve guilds using before parameter.""" before = self.before.id if self.before else None - data = await self.get_guilds(retrieve, before=before) + data: List[GuildPayload] = await self.get_guilds(retrieve, before=before) if len(data): if self.limit is not None: self.limit -= retrieve @@ -572,13 +595,14 @@ class GuildIterator(_AsyncIterator['Guild']): async def _retrieve_guilds_after_strategy(self, retrieve): """Retrieve guilds using after parameter.""" after = self.after.id if self.after else None - data = await self.get_guilds(retrieve, after=after) + data: List[GuildPayload] = await self.get_guilds(retrieve, after=after) if len(data): if self.limit is not None: self.limit -= retrieve self.after = Object(id=int(data[0]['id'])) return data + class MemberIterator(_AsyncIterator['Member']): def __init__(self, guild, limit=1000, after=None): @@ -620,7 +644,7 @@ class MemberIterator(_AsyncIterator['Member']): return if len(data) < 1000: - self.limit = 0 # terminate loop + self.limit = 0 # terminate loop self.after = Object(id=int(data[-1]['user']['id'])) @@ -629,4 +653,5 @@ class MemberIterator(_AsyncIterator['Member']): def create_member(self, data): from .member import Member + return Member(data=data, guild=self.guild, state=self.state) diff --git a/discord/types/audit_log.py b/discord/types/audit_log.py new file mode 100644 index 000000000..1bbd96d60 --- /dev/null +++ b/discord/types/audit_log.py @@ -0,0 +1,106 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional, TypedDict +from .webhook import Webhook +from .integration import PartialIntegration +from .user import User +from .snowflake import Snowflake + +AuditLogEvent = Literal[ + 1, + 10, + 11, + 12, + 13, + 14, + 15, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 30, + 31, + 32, + 40, + 41, + 42, + 50, + 51, + 52, + 60, + 61, + 62, + 72, + 73, + 74, + 75, + 80, + 81, + 82, +] + + +class AuditLogChange(TypedDict): + key: str + new_value: Any + old_value: Any + + +class AuditEntryInfo(TypedDict): + delete_member_days: str + members_removed: str + channel_id: Snowflake + message_id: Snowflake + count: str + id: Snowflake + type: Literal['0', '1'] + role_name: str + + +class _AuditLogEntryOptional(TypedDict, total=False): + changes: List[AuditLogChange] + options: AuditEntryInfo + reason: str + + +class AuditLogEntry(_AuditLogEntryOptional): + target_id: Optional[str] + user_id: Snowflake + id: Snowflake + action_type: AuditLogEvent + + +class AuditLog(TypedDict): + webhooks: List[Webhook] + users: List[User] + audit_log_entries: List[AuditLogEntry] + integrations: List[PartialIntegration] diff --git a/discord/types/integration.py b/discord/types/integration.py new file mode 100644 index 000000000..2921d1ffe --- /dev/null +++ b/discord/types/integration.py @@ -0,0 +1,76 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Literal, Optional, TypedDict +from .snowflake import Snowflake +from .user import User + + +class _IntegrationApplicationOptional(TypedDict, total=False): + bot: User + + +class IntegrationApplication(_IntegrationApplicationOptional): + id: Snowflake + name: str + icon: Optional[str] + description: str + summary: str + + +class IntegrationAccount(TypedDict): + id: str + name: str + + +IntegrationExpireBehavior = Literal[0, 1] + + +class PartialIntegration(TypedDict): + id: Snowflake + name: str + type: IntegrationType + account: IntegrationAccount + + +class _IntegrationOptional(TypedDict, total=False): + role_id: Snowflake + enable_emoticons: bool + subscriber_count: int + revoked: bool + application: IntegrationApplication + + +IntegrationType = Literal['twitch', 'youtube', 'discord'] + + +class Integration(PartialIntegration, _IntegrationOptional): + enabled: bool + syncing: bool + synced_at: str + user: User + expire_behavior: IntegrationExpireBehavior + expire_grace_period: int diff --git a/discord/types/webhook.py b/discord/types/webhook.py new file mode 100644 index 000000000..851b5ec7a --- /dev/null +++ b/discord/types/webhook.py @@ -0,0 +1,70 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import Literal, Optional, TypedDict +from .snowflake import Snowflake +from .user import User +from .channel import PartialChannel + + +class SourceGuild(TypedDict): + id: int + name: str + icon: str + + +class _WebhookOptional(TypedDict, total=False): + guild_id: Snowflake + user: User + token: str + + +WebhookType = Literal[1, 2] + + +class _FollowerWebhookOptional(TypedDict, total=False): + source_channel: PartialChannel + source_guild: SourceGuild + + +class FollowerWebhook(_FollowerWebhookOptional): + channel_id: Snowflake + webhook_id: Snowflake + + +class PartialWebhook(_WebhookOptional): + id: Snowflake + type: WebhookType + + +class _FullWebhook(TypedDict, total=False): + name: Optional[str] + avatar: Optional[str] + channel_id: Snowflake + application_id: Optional[Snowflake] + + +class Webhook(PartialWebhook, _FullWebhook): + ... diff --git a/discord/utils.py b/discord/utils.py index 2f151be18..7fb1d7ef7 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. import array import asyncio import collections.abc -from typing import Optional, overload +from typing import Any, Optional, overload import unicodedata from base64 import b64encode from bisect import bisect_left @@ -325,7 +325,7 @@ def _unique(iterable): adder = seen.add return [x for x in iterable if not (x in seen or adder(x))] -def _get_as_snowflake(data, key): +def _get_as_snowflake(data: Any, key: str) -> Optional[int]: try: value = data[key] except KeyError: