diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index dbbd2df78..12ce721c9 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -210,8 +210,8 @@ class Loop: self._task.add_done_callback(restart_when_over) self._task.cancel() - def add_exception_type(self, exc): - r"""Adds an exception type to be handled during the reconnect logic. + def add_exception_type(self, *exceptions): + r"""Adds exception types to be handled during the reconnect logic. By default the exception types handled are those handled by :meth:`discord.Client.connect`\, which includes a lot of internet disconnection @@ -222,21 +222,22 @@ class Loop: Parameters ------------ - exc: Type[:class:`BaseException`] - The exception class to handle. + \*exceptions: Type[:class:`BaseException`] + An argument list of exception classes to handle. Raises -------- TypeError - The exception passed is either not a class or not inherited from :class:`BaseException`. + An exception passed is either not a class or not inherited from :class:`BaseException`. """ - if not inspect.isclass(exc): - raise TypeError('{0!r} must be a class.'.format(exc)) - if not issubclass(exc, BaseException): - raise TypeError('{0!r} must inherit from BaseException.'.format(exc)) + for exc in exceptions: + if not inspect.isclass(exc): + raise TypeError('{0!r} must be a class.'.format(exc)) + if not issubclass(exc, BaseException): + raise TypeError('{0!r} must inherit from BaseException.'.format(exc)) - self._valid_exception = (*self._valid_exception, exc) + self._valid_exception = (*self._valid_exception, *exceptions) def clear_exception_types(self): """Removes all exception types that are handled. @@ -247,22 +248,22 @@ class Loop: """ self._valid_exception = tuple() - def remove_exception_type(self, exc): - """Removes an exception type from being handled during the reconnect logic. + def remove_exception_type(self, *exceptions): + r"""Removes exception types from being handled during the reconnect logic. Parameters ------------ - exc: Type[:class:`BaseException`] - The exception class to handle. + \*exceptions: Type[:class:`BaseException`] + An argument list of exception classes to handle. Returns --------- :class:`bool` - Whether it was successfully removed. + Whether all exceptions were successfully removed. """ old_length = len(self._valid_exception) - self._valid_exception = tuple(x for x in self._valid_exception if x is not exc) - return len(self._valid_exception) != old_length + self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) + return len(self._valid_exception) == old_length - len(exceptions) def get_task(self): """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""