Browse Source

Fix Client.fetch_channel not returning Thread

pull/7167/head
Alex Nørgaard 4 years ago
committed by GitHub
parent
commit
d1dc41ec2f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      discord/channel.py
  2. 11
      discord/client.py
  3. 4
      discord/guild.py
  4. 2
      discord/iterators.py
  5. 2
      discord/message.py
  6. 4
      discord/state.py
  7. 4
      discord/threads.py

10
discord/channel.py

@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
type=ChannelType.public_thread.value, type=ChannelType.public_thread.value,
) )
return Thread(guild=self.guild, data=data) return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads( def archived_threads(
self, self,
@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
""" """
data = await self._state.http.get_active_threads(self.id) data = await self._state.http.get_active_threads(self.id)
# TODO: thread members? # TODO: thread members?
return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])] return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]):
return GroupChannel, value return GroupChannel, value
else: else:
return cls, value return cls, value
def _threaded_channel_factory(channel_type: Union[ChannelType, int]):
cls, value = _channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value

11
discord/client.py

@ -39,7 +39,7 @@ from .template import Template
from .widget import Widget from .widget import Widget
from .guild import Guild from .guild import Guild
from .emoji import Emoji from .emoji import Emoji
from .channel import _channel_factory from .channel import _threaded_channel_factory
from .enums import ChannelType from .enums import ChannelType
from .mentions import AllowedMentions from .mentions import AllowedMentions
from .errors import * from .errors import *
@ -58,6 +58,7 @@ from .iterators import GuildIterator
from .appinfo import AppInfo from .appinfo import AppInfo
from .ui.view import View from .ui.view import View
from .stage_instance import StageInstance from .stage_instance import StageInstance
from .threads import Thread
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
@ -1371,10 +1372,10 @@ class Client:
data = await self.http.get_user(user_id) data = await self.http.get_user(user_id)
return User(state=self._connection, data=data) return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]: async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro| """|coro|
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID. Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
.. note:: .. note::
@ -1395,12 +1396,12 @@ class Client:
Returns Returns
-------- --------
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`] Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`]
The channel from the ID. The channel from the ID.
""" """
data = await self.http.get_channel(channel_id) data = await self.http.get_channel(channel_id)
factory, ch_type = _channel_factory(data['type']) factory, ch_type = _threaded_channel_factory(data['type'])
if factory is None: if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))

4
discord/guild.py

@ -287,7 +287,7 @@ class Guild(Hashable):
self._members[member.id] = member self._members[member.id] = member
def _store_thread(self, payload: ThreadPayload, /) -> Thread: def _store_thread(self, payload: ThreadPayload, /) -> Thread:
thread = Thread(guild=self, data=payload) thread = Thread(guild=self, state=self._state, data=payload)
self._threads[thread.id] = thread self._threads[thread.id] = thread
return thread return thread
@ -466,7 +466,7 @@ class Guild(Hashable):
if 'threads' in data: if 'threads' in data:
threads = data['threads'] threads = data['threads']
for thread in threads: for thread in threads:
self._add_thread(Thread(guild=self, data=thread)) self._add_thread(Thread(guild=self, state=self._state, data=thread))
@property @property
def channels(self) -> List[GuildChannel]: def channels(self) -> List[GuildChannel]:

2
discord/iterators.py

@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread: def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread from .threads import Thread
return Thread(guild=self.guild, data=data) return Thread(guild=self.guild, state=self.guild._state, data=data)

2
discord/message.py

@ -1491,7 +1491,7 @@ class Message(Hashable):
auto_archive_duration=auto_archive_duration, auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value, type=ChannelType.public_thread.value,
) )
return Thread(guild=self.guild, data=data) # type: ignore return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
async def reply(self, content: Optional[str] = None, **kwargs) -> Message: async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
"""|coro| """|coro|

4
discord/state.py

@ -715,7 +715,7 @@ class ConnectionState:
log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
return return
thread = Thread(guild=guild, data=data) thread = Thread(guild=guild, state=guild._state, data=data)
has_thread = guild.get_thread(thread.id) has_thread = guild.get_thread(thread.id)
guild._add_thread(thread) guild._add_thread(thread)
if not has_thread: if not has_thread:
@ -735,7 +735,7 @@ class ConnectionState:
thread._update(data) thread._update(data)
self.dispatch('thread_update', old, thread) self.dispatch('thread_update', old, thread)
else: else:
thread = Thread(guild=guild, data=data) thread = Thread(guild=guild, state=guild._state, data=data)
guild._add_thread(thread) guild._add_thread(thread)
self.dispatch('thread_join', thread) self.dispatch('thread_join', thread)

4
discord/threads.py

@ -139,8 +139,8 @@ class Thread(Messageable, Hashable):
'archive_timestamp', 'archive_timestamp',
) )
def __init__(self, *, guild: Guild, data: ThreadPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
self._state: ConnectionState = guild._state self._state: ConnectionState = state
self.guild = guild self.guild = guild
self._members: Dict[int, ThreadMember] = {} self._members: Dict[int, ThreadMember] = {}
self._from_data(data) self._from_data(data)

Loading…
Cancel
Save