From 40176bb71b903461499c74e6f55247a3d71ac0cb Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 9 May 2022 17:24:48 -0400 Subject: [PATCH] Fix bound method autocomplete callbacks not working This also adds some regression tests --- discord/app_commands/commands.py | 5 +- tests/test_app_commands_autocomplete.py | 130 ++++++++++++++++++++++++ tests/test_app_commands_invoke.py | 4 +- 3 files changed, 134 insertions(+), 5 deletions(-) create mode 100644 tests/test_app_commands_autocomplete.py diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index dbe715d59..b528430a4 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -233,8 +233,9 @@ def _validate_auto_complete_callback( binding = getattr(callback, '__self__', None) if binding is not None: callback = callback.__func__ - - requires_binding = (binding is None and is_inside_class(callback)) or skip_binding + requires_binding = True + else: + requires_binding = is_inside_class(callback) or skip_binding callback.requires_binding = requires_binding callback.binding = binding diff --git a/tests/test_app_commands_autocomplete.py b/tests/test_app_commands_autocomplete.py new file mode 100644 index 000000000..ee3e54459 --- /dev/null +++ b/tests/test_app_commands_autocomplete.py @@ -0,0 +1,130 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import List + +import discord +import pytest +from discord import app_commands +from discord.utils import MISSING + + +async def free_function_autocomplete(interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + return [] + + +async def invalid_free_function(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + return [] + + +class X(app_commands.Transformer): + @classmethod + async def autocomplete(cls, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + return [] + + +class ClassBased: + async def autocomplete(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + return [] + + async def invalid(self, interaction: discord.Interaction, current: str, bad: int) -> List[app_commands.Choice[str]]: + return [] + + +lookup = ClassBased() +bound_autocomplete = lookup.autocomplete +invalid_bound_autocomplete = lookup.invalid + + +def test_free_function_autocomplete(): + @app_commands.command() + @app_commands.autocomplete(name=free_function_autocomplete) + async def cmd(interaction: discord.Interaction, name: str): + ... + + param = cmd._params['name'] + assert param.autocomplete is not MISSING + assert param.autocomplete.binding is None # type: ignore + assert not param.autocomplete.requires_binding # type: ignore + + +def test_invalid_free_function_autocomplete(): + with pytest.raises(TypeError): + + @app_commands.command() + @app_commands.autocomplete(name=invalid_free_function) + async def cmd(interaction: discord.Interaction, name: str): + ... + + +def test_transformer_autocomplete(): + @app_commands.command() + async def cmd(interaction: discord.Interaction, param: app_commands.Transform[str, X]): + ... + + param = cmd._params['param'] + assert param.autocomplete is not MISSING + assert param.autocomplete.binding is X # type: ignore + assert param.autocomplete.requires_binding # type: ignore + + +def test_bound_function_autocomplete(): + @app_commands.command() + @app_commands.autocomplete(name=bound_autocomplete) + async def cmd(interaction: discord.Interaction, name: str): + ... + + param = cmd._params['name'] + assert param.autocomplete is not MISSING + assert param.autocomplete.binding is lookup # type: ignore + assert param.autocomplete.requires_binding # type: ignore + + +def test_invalid_bound_function_autocomplete(): + with pytest.raises(TypeError): + + @app_commands.command() + @app_commands.autocomplete(name=invalid_bound_autocomplete) # type: ignore + async def cmd(interaction: discord.Interaction, name: str): + ... + + +def test_group_function_autocomplete(): + class MyGroup(app_commands.Group): + @app_commands.command() + async def foo(self, interaction: discord.Interaction, name: str): + ... + + @foo.autocomplete('name') + async def autocomplete(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + return [] + + g = MyGroup() + param = g.foo._params['name'] + + assert param.autocomplete is not MISSING + # Note: The binding is filled later when actually invoked + assert param.autocomplete.binding is None # type: ignore + assert param.autocomplete.requires_binding # type: ignore diff --git a/tests/test_app_commands_invoke.py b/tests/test_app_commands_invoke.py index 51fc88e9a..38ead3e56 100644 --- a/tests/test_app_commands_invoke.py +++ b/tests/test_app_commands_invoke.py @@ -104,9 +104,7 @@ class MockTree(discord.app_commands.CommandTree): self.last_exception = None return await super().call(interaction) - async def on_error( - self, interaction: discord.Interaction, error: discord.app_commands.AppCommandError - ) -> None: + async def on_error(self, interaction: discord.Interaction, error: discord.app_commands.AppCommandError) -> None: self.last_exception = error