From 5e65ec978cf278fefcc5586e2df732bc7a8bed4e Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 18 Mar 2019 07:54:36 -0400 Subject: [PATCH] Take back ownership of files from aiohttp for retrying requests. Fix #1809 --- discord/abc.py | 2 +- discord/file.py | 45 +++++++++++++++++++++++++++++++++++---------- discord/http.py | 16 ++++++++++------ discord/webhook.py | 21 ++++++++++++++++----- 4 files changed, 62 insertions(+), 22 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 8345786c2..2e3e1419c 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -756,7 +756,7 @@ class Messageable(metaclass=abc.ABCMeta): raise InvalidArgument('file parameter must be File') try: - data = await state.http.send_files(channel.id, files=[(file.open_file(), file.filename)], + data = await state.http.send_files(channel.id, files=[file], content=content, tts=tts, embed=embed, nonce=nonce) finally: file.close() diff --git a/discord/file.py b/discord/file.py index d3a775fd2..7240b8842 100644 --- a/discord/file.py +++ b/discord/file.py @@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE. """ import os.path +import io class File: """A parameter object used for :meth:`abc.Messageable.send` @@ -52,11 +53,28 @@ class File: Whether the attachment is a spoiler. """ - __slots__ = ('fp', 'filename', '_true_fp') + __slots__ = ('fp', 'filename', '_original_pos', '_owner', '_closer') def __init__(self, fp, filename=None, *, spoiler=False): self.fp = fp - self._true_fp = None + + if isinstance(fp, io.IOBase): + if not (fp.seekable() and fp.readable()): + raise ValueError('File buffer {!r} must be seekable and readable'.format(fp)) + self.fp = fp + self._original_pos = fp.tell() + self._owner = False + else: + self.fp = open(fp, 'rb') + self._original_pos = 0 + self._owner = True + + # aiohttp only uses two methods from IOBase + # read and close, since I want to control when the files + # close, I need to stub it so it doesn't close unless + # I tell it to + self._closer = self.fp.close + self.fp.close = lambda: None if filename is None: if isinstance(fp, str): @@ -66,15 +84,22 @@ class File: else: self.filename = filename - if spoiler and not self.filename.startswith('SPOILER_'): + if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'): self.filename = 'SPOILER_' + self.filename - def open_file(self): - fp = self.fp - if isinstance(fp, str): - self._true_fp = fp = open(fp, 'rb') - return fp + def reset(self, *, seek=True): + # The `seek` parameter is needed because + # the retry-loop is iterated over multiple times + # starting from 0, as an implementation quirk + # the resetting must be done at the beginning + # before a request is done, since the first index + # is 0, and thus false, then this prevents an + # unnecessary seek since it's the first request + # done. + if seek: + self.fp.seek(self._original_pos) def close(self): - if self._true_fp: - self._true_fp.close() + self.fp.close = self._closer + if self._owner: + self._closer() diff --git a/discord/http.py b/discord/http.py index 6032d1044..7e34951f6 100644 --- a/discord/http.py +++ b/discord/http.py @@ -105,7 +105,7 @@ class HTTPClient: if self._session.closed: self._session = aiohttp.ClientSession(connector=self.connector, loop=self.loop) - async def request(self, route, *, header_bypass_delay=None, **kwargs): + async def request(self, route, *, files=None, header_bypass_delay=None, **kwargs): bucket = route.bucket method = route.method url = route.url @@ -151,6 +151,10 @@ class HTTPClient: await lock.acquire() with MaybeUnlock(lock) as maybe_lock: for tries in range(5): + if files: + for f in files: + f.reset(seek=tries) + async with self._session.request(method, url, **kwargs) as r: log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status) @@ -334,13 +338,13 @@ class HTTPClient: form.add_field('payload_json', utils.to_json(payload)) if len(files) == 1: - fp = files[0] - form.add_field('file', fp[0], filename=fp[1], content_type='application/octet-stream') + file = files[0] + form.add_field('file', file.fp, filename=file.filename, content_type='application/octet-stream') else: - for index, (buffer, filename) in enumerate(files): - form.add_field('file%s' % index, buffer, filename=filename, content_type='application/octet-stream') + for index, file in enumerate(files): + form.add_field('file%s' % index, file.fp, filename=file.filename, content_type='application/octet-stream') - return self.request(r, data=form) + return self.request(r, data=form, files=files) async def ack_message(self, channel_id, message_id): r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id) diff --git a/discord/webhook.py b/discord/webhook.py index 7e130f199..35f553ffa 100644 --- a/discord/webhook.py +++ b/discord/webhook.py @@ -110,17 +110,18 @@ class WebhookAdapter: cleanup = None if file is not None: multipart = { - 'file': (file.filename, file.open_file(), 'application/octet-stream'), + 'file': (file.filename, file.fp, 'application/octet-stream'), 'payload_json': utils.to_json(payload) } data = None cleanup = file.close + files_to_pass = [file] elif files is not None: multipart = { 'payload_json': utils.to_json(payload) } for i, file in enumerate(files, start=1): - multipart['file%i' % i] = (file.filename, file.open_file(), 'application/octet-stream') + multipart['file%i' % i] = (file.filename, file.fp, 'application/octet-stream') data = None def _anon(): @@ -128,13 +129,15 @@ class WebhookAdapter: f.close() cleanup = _anon + files_to_pass = files else: data = payload multipart = None + files_to_pass = None url = '%s?wait=%d' % (self._request_url, wait) try: - maybe_coro = self.request('POST', url, multipart=multipart, payload=data) + maybe_coro = self.request('POST', url, multipart=multipart, payload=data, files=files_to_pass) finally: if cleanup is not None: if not asyncio.iscoroutine(maybe_coro): @@ -160,9 +163,10 @@ class AsyncWebhookAdapter(WebhookAdapter): self.session = session self.loop = asyncio.get_event_loop() - async def request(self, verb, url, payload=None, multipart=None): + async def request(self, verb, url, payload=None, multipart=None, *, files=None): headers = {} data = None + files = files or [] if payload: headers['Content-Type'] = 'application/json' data = utils.to_json(payload) @@ -176,6 +180,9 @@ class AsyncWebhookAdapter(WebhookAdapter): data.add_field(key, value) for tries in range(5): + for file in files: + file.reset(seek=tries) + async with self.session.request(verb, url, headers=headers, data=data) as r: data = await r.text(encoding='utf-8') if r.headers['Content-Type'] == 'application/json': @@ -239,9 +246,10 @@ class RequestsWebhookAdapter(WebhookAdapter): self.session = session or requests self.sleep = sleep - def request(self, verb, url, payload=None, multipart=None): + def request(self, verb, url, payload=None, multipart=None, *, files=None): headers = {} data = None + files = files or [] if payload: headers['Content-Type'] = 'application/json' data = utils.to_json(payload) @@ -250,6 +258,9 @@ class RequestsWebhookAdapter(WebhookAdapter): data = {'payload_json': multipart.pop('payload_json')} for tries in range(5): + for file in files: + file.reset(seek=tries) + r = self.session.request(verb, url, headers=headers, data=data, files=multipart) r.encoding = 'utf-8' data = r.text