Browse Source

Fix race condition in handling of binary attachments

Fixes #37
pull/41/head
Miguel Grinberg 9 years ago
parent
commit
024609e10e
  1. 13
      socketio/packet.py
  2. 30
      socketio/server.py
  3. 18
      tests/test_packet.py
  4. 13
      tests/test_server.py

13
socketio/packet.py

@ -28,6 +28,7 @@ class Packet(object):
else: else:
raise ValueError('Packet does not support binary payload.') raise ValueError('Packet does not support binary payload.')
self.attachment_count = 0 self.attachment_count = 0
self.attachments = []
if encoded_packet: if encoded_packet:
self.attachment_count = self.decode(encoded_packet) self.attachment_count = self.decode(encoded_packet)
@ -100,11 +101,21 @@ class Packet(object):
self.data = self.json.loads(ep) self.data = self.json.loads(ep)
return attachment_count return attachment_count
def add_attachment(self, attachment):
if self.attachment_count <= len(self.attachments):
raise ValueError('Unexpected binary attachment')
self.attachments.append(attachment)
if self.attachment_count == len(self.attachments):
self.reconstruct_binary(self.attachments)
return True
return False
def reconstruct_binary(self, attachments): def reconstruct_binary(self, attachments):
"""Reconstruct a decoded packet using the given list of binary """Reconstruct a decoded packet using the given list of binary
attachments. attachments.
""" """
self.data = self._reconstruct_binary_internal(self.data, attachments) self.data = self._reconstruct_binary_internal(self.data,
self.attachments)
def _reconstruct_binary_internal(self, data, attachments): def _reconstruct_binary_internal(self, data, attachments):
if isinstance(data, list): if isinstance(data, list):

30
socketio/server.py

@ -78,9 +78,7 @@ class Server(object):
self.environ = {} self.environ = {}
self.handlers = {} self.handlers = {}
self._binary_packet = None self._binary_packet = []
self._attachment_count = 0
self._attachments = []
if not isinstance(logger, bool): if not isinstance(logger, bool):
self.logger = logger self.logger = logger
@ -434,22 +432,14 @@ class Server(object):
def _handle_eio_message(self, sid, data): def _handle_eio_message(self, sid, data):
"""Dispatch Engine.IO messages.""" """Dispatch Engine.IO messages."""
if self._attachment_count > 0: if len(self._binary_packet):
self._attachments.append(data) pkt = self._binary_packet[0]
self._attachment_count -= 1 if pkt.add_attachment(data):
self._binary_packet.pop(0)
if self._attachment_count == 0: if pkt.packet_type == packet.BINARY_EVENT:
self._binary_packet.reconstruct_binary(self._attachments) self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
if self._binary_packet.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, self._binary_packet.namespace,
self._binary_packet.id,
self._binary_packet.data)
else: else:
self._handle_ack(sid, self._binary_packet.namespace, self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
self._binary_packet.id,
self._binary_packet.data)
self._binary_packet = None
self._attachments = []
else: else:
pkt = packet.Packet(encoded_packet=data) pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
@ -462,9 +452,7 @@ class Server(object):
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \ elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK: pkt.packet_type == packet.BINARY_ACK:
self._binary_packet = pkt self._binary_packet.append(pkt)
self._attachments = []
self._attachment_count = pkt.attachment_count
elif pkt.packet_type == packet.ERROR: elif pkt.packet_type == packet.ERROR:
raise ValueError('Unexpected ERROR packet.') raise ValueError('Unexpected ERROR packet.')
else: else:

18
tests/test_packet.py

@ -50,7 +50,7 @@ class TestPacket(unittest.TestCase):
def test_decode_binary_event_packet(self): def test_decode_binary_event_packet(self):
pkt = packet.Packet(encoded_packet='51-{"_placeholder":true,"num":0}') pkt = packet.Packet(encoded_packet='51-{"_placeholder":true,"num":0}')
pkt.reconstruct_binary([b'1234']) self.assertTrue(pkt.add_attachment(b'1234'))
self.assertEqual(pkt.packet_type, packet.BINARY_EVENT) self.assertEqual(pkt.packet_type, packet.BINARY_EVENT)
self.assertEqual(pkt.data, b'1234') self.assertEqual(pkt.data, b'1234')
@ -78,7 +78,7 @@ class TestPacket(unittest.TestCase):
def test_decode_binary_ack_packet(self): def test_decode_binary_ack_packet(self):
pkt = packet.Packet(encoded_packet='61-{"_placeholder":true,"num":0}') pkt = packet.Packet(encoded_packet='61-{"_placeholder":true,"num":0}')
pkt.reconstruct_binary([b'1234']) self.assertTrue(pkt.add_attachment(b'1234'))
self.assertEqual(pkt.packet_type, packet.BINARY_ACK) self.assertEqual(pkt.packet_type, packet.BINARY_ACK)
self.assertEqual(pkt.data, b'1234') self.assertEqual(pkt.data, b'1234')
@ -157,7 +157,8 @@ class TestPacket(unittest.TestCase):
pkt = packet.Packet(encoded_packet=( pkt = packet.Packet(encoded_packet=(
'52-{"a":"123","b":{"_placeholder":true,"num":0},' '52-{"a":"123","b":{"_placeholder":true,"num":0},'
'"c":[{"_placeholder":true,"num":1},123]}')) '"c":[{"_placeholder":true,"num":1},123]}'))
pkt.reconstruct_binary([b'456', b'789']) self.assertFalse(pkt.add_attachment(b'456'))
self.assertTrue(pkt.add_attachment(b'789'))
self.assertEqual(pkt.packet_type, packet.BINARY_EVENT) self.assertEqual(pkt.packet_type, packet.BINARY_EVENT)
self.assertEqual(pkt.data['a'], '123') self.assertEqual(pkt.data['a'], '123')
self.assertEqual(pkt.data['b'], b'456') self.assertEqual(pkt.data['b'], b'456')
@ -167,12 +168,21 @@ class TestPacket(unittest.TestCase):
pkt = packet.Packet(encoded_packet=( pkt = packet.Packet(encoded_packet=(
'62-{"a":"123","b":{"_placeholder":true,"num":0},' '62-{"a":"123","b":{"_placeholder":true,"num":0},'
'"c":[{"_placeholder":true,"num":1},123]}')) '"c":[{"_placeholder":true,"num":1},123]}'))
pkt.reconstruct_binary([b'456', b'789']) self.assertFalse(pkt.add_attachment(b'456'))
self.assertTrue(pkt.add_attachment(b'789'))
self.assertEqual(pkt.packet_type, packet.BINARY_ACK) self.assertEqual(pkt.packet_type, packet.BINARY_ACK)
self.assertEqual(pkt.data['a'], '123') self.assertEqual(pkt.data['a'], '123')
self.assertEqual(pkt.data['b'], b'456') self.assertEqual(pkt.data['b'], b'456')
self.assertEqual(pkt.data['c'], [b'789', 123]) self.assertEqual(pkt.data['c'], [b'789', 123])
def test_decode_too_many_binary_packets(self):
pkt = packet.Packet(encoded_packet=(
'62-{"a":"123","b":{"_placeholder":true,"num":0},'
'"c":[{"_placeholder":true,"num":1},123]}'))
self.assertFalse(pkt.add_attachment(b'456'))
self.assertTrue(pkt.add_attachment(b'789'))
self.assertRaises(ValueError, pkt.add_attachment, b'123')
def test_data_is_binary_list(self): def test_data_is_binary_list(self):
pkt = packet.Packet() pkt = packet.Packet()
self.assertFalse(pkt._data_is_binary([six.text_type('foo')])) self.assertFalse(pkt._data_is_binary([six.text_type('foo')]))

13
tests/test_server.py

@ -283,22 +283,19 @@ class TestServer(unittest.TestCase):
s._handle_eio_message('123', '52-["my message","a",' s._handle_eio_message('123', '52-["my message","a",'
'{"_placeholder":true,"num":1},' '{"_placeholder":true,"num":1},'
'{"_placeholder":true,"num":0}]') '{"_placeholder":true,"num":0}]')
self.assertEqual(s._attachment_count, 2)
s._handle_eio_message('123', b'foo') s._handle_eio_message('123', b'foo')
self.assertEqual(s._attachment_count, 1)
s._handle_eio_message('123', b'bar') s._handle_eio_message('123', b'bar')
self.assertEqual(s._attachment_count, 0)
handler.assert_called_once_with('123', 'a', b'bar', b'foo') handler.assert_called_once_with('123', 'a', b'bar', b'foo')
def test_handle_event_binary_ack(self, eio): def test_handle_event_binary_ack(self, eio):
s = server.Server() mgr = mock.MagicMock()
s = server.Server(client_manager=mgr)
s.manager.initialize(s) s.manager.initialize(s)
s._handle_eio_message('123', '61-1["my message","a",' s._handle_eio_message('123', '61-321["my message","a",'
'{"_placeholder":true,"num":0}]') '{"_placeholder":true,"num":0}]')
self.assertEqual(s._attachment_count, 1)
# the following call should not raise an exception in spite of the
# callback id being invalid
s._handle_eio_message('123', b'foo') s._handle_eio_message('123', b'foo')
mgr.trigger_callback.assert_called_once_with(
'123', '/', 321, ['my message', 'a', b'foo'])
def test_handle_event_with_ack(self, eio): def test_handle_event_with_ack(self, eio):
s = server.Server() s = server.Server()

Loading…
Cancel
Save