diff --git a/discord/http.py b/discord/http.py index 48ee88b93..0ba8c2ce1 100644 --- a/discord/http.py +++ b/discord/http.py @@ -42,6 +42,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) from urllib.parse import quote as _uriquote import weakref @@ -51,6 +52,7 @@ import aiohttp from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound from .gateway import DiscordClientWebSocketResponse from . import __version__, utils +from .utils import MISSING log = logging.getLogger(__name__) @@ -93,7 +95,7 @@ if TYPE_CHECKING: Response = Coroutine[Any, Any, T] -async def json_or_text(response: aiohttp.ClientResponse) -> Any: +async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any], str]: text = await response.text(encoding='utf-8') try: if response.headers['content-type'] == 'application/json': @@ -170,7 +172,7 @@ class HTTPClient: ) -> None: self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.connector = connector - self.__session: Optional[aiohttp.ClientSession] = None # filled in static_login + self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._global_over: asyncio.Event = asyncio.Event() self._global_over.set() @@ -223,7 +225,7 @@ class HTTPClient: self._locks[bucket] = lock # header creation - headers = { + headers: Dict[str, str] = { 'User-Agent': self.user_agent, } @@ -254,6 +256,8 @@ class HTTPClient: # wait until the global lock is complete await self._global_over.wait() + response: Optional[aiohttp.ClientResponse] = None + data: Optional[Union[Dict[str, Any], str]] = None await lock.acquire() with MaybeUnlock(lock) as maybe_lock: for tries in range(5): @@ -268,36 +272,36 @@ class HTTPClient: kwargs['data'] = form_data try: - async with self.__session.request(method, url, **kwargs) as r: - log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) + async with self.__session.request(method, url, **kwargs) as response: + log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status) # even errors have text involved in them so this is safe to call - data = await json_or_text(r) + data = await json_or_text(response) # check if we have rate limit header information - remaining = r.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and r.status != 429: + remaining = response.headers.get('X-Ratelimit-Remaining') + if remaining == '0' and response.status != 429: # we've depleted our current bucket - delta = utils._parse_ratelimit_header(r, use_clock=self.use_clock) + delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) maybe_lock.defer() self.loop.call_later(delta, lock.release) # the request was successful so just return the text/json - if 300 > r.status >= 200: + if 300 > response.status >= 200: log.debug('%s %s has received %s', method, url, data) return data # we are being rate limited - if r.status == 429: - if not r.headers.get('Via'): + if response.status == 429: + if not response.headers.get('Via') or isinstance(data, str): # Banned by Cloudflare more than likely. - raise HTTPException(r, data) + raise HTTPException(response, data) fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' # sleep a bit - retry_after: float = data['retry_after'] # type: ignore + retry_after: float = data['retry_after'] log.warning(fmt, retry_after, bucket) # check if it's a global rate limit @@ -318,19 +322,19 @@ class HTTPClient: continue # we've received a 500 or 502, unconditional retry - if r.status in {500, 502}: + if response.status in {500, 502}: await asyncio.sleep(1 + tries * 2) continue # the usual error cases - if r.status == 403: - raise Forbidden(r, data) - elif r.status == 404: - raise NotFound(r, data) - elif r.status == 503: - raise DiscordServerError(r, data) + if response.status == 403: + raise Forbidden(response, data) + elif response.status == 404: + raise NotFound(response, data) + elif response.status == 503: + raise DiscordServerError(response, data) else: - raise HTTPException(r, data) + raise HTTPException(response, data) # This is handling exceptions from the request except OSError as e: @@ -340,11 +344,14 @@ class HTTPClient: continue raise - # We've run out of retries, raise. - if r.status >= 500: - raise DiscordServerError(r, data) + if response is not None: + # We've run out of retries, raise. + if response.status >= 500: + raise DiscordServerError(response, data) - raise HTTPException(r, data) + raise HTTPException(response, data) + + raise RuntimeError('Unreachable code in HTTP handling') async def get_from_cdn(self, url: str) -> bytes: async with self.__session.get(url) as resp: @@ -375,7 +382,7 @@ class HTTPClient: data = await self.request(Route('GET', '/users/@me')) except HTTPException as exc: self.token = old_token - if exc.response.status == 401: + if exc.status == 401: raise LoginFailure('Improper token has been passed.') from exc raise @@ -597,7 +604,7 @@ class HTTPClient: emoji=emoji, ) - params = { + params: Dict[str, Any] = { 'limit': limit, } if after: