Browse Source

Fix Client.fetch_channel not returning Thread

pull/7167/head
Alex Nørgaard 4 years ago
committed by GitHub
parent
commit
d1dc41ec2f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      discord/channel.py
  2. 11
      discord/client.py
  3. 4
      discord/guild.py
  4. 2
      discord/iterators.py
  5. 2
      discord/message.py
  6. 4
      discord/state.py
  7. 4
      discord/threads.py

10
discord/channel.py

@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
type=ChannelType.public_thread.value,
)
return Thread(guild=self.guild, data=data)
return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads(
self,
@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
data = await self._state.http.get_active_threads(self.id)
# TODO: thread members?
return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])]
return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]):
return GroupChannel, value
else:
return cls, value
def _threaded_channel_factory(channel_type: Union[ChannelType, int]):
cls, value = _channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value

11
discord/client.py

@ -39,7 +39,7 @@ from .template import Template
from .widget import Widget
from .guild import Guild
from .emoji import Emoji
from .channel import _channel_factory
from .channel import _threaded_channel_factory
from .enums import ChannelType
from .mentions import AllowedMentions
from .errors import *
@ -58,6 +58,7 @@ from .iterators import GuildIterator
from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
from .threads import Thread
if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
@ -1371,10 +1372,10 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]:
async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID.
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
.. note::
@ -1395,12 +1396,12 @@ class Client:
Returns
--------
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]
Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`]
The channel from the ID.
"""
data = await self.http.get_channel(channel_id)
factory, ch_type = _channel_factory(data['type'])
factory, ch_type = _threaded_channel_factory(data['type'])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))

4
discord/guild.py

@ -287,7 +287,7 @@ class Guild(Hashable):
self._members[member.id] = member
def _store_thread(self, payload: ThreadPayload, /) -> Thread:
thread = Thread(guild=self, data=payload)
thread = Thread(guild=self, state=self._state, data=payload)
self._threads[thread.id] = thread
return thread
@ -466,7 +466,7 @@ class Guild(Hashable):
if 'threads' in data:
threads = data['threads']
for thread in threads:
self._add_thread(Thread(guild=self, data=thread))
self._add_thread(Thread(guild=self, state=self._state, data=thread))
@property
def channels(self) -> List[GuildChannel]:

2
discord/iterators.py

@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, data=data)
return Thread(guild=self.guild, state=self.guild._state, data=data)

2
discord/message.py

@ -1491,7 +1491,7 @@ class Message(Hashable):
auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value,
)
return Thread(guild=self.guild, data=data) # type: ignore
return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
"""|coro|

4
discord/state.py

@ -715,7 +715,7 @@ class ConnectionState:
log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
return
thread = Thread(guild=guild, data=data)
thread = Thread(guild=guild, state=guild._state, data=data)
has_thread = guild.get_thread(thread.id)
guild._add_thread(thread)
if not has_thread:
@ -735,7 +735,7 @@ class ConnectionState:
thread._update(data)
self.dispatch('thread_update', old, thread)
else:
thread = Thread(guild=guild, data=data)
thread = Thread(guild=guild, state=guild._state, data=data)
guild._add_thread(thread)
self.dispatch('thread_join', thread)

4
discord/threads.py

@ -139,8 +139,8 @@ class Thread(Messageable, Hashable):
'archive_timestamp',
)
def __init__(self, *, guild: Guild, data: ThreadPayload):
self._state: ConnectionState = guild._state
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
self._state: ConnectionState = state
self.guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)

Loading…
Cancel
Save