diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index c11090b5a..d1063c4fe 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -43,7 +43,7 @@ from ..user import BaseUser, User from ..asset import Asset from ..http import Route from ..mixins import Hashable -from ..object import Object +from ..channel import PartialMessageable __all__ = ( 'Webhook', @@ -58,6 +58,7 @@ if TYPE_CHECKING: from ..file import File from ..embeds import Embed from ..mentions import AllowedMentions + from ..state import ConnectionState from ..types.webhook import ( Webhook as WebhookPayload, ) @@ -579,10 +580,11 @@ class _FriendlyHttpAttributeErrorHelper: class _WebhookState: __slots__ = ('_parent', '_webhook') - def __init__(self, webhook, parent): - self._webhook = webhook + def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): + self._webhook: Any = webhook - if isinstance(parent, self.__class__): + self._parent: Optional[ConnectionState] + if isinstance(parent, _WebhookState): self._parent = None else: self._parent = parent @@ -595,10 +597,12 @@ class _WebhookState: def store_user(self, data): if self._parent is not None: return self._parent.store_user(data) - return BaseUser(state=self, data=data) + # state parameter is artificial + return BaseUser(state=self, data=data) # type: ignore def create_user(self, data): - return BaseUser(state=self, data=data) + # state parameter is artificial + return BaseUser(state=self, data=data) # type: ignore @property def http(self): @@ -748,9 +752,9 @@ class BaseWebhook(Hashable): '_state', ) - def __init__(self, data: WebhookPayload, token: Optional[str] = None, state=None): + def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None): self.auth_token: Optional[str] = token - self._state = state or _WebhookState(self, parent=state) + self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state) self._update(data) def _update(self, data: WebhookPayload): @@ -765,10 +769,8 @@ class BaseWebhook(Hashable): user = data.get('user') self.user: Optional[Union[BaseUser, User]] = None if user is not None: - if self._state is None: - self.user = BaseUser(state=None, data=user) - else: - self.user = User(state=self._state, data=user) + # state parameter may be _WebhookState + self.user = User(state=self._state, data=user) # type: ignore source_channel = data.get('source_channel') if source_channel: @@ -1179,7 +1181,9 @@ class Webhook(BaseWebhook): def _create_message(self, data): state = _WebhookState(self, parent=self._state) - channel = self.channel or Object(id=int(data['channel_id'])) + # state may be artificial (unlikely at this point...) + channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore + # state is artificial return WebhookMessage(data=data, state=state, channel=channel) # type: ignore @overload diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index a3b2f156f..170b387d3 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -43,7 +43,7 @@ from .. import utils from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError from ..message import Message from ..http import Route -from ..object import Object +from ..channel import PartialMessageable from .async_ import BaseWebhook, handle_message_parameters, _WebhookState @@ -373,6 +373,8 @@ class SyncWebhookMessage(Message): .. versionadded:: 2.0 """ + _state: _WebhookState + def edit( self, content: Optional[str] = MISSING, @@ -745,8 +747,10 @@ class SyncWebhook(BaseWebhook): def _create_message(self, data): state = _WebhookState(self, parent=self._state) - channel = self.channel or Object(id=int(data['channel_id'])) - return SyncWebhookMessage(data=data, state=state, channel=channel) + # state may be artificial (unlikely at this point...) + channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore + # state is artificial + return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore @overload def send(