Browse Source

Improve TranslationContext type narrowing using a tagged union

pull/8342/head
Bryan Forbes 3 years ago
committed by GitHub
parent
commit
63b32994f4
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      discord/app_commands/commands.py
  2. 6
      discord/app_commands/errors.py
  3. 2
      discord/app_commands/models.py
  4. 2
      discord/app_commands/transformers.py
  5. 148
      discord/app_commands/translator.py

2
discord/app_commands/commands.py

@ -52,7 +52,7 @@ from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale
from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
from .translator import TranslationContext, TranslationContextLocation, Translator, locale_str
from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..message import Message
from ..user import User
from ..member import Member

6
discord/app_commands/errors.py

@ -52,7 +52,7 @@ __all__ = (
if TYPE_CHECKING:
from .commands import Command, Group, ContextMenu
from .transformers import Transformer
from .translator import TranslationContext, locale_str
from .translator import TranslationContextTypes, locale_str
from ..types.snowflake import Snowflake, SnowflakeList
from .checks import Cooldown
@ -164,11 +164,11 @@ class TranslationError(AppCommandError):
*msg: str,
string: Optional[Union[str, locale_str]] = None,
locale: Optional[Locale] = None,
context: TranslationContext,
context: TranslationContextTypes,
) -> None:
self.string: Optional[Union[str, locale_str]] = string
self.locale: Optional[Locale] = locale
self.context: TranslationContext = context
self.context: TranslationContextTypes = context
if msg:
super().__init__(*msg)

2
discord/app_commands/models.py

@ -26,7 +26,7 @@ from __future__ import annotations
from datetime import datetime
from .errors import MissingApplicationID
from .translator import TranslationContextLocation, Translator, TranslationContext, locale_str
from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator
from ..permissions import Permissions
from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum
from ..mixins import Hashable

2
discord/app_commands/transformers.py

@ -46,7 +46,7 @@ from typing import (
from .errors import AppCommandError, TransformerError
from .models import AppCommandChannel, AppCommandThread, Choice
from .translator import TranslationContextLocation, locale_str, Translator, TranslationContext
from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str
from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel
from ..abc import GuildChannel
from ..threads import Thread

148
discord/app_commands/translator.py

@ -23,13 +23,19 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload
from .errors import TranslationError
from ..enums import Enum, Locale
if TYPE_CHECKING:
from .commands import Command, ContextMenu, Group, Parameter
from .models import Choice
__all__ = (
'TranslationContextLocation',
'TranslationContextTypes',
'TranslationContext',
'Translator',
'locale_str',
@ -47,7 +53,11 @@ class TranslationContextLocation(Enum):
other = 7
class TranslationContext: # type: ignore # See below
_L = TypeVar('_L', bound=TranslationContextLocation)
_D = TypeVar('_D')
class TranslationContext(Generic[_L, _D]):
"""A class that provides context for the :class:`locale_str` being translated.
This is useful to determine where exactly the string is located and aid in looking
@ -63,60 +73,77 @@ class TranslationContext: # type: ignore # See below
__slots__ = ('location', 'data')
def __init__(self, location: TranslationContextLocation, data: Any) -> None:
self.location: TranslationContextLocation = location
self.data: Any = data
if TYPE_CHECKING:
# For type checking purposes, it makes sense to allow the user to leverage type narrowing
# So code like this works as expected:
# if context.type is TranslationContextLocation.command_name:
# reveal_type(context.data) # Revealed type is Command | ContextMenu
#
# Unfortunately doing a trick like this requires lying to the type checker so
# this is what the code below enables.
#
# Should this trick stop working then it might be fair to remove this code.
# It's purely here for convenience.
from .commands import Command, ContextMenu, Group, Parameter
from .models import Choice
class _CommandNameTranslationContext:
location: Literal[TranslationContextLocation.command_name]
data: Union[Command[Any, ..., Any], ContextMenu]
class _CommandDescriptionTranslationContext:
location: Literal[TranslationContextLocation.command_description]
data: Command[Any, ..., Any]
class _GroupTranslationContext:
location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description]
data: Group
class _ParameterTranslationContext:
location: Literal[TranslationContextLocation.parameter_description, TranslationContextLocation.parameter_name]
data: Parameter
class _ChoiceTranslationContext:
location: Literal[TranslationContextLocation.choice_name]
data: Choice[Union[int, str, float]]
class _OtherTranslationContext:
location: Literal[TranslationContextLocation.other]
data: Any
class TranslationContext(
_CommandNameTranslationContext,
_CommandDescriptionTranslationContext,
_GroupTranslationContext,
_ParameterTranslationContext,
_ChoiceTranslationContext,
_OtherTranslationContext,
):
def __init__(self, location: TranslationContextLocation, data: Any) -> None:
...
@overload
def __init__(
self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu]
) -> None:
...
@overload
def __init__(
self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any]
) -> None:
...
@overload
def __init__(
self,
location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description],
data: Group,
) -> None:
...
@overload
def __init__(
self,
location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description],
data: Parameter,
) -> None:
...
@overload
def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None:
...
@overload
def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None:
...
def __init__(self, location: _L, data: _D) -> None:
self.location: _L = location
self.data: _D = data
# For type checking purposes, it makes sense to allow the user to leverage type narrowing
# So code like this works as expected:
#
# if context.type == TranslationContextLocation.command_name:
# reveal_type(context.data) # Revealed type is Command | ContextMenu
#
# This requires a union of types
CommandNameTranslationContext = TranslationContext[
Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu']
]
CommandDescriptionTranslationContext = TranslationContext[
Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]'
]
GroupTranslationContext = TranslationContext[
Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group'
]
ParameterTranslationContext = TranslationContext[
Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter'
]
ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]']
OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any]
TranslationContextTypes = Union[
CommandNameTranslationContext,
CommandDescriptionTranslationContext,
GroupTranslationContext,
ParameterTranslationContext,
ChoiceTranslationContext,
OtherTranslationContext,
]
class Translator:
@ -162,7 +189,9 @@ class Translator:
"""
pass
async def _checked_translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]:
async def _checked_translate(
self, string: locale_str, locale: Locale, context: TranslationContextTypes
) -> Optional[str]:
try:
return await self.translate(string, locale, context)
except TranslationError:
@ -170,7 +199,7 @@ class Translator:
except Exception as e:
raise TranslationError(string=string, locale=locale, context=context) from e
async def translate(self, string: locale_str, locale: Locale, context: TranslationContext) -> Optional[str]:
async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]:
"""|coro|
Translates the given string to the specified locale.
@ -190,6 +219,9 @@ class Translator:
The locale being requested for translation.
context: :class:`TranslationContext`
The translation context where the string originated from.
For better type checking ergonomics, the ``TranslationContextTypes``
type can be used instead to aid with type narrowing. It is functionally
equivalent to :class:`TranslationContext`.
"""
return None

Loading…
Cancel
Save