diff --git a/discord/gateway.py b/discord/gateway.py index 56c1be9ec..d54ab5970 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -164,14 +164,9 @@ class VoiceKeepAliveHandler(KeepAliveHandler): self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) -# Monkey patch certain things from the aiohttp websocket code -# Check this whenever we update dependencies. -OLD_CLOSE = aiohttp.ClientWebSocketResponse.close - -async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool: - return await OLD_CLOSE(self, code=code, message=message) - -aiohttp.ClientWebSocketResponse.close = _new_ws_close +class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): + async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: + return await super().close(code=code, message=message) class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. diff --git a/discord/http.py b/discord/http.py index 4651490a5..d91664c97 100644 --- a/discord/http.py +++ b/discord/http.py @@ -34,6 +34,7 @@ import weakref import aiohttp from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound +from .gateway import DiscordClientWebSocketResponse from . import __version__, utils log = logging.getLogger(__name__) @@ -113,7 +114,7 @@ class HTTPClient: def recreate(self): if self.__session.closed: - self.__session = aiohttp.ClientSession(connector=self.connector) + self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) async def ws_connect(self, url, *, compress=0): kwargs = { @@ -279,7 +280,7 @@ class HTTPClient: async def static_login(self, token, *, bot): # Necessary to get aiohttp to stop complaining about session creation - self.__session = aiohttp.ClientSession(connector=self.connector) + self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) old_token, old_bot = self.token, self.bot_token self._token(token, bot=bot)