Browse Source

Use typing.Protocol instead of abc.ABCMeta

pull/6628/head
James 4 years ago
committed by GitHub
parent
commit
34ab772653
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 112
      discord/abc.py
  2. 58
      discord/ext/commands/converter.py
  3. 5
      discord/ext/commands/core.py
  4. 23
      docs/api.rst

112
discord/abc.py

@ -22,10 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
import abc from __future__ import annotations
import sys import sys
import copy import copy
import asyncio import asyncio
from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable
from .iterators import HistoryIterator from .iterators import HistoryIterator
from .context_managers import Typing from .context_managers import Typing
@ -39,13 +41,22 @@ from .file import File
from .voice_client import VoiceClient, VoiceProtocol from .voice_client import VoiceClient, VoiceProtocol
from . import utils from . import utils
if TYPE_CHECKING:
from datetime import datetime
from .user import ClientUser
class _Undefined: class _Undefined:
def __repr__(self): def __repr__(self):
return 'see-below' return 'see-below'
_undefined = _Undefined() _undefined = _Undefined()
class Snowflake(metaclass=abc.ABCMeta):
@runtime_checkable
class Snowflake(Protocol):
"""An ABC that details the common operations on a Discord model. """An ABC that details the common operations on a Discord model.
Almost all :ref:`Discord models <discord_api_models>` meet this Almost all :ref:`Discord models <discord_api_models>` meet this
@ -60,27 +71,16 @@ class Snowflake(metaclass=abc.ABCMeta):
The model's unique ID. The model's unique ID.
""" """
__slots__ = () __slots__ = ()
id: int
@property @property
@abc.abstractmethod def created_at(self) -> datetime:
def created_at(self):
""":class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC.""" """:class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC."""
raise NotImplementedError raise NotImplementedError
@classmethod
def __subclasshook__(cls, C):
if cls is Snowflake:
mro = C.__mro__
for attr in ('created_at', 'id'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
class User(metaclass=abc.ABCMeta): @runtime_checkable
class User(Snowflake, Protocol):
"""An ABC that details the common operations on a Discord user. """An ABC that details the common operations on a Discord user.
The following implement this ABC: The following implement this ABC:
@ -104,35 +104,24 @@ class User(metaclass=abc.ABCMeta):
""" """
__slots__ = () __slots__ = ()
name: str
discriminator: str
avatar: Optional[str]
bot: bool
@property @property
@abc.abstractmethod def display_name(self) -> str:
def display_name(self):
""":class:`str`: Returns the user's display name.""" """:class:`str`: Returns the user's display name."""
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod def mention(self) -> str:
def mention(self):
""":class:`str`: Returns a string that allows you to mention the given user.""" """:class:`str`: Returns a string that allows you to mention the given user."""
raise NotImplementedError raise NotImplementedError
@classmethod
def __subclasshook__(cls, C):
if cls is User:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for attr in ('display_name', 'mention', 'name', 'avatar', 'discriminator', 'bot'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
class PrivateChannel(metaclass=abc.ABCMeta): @runtime_checkable
class PrivateChannel(Snowflake, Protocol):
"""An ABC that details the common operations on a private Discord channel. """An ABC that details the common operations on a private Discord channel.
The following implement this ABC: The following implement this ABC:
@ -149,18 +138,8 @@ class PrivateChannel(metaclass=abc.ABCMeta):
""" """
__slots__ = () __slots__ = ()
@classmethod me: ClientUser
def __subclasshook__(cls, C):
if cls is PrivateChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
mro = C.__mro__
for base in mro:
if 'me' in base.__dict__:
return True
return NotImplemented
return NotImplemented
class _Overwrites: class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type') __slots__ = ('id', 'allow', 'deny', 'type')
@ -179,7 +158,8 @@ class _Overwrites:
'type': self.type, 'type': self.type,
} }
class GuildChannel:
class GuildChannel(Protocol):
"""An ABC that details the common operations on a Discord guild channel. """An ABC that details the common operations on a Discord guild channel.
The following implement this ABC: The following implement this ABC:
@ -190,6 +170,11 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`. This ABC must also implement :class:`~discord.abc.Snowflake`.
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -826,14 +811,13 @@ class GuildChannel:
lock_permissions = kwargs.get('sync_permissions', False) lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason') reason = kwargs.get('reason')
for index, channel in enumerate(channels): for index, channel in enumerate(channels):
d = { 'id': channel.id, 'position': index } d = {'id': channel.id, 'position': index}
if parent_id is not ... and channel.id == self.id: if parent_id is not ... and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions) d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d) payload.append(d)
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
async def create_invite(self, *, reason=None, **fields): async def create_invite(self, *, reason=None, **fields):
"""|coro| """|coro|
@ -908,7 +892,8 @@ class GuildChannel:
return result return result
class Messageable(metaclass=abc.ABCMeta):
class Messageable(Protocol):
"""An ABC that details the common operations on a model that can send messages. """An ABC that details the common operations on a model that can send messages.
The following implement this ABC: The following implement this ABC:
@ -919,11 +904,16 @@ class Messageable(metaclass=abc.ABCMeta):
- :class:`~discord.User` - :class:`~discord.User`
- :class:`~discord.Member` - :class:`~discord.Member`
- :class:`~discord.ext.commands.Context` - :class:`~discord.ext.commands.Context`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
""" """
__slots__ = () __slots__ = ()
@abc.abstractmethod
async def _get_channel(self): async def _get_channel(self):
raise NotImplementedError raise NotImplementedError
@ -1060,8 +1050,8 @@ class Messageable(metaclass=abc.ABCMeta):
f.close() f.close()
else: else:
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
nonce=nonce, allowed_mentions=allowed_mentions, nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference) message_reference=reference)
ret = state.create_message(channel=channel, data=data) ret = state.create_message(channel=channel, data=data)
if delete_after is not None: if delete_after is not None:
@ -1213,21 +1203,25 @@ class Messageable(metaclass=abc.ABCMeta):
""" """
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
class Connectable(metaclass=abc.ABCMeta):
class Connectable(Protocol):
"""An ABC that details the common operations on a channel that can """An ABC that details the common operations on a channel that can
connect to a voice server. connect to a voice server.
The following implement this ABC: The following implement this ABC:
- :class:`~discord.VoiceChannel` - :class:`~discord.VoiceChannel`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
""" """
__slots__ = () __slots__ = ()
@abc.abstractmethod
def _get_voice_client_key(self): def _get_voice_client_key(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def _get_voice_state_pair(self): def _get_voice_state_pair(self):
raise NotImplementedError raise NotImplementedError
@ -1286,6 +1280,6 @@ class Connectable(metaclass=abc.ABCMeta):
except Exception: except Exception:
# we don't care if disconnect failed because connection failed # we don't care if disconnect failed because connection failed
pass pass
raise # re-raise raise # re-raise
return voice return voice

58
discord/ext/commands/converter.py

@ -22,14 +22,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import re import re
import inspect import inspect
import typing from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable
import discord import discord
from .errors import * from .errors import *
if TYPE_CHECKING:
from .context import Context
__all__ = ( __all__ = (
'Converter', 'Converter',
'MemberConverter', 'MemberConverter',
@ -54,6 +59,7 @@ __all__ = (
'Greedy', 'Greedy',
) )
def _get_from_guilds(bot, getter, argument): def _get_from_guilds(bot, getter, argument):
result = None result = None
for guild in bot.guilds: for guild in bot.guilds:
@ -62,9 +68,13 @@ def _get_from_guilds(bot, getter, argument):
return result return result
return result return result
_utils_get = discord.utils.get _utils_get = discord.utils.get
T = TypeVar("T")
class Converter:
@runtime_checkable
class Converter(Protocol[T]):
"""The base class of custom converters that require the :class:`.Context` """The base class of custom converters that require the :class:`.Context`
to be passed to be useful. to be passed to be useful.
@ -75,7 +85,7 @@ class Converter:
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`. method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
""" """
async def convert(self, ctx, argument): async def convert(self, ctx: Context, argument: str) -> T:
"""|coro| """|coro|
The method to override to do conversion logic. The method to override to do conversion logic.
@ -100,7 +110,7 @@ class Converter:
""" """
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter): class IDConverter(Converter[T]):
def __init__(self): def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$') self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__() super().__init__()
@ -108,7 +118,7 @@ class IDConverter(Converter):
def _get_id_match(self, argument): def _get_id_match(self, argument):
return self._id_regex.match(argument) return self._id_regex.match(argument)
class MemberConverter(IDConverter): class MemberConverter(IDConverter[discord.Member]):
"""Converts to a :class:`~discord.Member`. """Converts to a :class:`~discord.Member`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -194,7 +204,7 @@ class MemberConverter(IDConverter):
return result return result
class UserConverter(IDConverter): class UserConverter(IDConverter[discord.User]):
"""Converts to a :class:`~discord.User`. """Converts to a :class:`~discord.User`.
All lookups are via the global user cache. All lookups are via the global user cache.
@ -253,7 +263,7 @@ class UserConverter(IDConverter):
return result return result
class PartialMessageConverter(Converter): class PartialMessageConverter(Converter[discord.PartialMessage], Generic[T]):
"""Converts to a :class:`discord.PartialMessage`. """Converts to a :class:`discord.PartialMessage`.
.. versionadded:: 1.7 .. versionadded:: 1.7
@ -284,7 +294,7 @@ class PartialMessageConverter(Converter):
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id) return discord.PartialMessage(channel=channel, id=message_id)
class MessageConverter(PartialMessageConverter): class MessageConverter(PartialMessageConverter[discord.Message]):
"""Converts to a :class:`discord.Message`. """Converts to a :class:`discord.Message`.
.. versionadded:: 1.1 .. versionadded:: 1.1
@ -313,7 +323,7 @@ class MessageConverter(PartialMessageConverter):
except discord.Forbidden: except discord.Forbidden:
raise ChannelNotReadable(channel) raise ChannelNotReadable(channel)
class TextChannelConverter(IDConverter): class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`. """Converts to a :class:`~discord.TextChannel`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -355,7 +365,7 @@ class TextChannelConverter(IDConverter):
return result return result
class VoiceChannelConverter(IDConverter): class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""Converts to a :class:`~discord.VoiceChannel`. """Converts to a :class:`~discord.VoiceChannel`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -396,7 +406,7 @@ class VoiceChannelConverter(IDConverter):
return result return result
class StageChannelConverter(IDConverter): class StageChannelConverter(IDConverter[discord.StageChannel]):
"""Converts to a :class:`~discord.StageChannel`. """Converts to a :class:`~discord.StageChannel`.
.. versionadded:: 1.7 .. versionadded:: 1.7
@ -436,7 +446,7 @@ class StageChannelConverter(IDConverter):
return result return result
class CategoryChannelConverter(IDConverter): class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""Converts to a :class:`~discord.CategoryChannel`. """Converts to a :class:`~discord.CategoryChannel`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -478,7 +488,7 @@ class CategoryChannelConverter(IDConverter):
return result return result
class StoreChannelConverter(IDConverter): class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""Converts to a :class:`~discord.StoreChannel`. """Converts to a :class:`~discord.StoreChannel`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -519,7 +529,7 @@ class StoreChannelConverter(IDConverter):
return result return result
class ColourConverter(Converter): class ColourConverter(Converter[discord.Colour]):
"""Converts to a :class:`~discord.Colour`. """Converts to a :class:`~discord.Colour`.
.. versionchanged:: 1.5 .. versionchanged:: 1.5
@ -603,7 +613,7 @@ class ColourConverter(Converter):
ColorConverter = ColourConverter ColorConverter = ColourConverter
class RoleConverter(IDConverter): class RoleConverter(IDConverter[discord.Role]):
"""Converts to a :class:`~discord.Role`. """Converts to a :class:`~discord.Role`.
All lookups are via the local guild. If in a DM context, then the lookup All lookups are via the local guild. If in a DM context, then the lookup
@ -633,12 +643,12 @@ class RoleConverter(IDConverter):
raise RoleNotFound(argument) raise RoleNotFound(argument)
return result return result
class GameConverter(Converter): class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`.""" """Converts to :class:`~discord.Game`."""
async def convert(self, ctx, argument): async def convert(self, ctx, argument):
return discord.Game(name=argument) return discord.Game(name=argument)
class InviteConverter(Converter): class InviteConverter(Converter[discord.Invite]):
"""Converts to a :class:`~discord.Invite`. """Converts to a :class:`~discord.Invite`.
This is done via an HTTP request using :meth:`.Bot.fetch_invite`. This is done via an HTTP request using :meth:`.Bot.fetch_invite`.
@ -653,7 +663,7 @@ class InviteConverter(Converter):
except Exception as exc: except Exception as exc:
raise BadInviteArgument() from exc raise BadInviteArgument() from exc
class GuildConverter(IDConverter): class GuildConverter(IDConverter[discord.Guild]):
"""Converts to a :class:`~discord.Guild`. """Converts to a :class:`~discord.Guild`.
The lookup strategy is as follows (in order): The lookup strategy is as follows (in order):
@ -679,7 +689,7 @@ class GuildConverter(IDConverter):
raise GuildNotFound(argument) raise GuildNotFound(argument)
return result return result
class EmojiConverter(IDConverter): class EmojiConverter(IDConverter[discord.Emoji]):
"""Converts to a :class:`~discord.Emoji`. """Converts to a :class:`~discord.Emoji`.
All lookups are done for the local guild first, if available. If that lookup All lookups are done for the local guild first, if available. If that lookup
@ -722,7 +732,7 @@ class EmojiConverter(IDConverter):
return result return result
class PartialEmojiConverter(Converter): class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""Converts to a :class:`~discord.PartialEmoji`. """Converts to a :class:`~discord.PartialEmoji`.
This is done by extracting the animated flag, name and ID from the emoji. This is done by extracting the animated flag, name and ID from the emoji.
@ -743,7 +753,7 @@ class PartialEmojiConverter(Converter):
raise PartialEmojiConversionFailure(argument) raise PartialEmojiConversionFailure(argument)
class clean_content(Converter): class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of """Converts the argument to mention scrubbed version of
said content. said content.
@ -775,7 +785,7 @@ class clean_content(Converter):
if self.fix_channel_mentions and ctx.guild: if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id, *, _get=ctx.guild.get_channel): def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id) ch = _get(id)
return (f'<#{id}>'), ('#' + ch.name if ch else '#deleted-channel') return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel')
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
@ -842,7 +852,7 @@ class _Greedy:
if converter is str or converter is type(None) or converter is _Greedy: if converter is str or converter is type(None) or converter is _Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.') raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
if getattr(converter, '__origin__', None) is typing.Union and type(None) in converter.__args__: if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__:
raise TypeError(f'Greedy[{converter!r}] is invalid.') raise TypeError(f'Greedy[{converter!r}] is invalid.')
return self.__class__(converter=converter) return self.__class__(converter=converter)

5
discord/ext/commands/core.py

@ -448,11 +448,6 @@ class Command(_BaseCommand):
instance = converter() instance = converter()
ret = await instance.convert(ctx, argument) ret = await instance.convert(ctx, argument)
return ret return ret
else:
method = getattr(converter, 'convert', None)
if method is not None and inspect.ismethod(method):
ret = await method(ctx, argument)
return ret
elif isinstance(converter, converters.Converter): elif isinstance(converter, converters.Converter):
ret = await converter.convert(ctx, argument) ret = await converter.convert(ctx, argument)
return ret return ret

23
docs/api.rst

@ -2499,20 +2499,19 @@ interface, :meth:`WebhookAdapter.request`.
Abstract Base Classes Abstract Base Classes
----------------------- -----------------------
An :term:`py:abstract base class` (also known as an ``abc``) is a class that models can inherit An :term:`abstract base class` (also known as an ``abc``) is a class that models can inherit
to get their behaviour. The Python implementation of an :doc:`abc <py:library/abc>` is to get their behaviour. **Abstract base classes should not be instantiated**.
slightly different in that you can register them at run-time. **Abstract base classes cannot be instantiated**. They are mainly there for usage with :func:`isinstance` and :func:`issubclass`\.
They are mainly there for usage with :func:`py:isinstance` and :func:`py:issubclass`\.
This library has a module related to abstract base classes, some of which are actually from the :doc:`abc <py:library/abc>` standard This library has a module related to abstract base classes, in which all the ABCs are subclasses of
module, others which are not. :class:`typing.Protocol`.
Snowflake Snowflake
~~~~~~~~~~ ~~~~~~~~~~
.. attributetable:: discord.abc.Snowflake .. attributetable:: discord.abc.Snowflake
.. autoclass:: discord.abc.Snowflake .. autoclass:: discord.abc.Snowflake()
:members: :members:
User User
@ -2520,7 +2519,7 @@ User
.. attributetable:: discord.abc.User .. attributetable:: discord.abc.User
.. autoclass:: discord.abc.User .. autoclass:: discord.abc.User()
:members: :members:
PrivateChannel PrivateChannel
@ -2528,7 +2527,7 @@ PrivateChannel
.. attributetable:: discord.abc.PrivateChannel .. attributetable:: discord.abc.PrivateChannel
.. autoclass:: discord.abc.PrivateChannel .. autoclass:: discord.abc.PrivateChannel()
:members: :members:
GuildChannel GuildChannel
@ -2536,7 +2535,7 @@ GuildChannel
.. attributetable:: discord.abc.GuildChannel .. attributetable:: discord.abc.GuildChannel
.. autoclass:: discord.abc.GuildChannel .. autoclass:: discord.abc.GuildChannel()
:members: :members:
Messageable Messageable
@ -2544,7 +2543,7 @@ Messageable
.. attributetable:: discord.abc.Messageable .. attributetable:: discord.abc.Messageable
.. autoclass:: discord.abc.Messageable .. autoclass:: discord.abc.Messageable()
:members: :members:
:exclude-members: history, typing :exclude-members: history, typing
@ -2559,7 +2558,7 @@ Connectable
.. attributetable:: discord.abc.Connectable .. attributetable:: discord.abc.Connectable
.. autoclass:: discord.abc.Connectable .. autoclass:: discord.abc.Connectable()
.. _discord_api_models: .. _discord_api_models:

Loading…
Cancel
Save