From 4122bef8eeb0121c5f04fdd927e7e70b5301a2a7 Mon Sep 17 00:00:00 2001 From: Eta <24918963+Eta0@users.noreply.github.com> Date: Sun, 27 Nov 2022 00:43:24 -0600 Subject: [PATCH] Fix async iterators requesting past their bounds This affects Messageable.history, ScheduledEvent.users, Client.fetch_guilds, and Guild.audit_logs. To illustrate the problem, Messageable.history counted returned messages to tell when to stop iteration, but did so before filtering away those past the before or after boundaries. When both oldest_first=False and an after boundary were provided, this led to the history iterator continuing to retrieve messages older than the after boundary, which would then all be filtered away, continuing until the message limit or the beginning of the entire channel was reached. A similar situation would also occur with oldest_first=True and a before boundary provided. This commit changes the logic in these methods to count items after filtering, so they stop requesting more as soon as the in-bounds items are exhausted. --- discord/abc.py | 14 ++++++++------ discord/client.py | 14 ++++++++------ discord/guild.py | 18 ++++++++++-------- discord/scheduled_event.py | 12 +++++++----- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index e1b5f3232..19912a646 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1774,24 +1774,26 @@ class Messageable: channel = await self._get_channel() while True: - retrieve = min(100 if limit is None else limit, 100) + retrieve = 100 if limit is None else min(limit, 100) if retrieve < 1: return data, state, limit = await strategy(retrieve, state, limit) - # Terminate loop on next iteration; there's no data left after this - if len(data) < 100: - limit = 0 - if reverse: data = reversed(data) if predicate: data = filter(predicate, data) - for raw_message in data: + count = 0 + + for count, raw_message in enumerate(data, 1): yield self._state.create_message(channel=channel, data=raw_message) + if count < 100: + # There's no data left after this + break + class Connectable(Protocol): """An ABC that details the common operations on a channel that can diff --git a/discord/client.py b/discord/client.py index 125c24350..3878ccd0d 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1424,22 +1424,24 @@ class Client: predicate = lambda m: int(m['id']) > after.id while True: - retrieve = min(200 if limit is None else limit, 200) + retrieve = 200 if limit is None else min(limit, 200) if retrieve < 1: return data, state, limit = await strategy(retrieve, state, limit) - # Terminate loop on next iteration; there's no data left after this - if len(data) < 200: - limit = 0 - if predicate: data = filter(predicate, data) - for raw_guild in data: + count = 0 + + for count, raw_guild in enumerate(data, 1): yield Guild(state=self._connection, data=raw_guild) + if count < 200: + # There's no data left after this + break + async def fetch_template(self, code: Union[Template, str]) -> Template: """|coro| diff --git a/discord/guild.py b/discord/guild.py index b855ea2c9..621c09b53 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -2079,7 +2079,7 @@ class Guild(Hashable): raise ClientException('Intents.members must be enabled to use this.') while True: - retrieve = min(1000 if limit is None else limit, 1000) + retrieve = 1000 if limit is None else min(limit, 1000) if retrieve < 1: return @@ -2304,7 +2304,7 @@ class Guild(Hashable): strategy, state = _after_strategy, after while True: - retrieve = min(1000 if limit is None else limit, 1000) + retrieve = 1000 if limit is None else min(limit, 1000) if retrieve < 1: return @@ -3660,16 +3660,12 @@ class Guild(Hashable): from .app_commands import AppCommand while True: - retrieve = min(100 if limit is None else limit, 100) + retrieve = 100 if limit is None else min(limit, 100) if retrieve < 1: return data, raw_entries, state, limit = await strategy(retrieve, state, limit) - # Terminate loop on next iteration; there's no data left after this - if len(raw_entries) < 100: - limit = 0 - if reverse: raw_entries = reversed(raw_entries) if predicate: @@ -3690,7 +3686,9 @@ class Guild(Hashable): ) automod_rule_map = {rule.id: rule for rule in automod_rules} - for raw_entry in raw_entries: + count = 0 + + for count, raw_entry in enumerate(raw_entries, 1): # Weird Discord quirk if raw_entry['action_type'] is None: continue @@ -3704,6 +3702,10 @@ class Guild(Hashable): guild=self, ) + if count < 100: + # There's no data left after this + break + async def widget(self) -> Widget: """|coro| diff --git a/discord/scheduled_event.py b/discord/scheduled_event.py index 929591a49..dd5c8c37b 100644 --- a/discord/scheduled_event.py +++ b/discord/scheduled_event.py @@ -560,25 +560,27 @@ class ScheduledEvent(Hashable): predicate = lambda u: u['user']['id'] > after.id while True: - retrieve = min(100 if limit is None else limit, 100) + retrieve = 100 if limit is None else min(limit, 100) if retrieve < 1: return data, state, limit = await strategy(retrieve, state, limit) - if len(data) < 100: - limit = 0 - if reverse: data = reversed(data) if predicate: data = filter(predicate, data) users = (self._state.store_user(raw_user['user']) for raw_user in data) + count = 0 - for user in users: + for count, user in enumerate(users, 1): yield user + if count < 100: + # There's no data left after this + break + def _add_user(self, user: User) -> None: self._users[user.id] = user