diff --git a/discord/guild.py b/discord/guild.py index c457c215e..cd0cdfb6b 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -2840,6 +2840,68 @@ class Guild(Hashable): data = await self._state.http.get_scheduled_event(self.id, scheduled_event_id, with_counts) return ScheduledEvent(state=self._state, data=data) + @overload + async def create_scheduled_event( + self, + *, + name: str, + start_time: datetime.datetime, + entity_type: Literal[EntityType.external] = ..., + privacy_level: PrivacyLevel = ..., + location: str = ..., + end_time: datetime.datetime = ..., + description: str = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def create_scheduled_event( + self, + *, + name: str, + start_time: datetime.datetime, + entity_type: Literal[EntityType.stage_instance, EntityType.voice] = ..., + privacy_level: PrivacyLevel = ..., + channel: Snowflake = ..., + end_time: datetime.datetime = ..., + description: str = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def create_scheduled_event( + self, + *, + name: str, + start_time: datetime.datetime, + privacy_level: PrivacyLevel = ..., + location: str = ..., + end_time: datetime.datetime = ..., + description: str = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + + @overload + async def create_scheduled_event( + self, + *, + name: str, + start_time: datetime.datetime, + privacy_level: PrivacyLevel = ..., + channel: Union[VoiceChannel, StageChannel] = ..., + end_time: datetime.datetime = ..., + description: str = ..., + image: bytes = ..., + reason: Optional[str] = ..., + ) -> ScheduledEvent: + ... + async def create_scheduled_event( self, *, @@ -2926,12 +2988,32 @@ class Guild(Hashable): ) payload['scheduled_start_time'] = start_time.isoformat() - if not isinstance(entity_type, EntityType): - raise TypeError('entity_type must be of type EntityType') + entity_type = entity_type or getattr(channel, '_scheduled_event_entity_type', MISSING) + if entity_type is MISSING: + if channel and isinstance(channel, Object): + if channel.type is VoiceChannel: + entity_type = EntityType.voice + elif channel.type is StageChannel: + entity_type = EntityType.stage_instance - payload['entity_type'] = entity_type.value + elif location not in (MISSING, None): + entity_type = EntityType.external + else: + if not isinstance(entity_type, EntityType): + raise TypeError('entity_type must be of type EntityType') + + payload['entity_type'] = entity_type.value - payload['privacy_level'] = PrivacyLevel.guild_only.value + if entity_type is None: + raise TypeError( + 'invalid GuildChannel type passed, must be VoiceChannel or StageChannel ' f'not {channel.__class__.__name__}' + ) + + if privacy_level is not MISSING: + if not isinstance(privacy_level, PrivacyLevel): + raise TypeError('privacy_level must be of type PrivacyLevel.') + + payload['privacy_level'] = privacy_level.value if description is not MISSING: payload['description'] = description @@ -2941,7 +3023,7 @@ class Guild(Hashable): payload['image'] = image_as_str if entity_type in (EntityType.stage_instance, EntityType.voice): - if channel is MISSING or channel is None: + if channel in (MISSING, None): raise TypeError('channel must be set when entity_type is voice or stage_instance') payload['channel_id'] = channel.id @@ -2957,12 +3039,15 @@ class Guild(Hashable): metadata['location'] = location - if end_time is not MISSING: - if end_time.tzinfo is None: - raise ValueError( - 'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.' - ) - payload['scheduled_end_time'] = end_time.isoformat() + if end_time in (MISSING, None): + raise TypeError('end_time must be set when entity_type is external') + + if end_time not in (MISSING, None): + if end_time.tzinfo is None: + raise ValueError( + 'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.' + ) + payload['scheduled_end_time'] = end_time.isoformat() if metadata: payload['entity_metadata'] = metadata diff --git a/discord/scheduled_event.py b/discord/scheduled_event.py index 2dc6d4088..89cdddaba 100644 --- a/discord/scheduled_event.py +++ b/discord/scheduled_event.py @@ -499,18 +499,17 @@ class ScheduledEvent(Hashable): entity_type = EntityType.stage_instance elif location not in (MISSING, None): entity_type = EntityType.external + else: + if not isinstance(entity_type, EntityType): + raise TypeError('entity_type must be of type EntityType') + + payload['entity_type'] = entity_type.value if entity_type is None: raise TypeError( f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}' ) - if entity_type is not MISSING: - if not isinstance(entity_type, EntityType): - raise TypeError('entity_type must be of type EntityType') - - payload['entity_type'] = entity_type.value - _entity_type = entity_type or self.entity_type _entity_type_changed = _entity_type is not self.entity_type