Browse Source

Implement new read state capabilities

pull/10109/head
dolfies 2 years ago
parent
commit
4039345c15
  1. 4
      discord/abc.py
  2. 40
      discord/channel.py
  3. 73
      discord/flags.py
  4. 13
      discord/http.py
  5. 4
      discord/message.py
  6. 105
      discord/read_state.py
  7. 6
      discord/state.py
  8. 11
      discord/threads.py
  9. 2
      discord/types/gateway.py
  10. 2
      discord/types/read_state.py
  11. 5
      docs/api.rst

4
discord/abc.py

@ -1986,7 +1986,7 @@ class Messageable:
Acking the channel failed.
"""
channel = await self._get_channel()
await self._state.http.ack_message(channel.id, channel.last_message_id or utils.time_snowflake(utils.utcnow()))
await channel.read_state.ack(channel.last_message_id or utils.time_snowflake(utils.utcnow()))
async def unack(self, *, mention_count: Optional[int] = None) -> None:
"""|coro|
@ -2007,7 +2007,7 @@ class Messageable:
Unacking the channel failed.
"""
channel = await self._get_channel()
await self._state.http.ack_message(channel.id, 0, manual=True, mention_count=mention_count)
await channel.read_state.ack(0, manual=True, mention_count=mention_count)
async def ack_pins(self) -> None:
"""|coro|

40
discord/channel.py

@ -358,6 +358,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
return self.read_state.badge_count
@property
def last_viewed_timestamp(self) -> datetime.date:
""":class:`datetime.date`: When the channel was last viewed.
.. versionadded:: 2.1
"""
return self.read_state.last_viewed # type: ignore
@overload
async def edit(self) -> Optional[TextChannel]:
...
@ -1138,6 +1146,14 @@ class VocalGuildChannel(discord.abc.Messageable, discord.abc.Connectable, discor
"""
return self.read_state.badge_count
@property
def last_viewed_timestamp(self) -> datetime.date:
""":class:`datetime.date`: When the channel was last viewed.
.. versionadded:: 2.1
"""
return self.read_state.last_viewed # type: ignore
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
@ -3200,6 +3216,14 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, discord.abc.Pr
"""
return self.read_state.badge_count
@property
def last_viewed_timestamp(self) -> datetime.date:
""":class:`datetime.date`: When the channel was last viewed.
.. versionadded:: 2.1
"""
return self.read_state.last_viewed # type: ignore
@property
def requested_at(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: Returns the message request's creation time in UTC, if applicable.
@ -3662,6 +3686,14 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, discord.abc
"""
return self.read_state.badge_count
@property
def last_viewed_timestamp(self) -> datetime.date:
""":class:`datetime.date`: When the channel was last viewed.
.. versionadded:: 2.1
"""
return self.read_state.last_viewed # type: ignore
def permissions_for(self, obj: Snowflake, /) -> Permissions:
"""Handles permission resolution for a :class:`User`.
@ -4010,6 +4042,14 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id)
@property
def read_state(self) -> ReadState:
""":class:`ReadState`: Returns the read state for this channel.
.. versionadded:: 2.1
"""
return self._state.get_read_state(self.id)
def permissions_for(self, obj: Any = None, /) -> Permissions:
"""Handles permission resolution for a :class:`User`.

73
discord/flags.py

@ -57,6 +57,7 @@ __all__ = (
'OnboardingProgressFlags',
'AutoModPresets',
'MemberFlags',
'ReadStateFlags',
)
BF = TypeVar('BF', bound='BaseFlags')
@ -2289,6 +2290,8 @@ class AutoModPresets(ArrayFlags):
rather than using this raw value.
"""
__slots__ = ()
@classmethod
def all(cls: Type[Self]) -> Self:
"""A factory method that creates a :class:`AutoModPresets` with everything enabled."""
@ -2325,8 +2328,6 @@ class AutoModPresets(ArrayFlags):
class MemberFlags(BaseFlags):
r"""Wraps up the Discord Guild Member flags
.. versionadded:: 2.0
.. container:: operations
.. describe:: x == y
@ -2366,6 +2367,7 @@ class MemberFlags(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.0
Attributes
-----------
@ -2374,6 +2376,8 @@ class MemberFlags(BaseFlags):
rather than using this raw value.
"""
__slots__ = ()
@flag_value
def did_rejoin(self):
""":class:`bool`: Returns ``True`` if the member left and rejoined the :attr:`~discord.Member.guild`."""
@ -2393,3 +2397,68 @@ class MemberFlags(BaseFlags):
def started_onboarding(self):
""":class:`bool`: Returns ``True`` if the member has started onboarding."""
return 1 << 3
@fill_with_flags()
class ReadStateFlags(BaseFlags):
r"""Wraps up the Discord read state flags.
.. container:: operations
.. describe:: x == y
Checks if two ReadStateFlags are equal.
.. describe:: x != y
Checks if two ReadStateFlags are not equal.
.. describe:: x | y, x |= y
Returns a ReadStateFlags instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns a ReadStateFlags instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns a ReadStateFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns a ReadStateFlags instance with all flags inverted from x.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.1
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
__slots__ = ()
@flag_value
def guild_channel(self):
""":class:`bool`: Returns ``True`` if the read state is for a guild channel."""
return 1 << 0
@flag_value
def thread(self):
""":class:`bool`: Returns ``True`` if the read state is for a thread."""
return 1 << 1

13
discord/http.py

@ -1163,7 +1163,14 @@ class HTTPClient:
return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id))
async def ack_message(
self, channel_id: Snowflake, message_id: Snowflake, *, manual: bool = False, mention_count: Optional[int] = None
self,
channel_id: Snowflake,
message_id: Snowflake,
*,
manual: bool = False,
mention_count: Optional[int] = None,
flags: Optional[int] = None,
last_viewed: Optional[int] = None,
) -> None:
r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id)
payload = {}
@ -1173,6 +1180,10 @@ class HTTPClient:
payload['token'] = self.ack_token
if mention_count is not None:
payload['mention_count'] = mention_count
if flags is not None:
payload['flags'] = flags
if last_viewed is not None:
payload['last_viewed'] = last_viewed
data: read_state.AcknowledgementToken = await self.request(r, json=payload)
self.ack_token = data.get('token') if data else None

4
discord/message.py

@ -1144,7 +1144,7 @@ class PartialMessage(Hashable):
HTTPException
Acking failed.
"""
await self._state.http.ack_message(self.channel.id, self.id, manual=manual, mention_count=mention_count)
await self.channel.read_state.ack(self.id, manual=manual, mention_count=mention_count)
async def unack(self, *, mention_count: Optional[int] = None) -> None:
"""|coro|
@ -1164,7 +1164,7 @@ class PartialMessage(Hashable):
HTTPException
Unacking failed.
"""
await self._state.http.ack_message(self.channel.id, self.id - 1, manual=True, mention_count=mention_count)
await self.channel.read_state.ack(self.id - 1, manual=True, mention_count=mention_count)
@overload
async def reply(

105
discord/read_state.py

@ -24,10 +24,14 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from datetime import date
from typing import TYPE_CHECKING, Optional, Union
from .channel import PartialMessageable
from .enums import ReadStateType, try_enum
from .utils import parse_time
from .flags import ReadStateFlags
from .threads import Thread
from .utils import DISCORD_EPOCH, MISSING, parse_time
if TYPE_CHECKING:
from datetime import datetime
@ -37,8 +41,8 @@ if TYPE_CHECKING:
from .abc import MessageableChannel
from .guild import Guild
from .state import ConnectionState
from .user import ClientUser
from .types.read_state import ReadState as ReadStatePayload
from .user import ClientUser
# fmt: off
__all__ = (
@ -81,6 +85,8 @@ class ReadState:
When the channel's pins were last acknowledged.
badge_count: :class:`int`
The number of badges in the read state (e.g. mentions).
last_viewed: Optional[:class:`datetime.date`]
When the resource was last viewed. Only tracked for read states of type :attr:`ReadStateType.channel`.
"""
__slots__ = (
@ -101,14 +107,18 @@ class ReadState:
self.id: int = int(data['id'])
self.type: ReadStateType = try_enum(ReadStateType, data.get('read_state_type', 0))
self._last_entity_id: Optional[int] = None
self._flags: int = 0
self.last_viewed: Optional[date] = self.unpack_last_viewed(0) if self.type == ReadStateType.channel else None
self._update(data)
def _update(self, data: ReadStatePayload):
self.last_acked_id: int = int(data.get('last_acked_id', data.get('last_message_id', 0)))
self.acked_pin_timestamp: Optional[datetime] = parse_time(data.get('last_pin_timestamp'))
self.badge_count: int = int(data.get('badge_count', data.get('mention_count', 0)))
self.last_viewed: Optional[datetime] = parse_time(data.get('last_viewed'))
self._flags: int = data.get('flags') or 0
if 'flags' in data and data['flags'] is not None:
self._flags = data['flags']
if 'last_viewed' in data and data['last_viewed']:
self.last_viewed = self.unpack_last_viewed(data['last_viewed'])
def __eq__(self, other: object) -> bool:
if isinstance(other, ReadState):
@ -130,11 +140,28 @@ class ReadState:
self.id = id
self.type = type
self._last_entity_id = None
self._flags = 0
self.last_viewed = cls.unpack_last_viewed(0) if type == ReadStateType.channel else None
self.last_acked_id = 0
self.acked_pin_timestamp = None
self.badge_count = 0
return self
@staticmethod
def unpack_last_viewed(last_viewed: int) -> date:
# last_viewed is days since the Discord epoch
return date.fromtimestamp(DISCORD_EPOCH / 1000 + last_viewed * 86400)
@staticmethod
def pack_last_viewed(last_viewed: date) -> int:
# We always round up
return int((last_viewed - date.fromtimestamp(DISCORD_EPOCH / 1000)).total_seconds() / 86400 + 0.5)
@property
def flags(self) -> ReadStateFlags:
""":class:`ReadStateFlags`: The read state's flags."""
return ReadStateFlags._from_value(self._flags)
@property
def resource(self) -> Optional[Union[ClientUser, Guild, MessageableChannel]]:
"""Optional[Union[:class:`ClientUser`, :class:`Guild`, :class:`TextChannel`, :class:`StageChannel`, :class:`VoiceChannel`, :class:`Thread`, :class:`DMChannel`, :class:`GroupChannel`, :class:`PartialMessageable`]]: The entity associated with the read state."""
@ -170,6 +197,76 @@ class ReadState:
if self.resource and hasattr(self.resource, 'last_pin_timestamp'):
return self.resource.last_pin_timestamp # type: ignore
async def ack(
self,
entity_id: int,
*,
manual: bool = False,
mention_count: Optional[int] = None,
last_viewed: Optional[date] = MISSING,
) -> None:
"""|coro|
Updates the read state. This is a purposefully low-level function.
Parameters
-----------
entity_id: :class:`int`
The ID of the entity to set the read state to.
manual: :class:`bool`
Whether the read state was manually set by the user.
Only for read states of type :attr:`ReadStateType.channel`.
mention_count: Optional[:class:`int`]
The number of mentions to set the read state to. Only applicable for
manual acknowledgements. Only for read states of type :attr:`ReadStateType.channel`.
last_viewed: Optional[:class:`datetime.date`]
The last day the user viewed the channel. Defaults to today for non-manual acknowledgements.
Only for read states of type :attr:`ReadStateType.channel`.
Raises
-------
ValueError
Invalid parameters were passed.
HTTPException
Updating the read state failed.
"""
state = self._state
if self.type == ReadStateType.channel:
flags = None
channel: MessageableChannel = self.resource # type: ignore
if not isinstance(channel, PartialMessageable):
# Read state flags are kept accurate by the client 😭
flags = ReadStateFlags()
if isinstance(channel, Thread):
flags.thread = True
elif channel.guild:
flags.guild_channel = True
if flags == self.flags:
flags = None
if not manual and last_viewed is MISSING:
last_viewed = date.today()
await state.http.ack_message(
self.id,
entity_id,
manual=manual,
mention_count=mention_count,
flags=flags.value if flags else None,
last_viewed=self.pack_last_viewed(last_viewed) if last_viewed else None,
)
return
if manual or mention_count is not None or last_viewed:
raise ValueError('Extended read state parameters are only valid for channel read states')
if self.type in (ReadStateType.scheduled_events, ReadStateType.guild_home, ReadStateType.onboarding):
await state.http.ack_guild_feature(self.id, self.type.value, entity_id)
elif self.type == ReadStateType.notification_center:
await state.http.ack_user_feature(self.type.value, entity_id)
async def delete(self):
"""|coro|

6
discord/state.py

@ -865,7 +865,7 @@ class ConnectionState:
else utils.find(lambda m: m.id == msg_id, reversed(self._call_message_cache.values()))
)
def _add_guild_from_data(self, data: GuildPayload) -> Optional[Guild]:
def _add_guild_from_data(self, data: GuildPayload) -> Guild:
guild = Guild(data=data, state=self)
self._add_guild(guild)
return guild
@ -1186,6 +1186,10 @@ class ConnectionState:
read_state.last_acked_id = message_id
if 'mention_count' in data:
read_state.badge_count = data['mention_count']
if 'flags' in data and data['flags'] is not None:
read_state._flags = data['flags']
if 'last_viewed' in data and data['last_viewed']:
read_state.last_viewed = read_state.unpack_last_viewed(data['last_viewed'])
self.dispatch('raw_message_ack', raw)
if message is not None:

11
discord/threads.py

@ -25,7 +25,6 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Literal, Optional, Sequence, Union, TYPE_CHECKING
from datetime import datetime
import asyncio
import array
import copy
@ -44,7 +43,7 @@ __all__ = (
)
if TYPE_CHECKING:
from datetime import datetime
from datetime import date, datetime
from typing_extensions import Self
from .types.threads import (
@ -391,6 +390,14 @@ class Thread(Messageable, Hashable):
"""
return self.read_state.badge_count
@property
def last_viewed_timestamp(self) -> date:
""":class:`datetime.date`: When the channel was last viewed.
.. versionadded:: 2.1
"""
return self.read_state.last_viewed # type: ignore
@property
def category(self) -> Optional[CategoryChannel]:
"""The category channel the parent channel belongs to, if applicable.

2
discord/types/gateway.py

@ -250,6 +250,8 @@ class ChannelPinsAckEvent(TypedDict):
class MessageAckEvent(TypedDict):
channel_id: Snowflake
message_id: Snowflake
flags: Optional[int]
last_viewed: Optional[int]
manual: NotRequired[bool]
mention_count: NotRequired[int]
ack_type: NotRequired[ReadStateType]

2
discord/types/read_state.py

@ -41,7 +41,7 @@ class ReadState(TypedDict):
mention_count: NotRequired[int]
badge_count: NotRequired[int]
flags: NotRequired[int]
# last_viewed: NotRequired[Optional[str]]
last_viewed: NotRequired[Optional[int]]
class BulkReadState(TypedDict):

5
docs/api.rst

@ -7830,6 +7830,11 @@ Flags
.. autoclass:: PromotionFlags()
:members:
.. attributetable:: ReadStateFlags
.. autoclass:: ReadStateFlags()
:members:
.. attributetable:: SKUFlags
.. autoclass:: SKUFlags()

Loading…
Cancel
Save