Browse Source

Fix create_scheduled_event param handling

pull/9296/head
Puncher 2 years ago
committed by GitHub
parent
commit
60094b17a9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 114
      discord/guild.py
  2. 11
      discord/scheduled_event.py

114
discord/guild.py

@ -2807,6 +2807,68 @@ class Guild(Hashable):
data = await self._state.http.get_scheduled_event(self.id, scheduled_event_id, with_counts) data = await self._state.http.get_scheduled_event(self.id, scheduled_event_id, with_counts)
return ScheduledEvent(state=self._state, data=data) return ScheduledEvent(state=self._state, data=data)
@overload
async def create_scheduled_event(
self,
*,
name: str,
start_time: datetime.datetime,
entity_type: Literal[EntityType.external] = ...,
privacy_level: PrivacyLevel = ...,
location: str = ...,
end_time: datetime.datetime = ...,
description: str = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def create_scheduled_event(
self,
*,
name: str,
start_time: datetime.datetime,
entity_type: Literal[EntityType.stage_instance, EntityType.voice] = ...,
privacy_level: PrivacyLevel = ...,
channel: Snowflake = ...,
end_time: datetime.datetime = ...,
description: str = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def create_scheduled_event(
self,
*,
name: str,
start_time: datetime.datetime,
privacy_level: PrivacyLevel = ...,
location: str = ...,
end_time: datetime.datetime = ...,
description: str = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def create_scheduled_event(
self,
*,
name: str,
start_time: datetime.datetime,
privacy_level: PrivacyLevel = ...,
channel: Union[VoiceChannel, StageChannel] = ...,
end_time: datetime.datetime = ...,
description: str = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
async def create_scheduled_event( async def create_scheduled_event(
self, self,
*, *,
@ -2899,27 +2961,32 @@ class Guild(Hashable):
) )
payload['scheduled_start_time'] = start_time.isoformat() payload['scheduled_start_time'] = start_time.isoformat()
entity_type = entity_type or getattr(channel, '_scheduled_event_entity_type', MISSING)
if entity_type is MISSING: if entity_type is MISSING:
if channel is MISSING: if channel and isinstance(channel, Object):
if channel.type is VoiceChannel:
entity_type = EntityType.voice
elif channel.type is StageChannel:
entity_type = EntityType.stage_instance
elif location not in (MISSING, None):
entity_type = EntityType.external entity_type = EntityType.external
else: else:
_entity_type = getattr(channel, '_scheduled_event_entity_type', MISSING) if not isinstance(entity_type, EntityType):
if _entity_type is None: raise TypeError('entity_type must be of type EntityType')
raise TypeError(
'invalid GuildChannel type passed, must be VoiceChannel or StageChannel '
f'not {channel.__class__.__name__}'
)
if _entity_type is MISSING:
raise TypeError('entity_type must be passed in when passing an ambiguous channel type')
entity_type = _entity_type payload['entity_type'] = entity_type.value
if not isinstance(entity_type, EntityType): if entity_type is None:
raise TypeError('entity_type must be of type EntityType') raise TypeError(
'invalid GuildChannel type passed, must be VoiceChannel or StageChannel ' f'not {channel.__class__.__name__}'
)
payload['entity_type'] = entity_type.value if privacy_level is not MISSING:
if not isinstance(privacy_level, PrivacyLevel):
raise TypeError('privacy_level must be of type PrivacyLevel.')
payload['privacy_level'] = PrivacyLevel.guild_only.value payload['privacy_level'] = privacy_level.value
if description is not MISSING: if description is not MISSING:
payload['description'] = description payload['description'] = description
@ -2929,7 +2996,7 @@ class Guild(Hashable):
payload['image'] = image_as_str payload['image'] = image_as_str
if entity_type in (EntityType.stage_instance, EntityType.voice): if entity_type in (EntityType.stage_instance, EntityType.voice):
if channel is MISSING or channel is None: if channel in (MISSING, None):
raise TypeError('channel must be set when entity_type is voice or stage_instance') raise TypeError('channel must be set when entity_type is voice or stage_instance')
payload['channel_id'] = channel.id payload['channel_id'] = channel.id
@ -2945,12 +3012,15 @@ class Guild(Hashable):
metadata['location'] = location metadata['location'] = location
if end_time is not MISSING: if end_time in (MISSING, None):
if end_time.tzinfo is None: raise TypeError('end_time must be set when entity_type is external')
raise ValueError(
'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.' if end_time not in (MISSING, None):
) if end_time.tzinfo is None:
payload['scheduled_end_time'] = end_time.isoformat() raise ValueError(
'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.'
)
payload['scheduled_end_time'] = end_time.isoformat()
if metadata: if metadata:
payload['entity_metadata'] = metadata payload['entity_metadata'] = metadata

11
discord/scheduled_event.py

@ -503,18 +503,17 @@ class ScheduledEvent(Hashable):
entity_type = EntityType.stage_instance entity_type = EntityType.stage_instance
elif location not in (MISSING, None): elif location not in (MISSING, None):
entity_type = EntityType.external entity_type = EntityType.external
else:
if not isinstance(entity_type, EntityType):
raise TypeError('entity_type must be of type EntityType')
payload['entity_type'] = entity_type.value
if entity_type is None: if entity_type is None:
raise TypeError( raise TypeError(
f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}' f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}'
) )
if entity_type is not MISSING:
if not isinstance(entity_type, EntityType):
raise TypeError('entity_type must be of type EntityType')
payload['entity_type'] = entity_type.value
_entity_type = entity_type or self.entity_type _entity_type = entity_type or self.entity_type
_entity_type_changed = _entity_type is not self.entity_type _entity_type_changed = _entity_type is not self.entity_type

Loading…
Cancel
Save