From 63b32994f4cfa03e6f5dc6ed03c1704499f9b4ea Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Mon, 15 Aug 2022 07:17:41 -0500 Subject: [PATCH] Improve TranslationContext type narrowing using a tagged union --- discord/app_commands/commands.py | 2 +- discord/app_commands/errors.py | 6 +- discord/app_commands/models.py | 2 +- discord/app_commands/transformers.py | 2 +- discord/app_commands/translator.py | 148 ++++++++++++++++----------- 5 files changed, 96 insertions(+), 64 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 60f4704f7..edbafdae9 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -52,7 +52,7 @@ from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale from .models import Choice from .transformers import annotation_to_parameter, CommandParameter, NoneType from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered -from .translator import TranslationContext, TranslationContextLocation, Translator, locale_str +from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str from ..message import Message from ..user import User from ..member import Member diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index 19c9f736f..11943a8a2 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -52,7 +52,7 @@ __all__ = ( if TYPE_CHECKING: from .commands import Command, Group, ContextMenu from .transformers import Transformer - from .translator import TranslationContext, locale_str + from .translator import TranslationContextTypes, locale_str from ..types.snowflake import Snowflake, SnowflakeList from .checks import Cooldown @@ -164,11 +164,11 @@ class TranslationError(AppCommandError): *msg: str, string: Optional[Union[str, locale_str]] = None, locale: Optional[Locale] = None, - context: TranslationContext, + context: TranslationContextTypes, ) -> None: self.string: Optional[Union[str, locale_str]] = string self.locale: Optional[Locale] = locale - self.context: TranslationContext = context + self.context: TranslationContextTypes = context if msg: super().__init__(*msg) diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py index e306e9392..9a3be4e73 100644 --- a/discord/app_commands/models.py +++ b/discord/app_commands/models.py @@ -26,7 +26,7 @@ from __future__ import annotations from datetime import datetime from .errors import MissingApplicationID -from .translator import TranslationContextLocation, Translator, TranslationContext, locale_str +from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator from ..permissions import Permissions from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum from ..mixins import Hashable diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index 21291b55c..99a45303e 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -46,7 +46,7 @@ from typing import ( from .errors import AppCommandError, TransformerError from .models import AppCommandChannel, AppCommandThread, Choice -from .translator import TranslationContextLocation, locale_str, Translator, TranslationContext +from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel from ..abc import GuildChannel from ..threads import Thread diff --git a/discord/app_commands/translator.py b/discord/app_commands/translator.py index a0d366735..1741054e3 100644 --- a/discord/app_commands/translator.py +++ b/discord/app_commands/translator.py @@ -23,13 +23,19 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload from .errors import TranslationError from ..enums import Enum, Locale +if TYPE_CHECKING: + from .commands import Command, ContextMenu, Group, Parameter + from .models import Choice + + __all__ = ( 'TranslationContextLocation', + 'TranslationContextTypes', 'TranslationContext', 'Translator', 'locale_str', @@ -47,7 +53,11 @@ class TranslationContextLocation(Enum): other = 7 -class TranslationContext: # type: ignore # See below +_L = TypeVar('_L', bound=TranslationContextLocation) +_D = TypeVar('_D') + + +class TranslationContext(Generic[_L, _D]): """A class that provides context for the :class:`locale_str` being translated. This is useful to determine where exactly the string is located and aid in looking @@ -63,60 +73,77 @@ class TranslationContext: # type: ignore # See below __slots__ = ('location', 'data') - def __init__(self, location: TranslationContextLocation, data: Any) -> None: - self.location: TranslationContextLocation = location - self.data: Any = data - - -if TYPE_CHECKING: - # For type checking purposes, it makes sense to allow the user to leverage type narrowing - # So code like this works as expected: - # if context.type is TranslationContextLocation.command_name: - # reveal_type(context.data) # Revealed type is Command | ContextMenu - # - # Unfortunately doing a trick like this requires lying to the type checker so - # this is what the code below enables. - # - # Should this trick stop working then it might be fair to remove this code. - # It's purely here for convenience. - - from .commands import Command, ContextMenu, Group, Parameter - from .models import Choice - - class _CommandNameTranslationContext: - location: Literal[TranslationContextLocation.command_name] - data: Union[Command[Any, ..., Any], ContextMenu] - - class _CommandDescriptionTranslationContext: - location: Literal[TranslationContextLocation.command_description] - data: Command[Any, ..., Any] - - class _GroupTranslationContext: - location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description] - data: Group - - class _ParameterTranslationContext: - location: Literal[TranslationContextLocation.parameter_description, TranslationContextLocation.parameter_name] - data: Parameter - - class _ChoiceTranslationContext: - location: Literal[TranslationContextLocation.choice_name] - data: Choice[Union[int, str, float]] - - class _OtherTranslationContext: - location: Literal[TranslationContextLocation.other] - data: Any - - class TranslationContext( - _CommandNameTranslationContext, - _CommandDescriptionTranslationContext, - _GroupTranslationContext, - _ParameterTranslationContext, - _ChoiceTranslationContext, - _OtherTranslationContext, - ): - def __init__(self, location: TranslationContextLocation, data: Any) -> None: - ... + @overload + def __init__( + self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu] + ) -> None: + ... + + @overload + def __init__( + self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any] + ) -> None: + ... + + @overload + def __init__( + self, + location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], + data: Group, + ) -> None: + ... + + @overload + def __init__( + self, + location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], + data: Parameter, + ) -> None: + ... + + @overload + def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None: + ... + + @overload + def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None: + ... + + def __init__(self, location: _L, data: _D) -> None: + self.location: _L = location + self.data: _D = data + + +# For type checking purposes, it makes sense to allow the user to leverage type narrowing +# So code like this works as expected: +# +# if context.type == TranslationContextLocation.command_name: +# reveal_type(context.data) # Revealed type is Command | ContextMenu +# +# This requires a union of types +CommandNameTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu'] +] +CommandDescriptionTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]' +] +GroupTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group' +] +ParameterTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter' +] +ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]'] +OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any] + +TranslationContextTypes = Union[ + CommandNameTranslationContext, + CommandDescriptionTranslationContext, + GroupTranslationContext, + ParameterTranslationContext, + ChoiceTranslationContext, + OtherTranslationContext, +] class Translator: @@ -162,7 +189,9 @@ class Translator: """ pass - async def _checked_translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: + async def _checked_translate( + self, string: locale_str, locale: Locale, context: TranslationContextTypes + ) -> Optional[str]: try: return await self.translate(string, locale, context) except TranslationError: @@ -170,7 +199,7 @@ class Translator: except Exception as e: raise TranslationError(string=string, locale=locale, context=context) from e - async def translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]: + async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]: """|coro| Translates the given string to the specified locale. @@ -190,6 +219,9 @@ class Translator: The locale being requested for translation. context: :class:`TranslationContext` The translation context where the string originated from. + For better type checking ergonomics, the ``TranslationContextTypes`` + type can be used instead to aid with type narrowing. It is functionally + equivalent to :class:`TranslationContext`. """ return None