Browse Source

Take back ownership of files from aiohttp for retrying requests.

Fix #1809
pull/2014/head
Rapptz 6 years ago
parent
commit
5e65ec978c
  1. 2
      discord/abc.py
  2. 45
      discord/file.py
  3. 16
      discord/http.py
  4. 21
      discord/webhook.py

2
discord/abc.py

@ -756,7 +756,7 @@ class Messageable(metaclass=abc.ABCMeta):
raise InvalidArgument('file parameter must be File')
try:
data = await state.http.send_files(channel.id, files=[(file.open_file(), file.filename)],
data = await state.http.send_files(channel.id, files=[file],
content=content, tts=tts, embed=embed, nonce=nonce)
finally:
file.close()

45
discord/file.py

@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
"""
import os.path
import io
class File:
"""A parameter object used for :meth:`abc.Messageable.send`
@ -52,11 +53,28 @@ class File:
Whether the attachment is a spoiler.
"""
__slots__ = ('fp', 'filename', '_true_fp')
__slots__ = ('fp', 'filename', '_original_pos', '_owner', '_closer')
def __init__(self, fp, filename=None, *, spoiler=False):
self.fp = fp
self._true_fp = None
if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()):
raise ValueError('File buffer {!r} must be seekable and readable'.format(fp))
self.fp = fp
self._original_pos = fp.tell()
self._owner = False
else:
self.fp = open(fp, 'rb')
self._original_pos = 0
self._owner = True
# aiohttp only uses two methods from IOBase
# read and close, since I want to control when the files
# close, I need to stub it so it doesn't close unless
# I tell it to
self._closer = self.fp.close
self.fp.close = lambda: None
if filename is None:
if isinstance(fp, str):
@ -66,15 +84,22 @@ class File:
else:
self.filename = filename
if spoiler and not self.filename.startswith('SPOILER_'):
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
self.filename = 'SPOILER_' + self.filename
def open_file(self):
fp = self.fp
if isinstance(fp, str):
self._true_fp = fp = open(fp, 'rb')
return fp
def reset(self, *, seek=True):
# The `seek` parameter is needed because
# the retry-loop is iterated over multiple times
# starting from 0, as an implementation quirk
# the resetting must be done at the beginning
# before a request is done, since the first index
# is 0, and thus false, then this prevents an
# unnecessary seek since it's the first request
# done.
if seek:
self.fp.seek(self._original_pos)
def close(self):
if self._true_fp:
self._true_fp.close()
self.fp.close = self._closer
if self._owner:
self._closer()

16
discord/http.py

@ -105,7 +105,7 @@ class HTTPClient:
if self._session.closed:
self._session = aiohttp.ClientSession(connector=self.connector, loop=self.loop)
async def request(self, route, *, header_bypass_delay=None, **kwargs):
async def request(self, route, *, files=None, header_bypass_delay=None, **kwargs):
bucket = route.bucket
method = route.method
url = route.url
@ -151,6 +151,10 @@ class HTTPClient:
await lock.acquire()
with MaybeUnlock(lock) as maybe_lock:
for tries in range(5):
if files:
for f in files:
f.reset(seek=tries)
async with self._session.request(method, url, **kwargs) as r:
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status)
@ -334,13 +338,13 @@ class HTTPClient:
form.add_field('payload_json', utils.to_json(payload))
if len(files) == 1:
fp = files[0]
form.add_field('file', fp[0], filename=fp[1], content_type='application/octet-stream')
file = files[0]
form.add_field('file', file.fp, filename=file.filename, content_type='application/octet-stream')
else:
for index, (buffer, filename) in enumerate(files):
form.add_field('file%s' % index, buffer, filename=filename, content_type='application/octet-stream')
for index, file in enumerate(files):
form.add_field('file%s' % index, file.fp, filename=file.filename, content_type='application/octet-stream')
return self.request(r, data=form)
return self.request(r, data=form, files=files)
async def ack_message(self, channel_id, message_id):
r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id)

21
discord/webhook.py

@ -110,17 +110,18 @@ class WebhookAdapter:
cleanup = None
if file is not None:
multipart = {
'file': (file.filename, file.open_file(), 'application/octet-stream'),
'file': (file.filename, file.fp, 'application/octet-stream'),
'payload_json': utils.to_json(payload)
}
data = None
cleanup = file.close
files_to_pass = [file]
elif files is not None:
multipart = {
'payload_json': utils.to_json(payload)
}
for i, file in enumerate(files, start=1):
multipart['file%i' % i] = (file.filename, file.open_file(), 'application/octet-stream')
multipart['file%i' % i] = (file.filename, file.fp, 'application/octet-stream')
data = None
def _anon():
@ -128,13 +129,15 @@ class WebhookAdapter:
f.close()
cleanup = _anon
files_to_pass = files
else:
data = payload
multipart = None
files_to_pass = None
url = '%s?wait=%d' % (self._request_url, wait)
try:
maybe_coro = self.request('POST', url, multipart=multipart, payload=data)
maybe_coro = self.request('POST', url, multipart=multipart, payload=data, files=files_to_pass)
finally:
if cleanup is not None:
if not asyncio.iscoroutine(maybe_coro):
@ -160,9 +163,10 @@ class AsyncWebhookAdapter(WebhookAdapter):
self.session = session
self.loop = asyncio.get_event_loop()
async def request(self, verb, url, payload=None, multipart=None):
async def request(self, verb, url, payload=None, multipart=None, *, files=None):
headers = {}
data = None
files = files or []
if payload:
headers['Content-Type'] = 'application/json'
data = utils.to_json(payload)
@ -176,6 +180,9 @@ class AsyncWebhookAdapter(WebhookAdapter):
data.add_field(key, value)
for tries in range(5):
for file in files:
file.reset(seek=tries)
async with self.session.request(verb, url, headers=headers, data=data) as r:
data = await r.text(encoding='utf-8')
if r.headers['Content-Type'] == 'application/json':
@ -239,9 +246,10 @@ class RequestsWebhookAdapter(WebhookAdapter):
self.session = session or requests
self.sleep = sleep
def request(self, verb, url, payload=None, multipart=None):
def request(self, verb, url, payload=None, multipart=None, *, files=None):
headers = {}
data = None
files = files or []
if payload:
headers['Content-Type'] = 'application/json'
data = utils.to_json(payload)
@ -250,6 +258,9 @@ class RequestsWebhookAdapter(WebhookAdapter):
data = {'payload_json': multipart.pop('payload_json')}
for tries in range(5):
for file in files:
file.reset(seek=tries)
r = self.session.request(verb, url, headers=headers, data=data, files=multipart)
r.encoding = 'utf-8'
data = r.text

Loading…
Cancel
Save