diff --git a/discord/scheduled_event.py b/discord/scheduled_event.py index 9f8bd9920..810ea2b2a 100644 --- a/discord/scheduled_event.py +++ b/discord/scheduled_event.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional, Union +from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional, Union, overload, Literal from .asset import Asset from .enums import EventStatus, EntityType, PrivacyLevel, try_enum @@ -298,6 +298,87 @@ class ScheduledEvent(Hashable): return await self.__modify_status(EventStatus.cancelled, reason) + @overload + async def edit( + self, + *, + name: str = ..., + description: str = ..., + start_time: datetime = ..., + end_time: Optional[datetime] = ..., + privacy_level: PrivacyLevel = ..., + status: EventStatus = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def edit( + self, + *, + name: str = ..., + description: str = ..., + channel: Snowflake, + start_time: datetime = ..., + end_time: Optional[datetime] = ..., + privacy_level: PrivacyLevel = ..., + entity_type: Literal[EntityType.voice, EntityType.stage_instance], + status: EventStatus = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def edit( + self, + *, + name: str = ..., + description: str = ..., + start_time: datetime = ..., + end_time: datetime = ..., + privacy_level: PrivacyLevel = ..., + entity_type: Literal[EntityType.external], + status: EventStatus = ..., + image: bytes = ..., + location: str, + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def edit( + self, + *, + name: str = ..., + description: str = ..., + channel: Union[VoiceChannel, StageChannel], + start_time: datetime = ..., + end_time: Optional[datetime] = ..., + privacy_level: PrivacyLevel = ..., + status: EventStatus = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def edit( + self, + *, + name: str = ..., + description: str = ..., + start_time: datetime = ..., + end_time: datetime = ..., + privacy_level: PrivacyLevel = ..., + status: EventStatus = ..., + image: bytes = ..., + location: str, + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + async def edit( self, *, @@ -414,6 +495,15 @@ class ScheduledEvent(Hashable): payload['image'] = image_as_str entity_type = entity_type or getattr(channel, '_scheduled_event_entity_type', MISSING) + if entity_type is MISSING: + if channel and isinstance(channel, Object): + if channel.type is VoiceChannel: + entity_type = EntityType.voice + elif channel.type is StageChannel: + entity_type = EntityType.stage_instance + elif location not in (MISSING, None): + entity_type = EntityType.external + if entity_type is None: raise TypeError( f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}' @@ -426,12 +516,14 @@ class ScheduledEvent(Hashable): payload['entity_type'] = entity_type.value _entity_type = entity_type or self.entity_type + _entity_type_changed = _entity_type is not self.entity_type if _entity_type in (EntityType.stage_instance, EntityType.voice): if channel is MISSING or channel is None: - raise TypeError('channel must be set when entity_type is voice or stage_instance') - - payload['channel_id'] = channel.id + if _entity_type_changed: + raise TypeError('channel must be set when entity_type is voice or stage_instance') + else: + payload['channel_id'] = channel.id if location not in (MISSING, None): raise TypeError('location cannot be set when entity_type is voice or stage_instance') @@ -442,11 +534,12 @@ class ScheduledEvent(Hashable): payload['channel_id'] = None if location is MISSING or location is None: - raise TypeError('location must be set when entity_type is external') - - metadata['location'] = location + if _entity_type_changed: + raise TypeError('location must be set when entity_type is external') + else: + metadata['location'] = location - if end_time is MISSING or end_time is None: + if not self.end_time and (end_time is MISSING or end_time is None): raise TypeError('end_time must be set when entity_type is external') if end_time is not MISSING: