committed by
GitHub
1 changed files with 199 additions and 0 deletions
@ -0,0 +1,199 @@ |
|||
""" |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2015-present Rapptz |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a |
|||
copy of this software and associated documentation files (the "Software"), |
|||
to deal in the Software without restriction, including without limitation |
|||
the rights to use, copy, modify, merge, publish, distribute, sublicense, |
|||
and/or sell copies of the Software, and to permit persons to whom the |
|||
Software is furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
|||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
|||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|||
DEALINGS IN THE SOFTWARE. |
|||
""" |
|||
from __future__ import annotations |
|||
|
|||
|
|||
from functools import wraps |
|||
import pytest |
|||
from typing import Awaitable, TYPE_CHECKING, Callable, Coroutine, Optional, TypeVar, Any, Type, List, Union |
|||
|
|||
import discord |
|||
|
|||
if TYPE_CHECKING: |
|||
|
|||
from typing_extensions import ParamSpec |
|||
from discord.types.interactions import ( |
|||
ApplicationCommandInteraction as ApplicationCommandInteractionPayload, |
|||
ChatInputApplicationCommandInteractionData as ChatInputApplicationCommandInteractionDataPayload, |
|||
ApplicationCommandInteractionDataOption as ApplicationCommandInteractionDataOptionPayload, |
|||
) |
|||
|
|||
P = ParamSpec('P') |
|||
|
|||
|
|||
T = TypeVar('T') |
|||
|
|||
|
|||
class MockCommandInteraction(discord.Interaction): |
|||
@classmethod |
|||
def _get_command_options(cls, **options: str) -> List[ApplicationCommandInteractionDataOptionPayload]: |
|||
return [ |
|||
{ |
|||
'type': discord.AppCommandOptionType.string.value, |
|||
'name': name, |
|||
'value': value, |
|||
} |
|||
for name, value in options.items() |
|||
] |
|||
|
|||
@classmethod |
|||
def _get_command_data( |
|||
cls, |
|||
command: Union[discord.app_commands.Command[Any, ..., Any], discord.app_commands.Group], |
|||
options: List[ApplicationCommandInteractionDataOptionPayload], |
|||
) -> ChatInputApplicationCommandInteractionDataPayload: |
|||
|
|||
data: Union[ChatInputApplicationCommandInteractionDataPayload, ApplicationCommandInteractionDataOptionPayload] = { |
|||
'type': discord.AppCommandType.chat_input.value, |
|||
'name': command.name, |
|||
'options': options, |
|||
} |
|||
|
|||
if command.parent is None: |
|||
data['id'] = hash(command) # type: ignore # narrowing isn't possible |
|||
return data # type: ignore # see above |
|||
else: |
|||
return cls._get_command_data(command.parent, [data]) |
|||
|
|||
def __init__( |
|||
self, |
|||
client: discord.Client, |
|||
command: discord.app_commands.Command[Any, ..., Any], |
|||
**options: str, |
|||
) -> None: |
|||
|
|||
data: ApplicationCommandInteractionPayload = { |
|||
"id": 0, |
|||
"application_id": 0, |
|||
"token": "", |
|||
"version": 1, |
|||
"type": 2, |
|||
"data": self._get_command_data(command, self._get_command_options(**options)), |
|||
} |
|||
super().__init__(data=data, state=client._connection) |
|||
|
|||
|
|||
client = discord.Client() |
|||
|
|||
|
|||
class MockTree(discord.app_commands.CommandTree): |
|||
last_exception: Optional[discord.app_commands.AppCommandError] |
|||
|
|||
async def call(self, interaction: discord.Interaction) -> None: |
|||
self.last_exception = None |
|||
return await super().call(interaction) |
|||
|
|||
async def on_error( |
|||
self, interaction: discord.Interaction, command: Any, error: discord.app_commands.AppCommandError |
|||
) -> None: |
|||
self.last_exception = error |
|||
|
|||
|
|||
tree = MockTree(client) |
|||
|
|||
|
|||
@tree.command() |
|||
async def test_command(interaction: discord.Interaction, foo: str) -> None: |
|||
pass |
|||
|
|||
|
|||
def wrapper(func: Callable[P, Awaitable[T]]) -> Callable[P, Coroutine[Any, Any, T]]: |
|||
@wraps(func) |
|||
async def deco(*args: P.args, **kwargs: P.kwargs) -> T: |
|||
return await func(*args, **kwargs) |
|||
|
|||
return deco |
|||
|
|||
|
|||
@tree.command() |
|||
@wrapper |
|||
async def test_wrapped_command(interaction: discord.Interaction, foo: str) -> None: |
|||
pass |
|||
|
|||
|
|||
@tree.command() |
|||
async def test_command_raises(interaction: discord.Interaction, foo: str) -> None: |
|||
raise TypeError |
|||
|
|||
|
|||
@tree.command() |
|||
@wrapper |
|||
async def test_wrapped_command_raises(interaction: discord.Interaction, foo: str) -> None: |
|||
raise TypeError |
|||
|
|||
|
|||
group = discord.app_commands.Group(name='group', description='...') |
|||
test_subcommand = group.command()(test_command.callback) |
|||
test_wrapped_subcommand = group.command()(test_wrapped_command.callback) |
|||
test_subcommand_raises = group.command()(test_command_raises.callback) |
|||
test_wrapped_subcommand_raises = group.command()(test_wrapped_command_raises.callback) |
|||
tree.add_command(group) |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
('command', 'raises'), |
|||
[ |
|||
(test_command, None), |
|||
(test_wrapped_command, None), |
|||
(test_command_raises, TypeError), |
|||
(test_wrapped_command_raises, TypeError), |
|||
(test_subcommand, None), |
|||
(test_wrapped_subcommand, None), |
|||
(test_subcommand_raises, TypeError), |
|||
(test_wrapped_subcommand_raises, TypeError), |
|||
], |
|||
) |
|||
@pytest.mark.asyncio |
|||
async def test_valid_command_invoke( |
|||
command: discord.app_commands.Command[Any, ..., Any], raises: Optional[Type[BaseException]] |
|||
): |
|||
interaction = MockCommandInteraction(client, command, foo='foo') |
|||
await tree.call(interaction) |
|||
|
|||
if raises is None: |
|||
assert tree.last_exception is None |
|||
else: |
|||
assert isinstance(tree.last_exception, discord.app_commands.CommandInvokeError) |
|||
assert isinstance(tree.last_exception.original, raises) |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
('command',), |
|||
[ |
|||
(test_command,), |
|||
(test_wrapped_command,), |
|||
(test_command_raises,), |
|||
(test_wrapped_command_raises,), |
|||
(test_subcommand,), |
|||
(test_subcommand_raises,), |
|||
(test_wrapped_subcommand,), |
|||
(test_wrapped_subcommand_raises,), |
|||
], |
|||
) |
|||
@pytest.mark.asyncio |
|||
async def test_invalid_command_invoke(command: discord.app_commands.Command[Any, ..., Any]): |
|||
interaction = MockCommandInteraction(client, command, bar='bar') |
|||
await tree.call(interaction) |
|||
|
|||
assert isinstance(tree.last_exception, discord.app_commands.CommandSignatureMismatch) |
Loading…
Reference in new issue