Browse Source

Improve TranslationContext type narrowing using a tagged union

pull/8342/head
Bryan Forbes 3 years ago
committed by GitHub
parent
commit
63b32994f4
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      discord/app_commands/commands.py
  2. 6
      discord/app_commands/errors.py
  3. 2
      discord/app_commands/models.py
  4. 2
      discord/app_commands/transformers.py
  5. 148
      discord/app_commands/translator.py

2
discord/app_commands/commands.py

@ -52,7 +52,7 @@ from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale
from .models import Choice from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
from .translator import TranslationContext, TranslationContextLocation, Translator, locale_str from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..message import Message from ..message import Message
from ..user import User from ..user import User
from ..member import Member from ..member import Member

6
discord/app_commands/errors.py

@ -52,7 +52,7 @@ __all__ = (
if TYPE_CHECKING: if TYPE_CHECKING:
from .commands import Command, Group, ContextMenu from .commands import Command, Group, ContextMenu
from .transformers import Transformer from .transformers import Transformer
from .translator import TranslationContext, locale_str from .translator import TranslationContextTypes, locale_str
from ..types.snowflake import Snowflake, SnowflakeList from ..types.snowflake import Snowflake, SnowflakeList
from .checks import Cooldown from .checks import Cooldown
@ -164,11 +164,11 @@ class TranslationError(AppCommandError):
*msg: str, *msg: str,
string: Optional[Union[str, locale_str]] = None, string: Optional[Union[str, locale_str]] = None,
locale: Optional[Locale] = None, locale: Optional[Locale] = None,
context: TranslationContext, context: TranslationContextTypes,
) -> None: ) -> None:
self.string: Optional[Union[str, locale_str]] = string self.string: Optional[Union[str, locale_str]] = string
self.locale: Optional[Locale] = locale self.locale: Optional[Locale] = locale
self.context: TranslationContext = context self.context: TranslationContextTypes = context
if msg: if msg:
super().__init__(*msg) super().__init__(*msg)

2
discord/app_commands/models.py

@ -26,7 +26,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from .errors import MissingApplicationID from .errors import MissingApplicationID
from .translator import TranslationContextLocation, Translator, TranslationContext, locale_str from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator
from ..permissions import Permissions from ..permissions import Permissions
from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum
from ..mixins import Hashable from ..mixins import Hashable

2
discord/app_commands/transformers.py

@ -46,7 +46,7 @@ from typing import (
from .errors import AppCommandError, TransformerError from .errors import AppCommandError, TransformerError
from .models import AppCommandChannel, AppCommandThread, Choice from .models import AppCommandChannel, AppCommandThread, Choice
from .translator import TranslationContextLocation, locale_str, Translator, TranslationContext from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel
from ..abc import GuildChannel from ..abc import GuildChannel
from ..threads import Thread from ..threads import Thread

148
discord/app_commands/translator.py

@ -23,13 +23,19 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload
from .errors import TranslationError from .errors import TranslationError
from ..enums import Enum, Locale from ..enums import Enum, Locale
if TYPE_CHECKING:
from .commands import Command, ContextMenu, Group, Parameter
from .models import Choice
__all__ = ( __all__ = (
'TranslationContextLocation', 'TranslationContextLocation',
'TranslationContextTypes',
'TranslationContext', 'TranslationContext',
'Translator', 'Translator',
'locale_str', 'locale_str',
@ -47,7 +53,11 @@ class TranslationContextLocation(Enum):
other = 7 other = 7
class TranslationContext: # type: ignore # See below _L = TypeVar('_L', bound=TranslationContextLocation)
_D = TypeVar('_D')
class TranslationContext(Generic[_L, _D]):
"""A class that provides context for the :class:`locale_str` being translated. """A class that provides context for the :class:`locale_str` being translated.
This is useful to determine where exactly the string is located and aid in looking This is useful to determine where exactly the string is located and aid in looking
@ -63,60 +73,77 @@ class TranslationContext: # type: ignore # See below
__slots__ = ('location', 'data') __slots__ = ('location', 'data')
def __init__(self, location: TranslationContextLocation, data: Any) -> None: @overload
self.location: TranslationContextLocation = location def __init__(
self.data: Any = data self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu]
) -> None:
...
if TYPE_CHECKING:
# For type checking purposes, it makes sense to allow the user to leverage type narrowing @overload
# So code like this works as expected: def __init__(
# if context.type is TranslationContextLocation.command_name: self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any]
# reveal_type(context.data) # Revealed type is Command | ContextMenu ) -> None:
# ...
# Unfortunately doing a trick like this requires lying to the type checker so
# this is what the code below enables. @overload
# def __init__(
# Should this trick stop working then it might be fair to remove this code. self,
# It's purely here for convenience. location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description],
data: Group,
from .commands import Command, ContextMenu, Group, Parameter ) -> None:
from .models import Choice ...
class _CommandNameTranslationContext: @overload
location: Literal[TranslationContextLocation.command_name] def __init__(
data: Union[Command[Any, ..., Any], ContextMenu] self,
location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description],
class _CommandDescriptionTranslationContext: data: Parameter,
location: Literal[TranslationContextLocation.command_description] ) -> None:
data: Command[Any, ..., Any] ...
class _GroupTranslationContext: @overload
location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description] def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None:
data: Group ...
class _ParameterTranslationContext: @overload
location: Literal[TranslationContextLocation.parameter_description, TranslationContextLocation.parameter_name] def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None:
data: Parameter ...
class _ChoiceTranslationContext: def __init__(self, location: _L, data: _D) -> None:
location: Literal[TranslationContextLocation.choice_name] self.location: _L = location
data: Choice[Union[int, str, float]] self.data: _D = data
class _OtherTranslationContext:
location: Literal[TranslationContextLocation.other] # For type checking purposes, it makes sense to allow the user to leverage type narrowing
data: Any # So code like this works as expected:
#
class TranslationContext( # if context.type == TranslationContextLocation.command_name:
_CommandNameTranslationContext, # reveal_type(context.data) # Revealed type is Command | ContextMenu
_CommandDescriptionTranslationContext, #
_GroupTranslationContext, # This requires a union of types
_ParameterTranslationContext, CommandNameTranslationContext = TranslationContext[
_ChoiceTranslationContext, Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu']
_OtherTranslationContext, ]
): CommandDescriptionTranslationContext = TranslationContext[
def __init__(self, location: TranslationContextLocation, data: Any) -> None: Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]'
... ]
GroupTranslationContext = TranslationContext[
Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group'
]
ParameterTranslationContext = TranslationContext[
Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter'
]
ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]']
OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any]
TranslationContextTypes = Union[
CommandNameTranslationContext,
CommandDescriptionTranslationContext,
GroupTranslationContext,
ParameterTranslationContext,
ChoiceTranslationContext,
OtherTranslationContext,
]
class Translator: class Translator:
@ -162,7 +189,9 @@ class Translator:
""" """
pass pass
async def _checked_translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: async def _checked_translate(
self, string: locale_str, locale: Locale, context: TranslationContextTypes
) -> Optional[str]:
try: try:
return await self.translate(string, locale, context) return await self.translate(string, locale, context)
except TranslationError: except TranslationError:
@ -170,7 +199,7 @@ class Translator:
except Exception as e: except Exception as e:
raise TranslationError(string=string, locale=locale, context=context) from e raise TranslationError(string=string, locale=locale, context=context) from e
async def translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]:
"""|coro| """|coro|
Translates the given string to the specified locale. Translates the given string to the specified locale.
@ -190,6 +219,9 @@ class Translator:
The locale being requested for translation. The locale being requested for translation.
context: :class:`TranslationContext` context: :class:`TranslationContext`
The translation context where the string originated from. The translation context where the string originated from.
For better type checking ergonomics, the ``TranslationContextTypes``
type can be used instead to aid with type narrowing. It is functionally
equivalent to :class:`TranslationContext`.
""" """
return None return None

Loading…
Cancel
Save