diff --git a/socketio/packet.py b/socketio/packet.py index c64ff3c..2834aee 100644 --- a/socketio/packet.py +++ b/socketio/packet.py @@ -28,6 +28,7 @@ class Packet(object): else: raise ValueError('Packet does not support binary payload.') self.attachment_count = 0 + self.attachments = [] if encoded_packet: self.attachment_count = self.decode(encoded_packet) @@ -100,11 +101,21 @@ class Packet(object): self.data = self.json.loads(ep) return attachment_count + def add_attachment(self, attachment): + if self.attachment_count <= len(self.attachments): + raise ValueError('Unexpected binary attachment') + self.attachments.append(attachment) + if self.attachment_count == len(self.attachments): + self.reconstruct_binary(self.attachments) + return True + return False + def reconstruct_binary(self, attachments): """Reconstruct a decoded packet using the given list of binary attachments. """ - self.data = self._reconstruct_binary_internal(self.data, attachments) + self.data = self._reconstruct_binary_internal(self.data, + self.attachments) def _reconstruct_binary_internal(self, data, attachments): if isinstance(data, list): diff --git a/socketio/server.py b/socketio/server.py index 613bd4e..b1e4154 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -78,9 +78,7 @@ class Server(object): self.environ = {} self.handlers = {} - self._binary_packet = None - self._attachment_count = 0 - self._attachments = [] + self._binary_packet = [] if not isinstance(logger, bool): self.logger = logger @@ -434,22 +432,14 @@ class Server(object): def _handle_eio_message(self, sid, data): """Dispatch Engine.IO messages.""" - if self._attachment_count > 0: - self._attachments.append(data) - self._attachment_count -= 1 - - if self._attachment_count == 0: - self._binary_packet.reconstruct_binary(self._attachments) - if self._binary_packet.packet_type == packet.BINARY_EVENT: - self._handle_event(sid, self._binary_packet.namespace, - self._binary_packet.id, - self._binary_packet.data) + if len(self._binary_packet): + pkt = self._binary_packet[0] + if pkt.add_attachment(data): + self._binary_packet.pop(0) + if pkt.packet_type == packet.BINARY_EVENT: + self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) else: - self._handle_ack(sid, self._binary_packet.namespace, - self._binary_packet.id, - self._binary_packet.data) - self._binary_packet = None - self._attachments = [] + self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: @@ -462,9 +452,7 @@ class Server(object): self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: - self._binary_packet = pkt - self._attachments = [] - self._attachment_count = pkt.attachment_count + self._binary_packet.append(pkt) elif pkt.packet_type == packet.ERROR: raise ValueError('Unexpected ERROR packet.') else: diff --git a/tests/test_packet.py b/tests/test_packet.py index 7bb56ba..039c2a5 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -50,7 +50,7 @@ class TestPacket(unittest.TestCase): def test_decode_binary_event_packet(self): pkt = packet.Packet(encoded_packet='51-{"_placeholder":true,"num":0}') - pkt.reconstruct_binary([b'1234']) + self.assertTrue(pkt.add_attachment(b'1234')) self.assertEqual(pkt.packet_type, packet.BINARY_EVENT) self.assertEqual(pkt.data, b'1234') @@ -78,7 +78,7 @@ class TestPacket(unittest.TestCase): def test_decode_binary_ack_packet(self): pkt = packet.Packet(encoded_packet='61-{"_placeholder":true,"num":0}') - pkt.reconstruct_binary([b'1234']) + self.assertTrue(pkt.add_attachment(b'1234')) self.assertEqual(pkt.packet_type, packet.BINARY_ACK) self.assertEqual(pkt.data, b'1234') @@ -157,7 +157,8 @@ class TestPacket(unittest.TestCase): pkt = packet.Packet(encoded_packet=( '52-{"a":"123","b":{"_placeholder":true,"num":0},' '"c":[{"_placeholder":true,"num":1},123]}')) - pkt.reconstruct_binary([b'456', b'789']) + self.assertFalse(pkt.add_attachment(b'456')) + self.assertTrue(pkt.add_attachment(b'789')) self.assertEqual(pkt.packet_type, packet.BINARY_EVENT) self.assertEqual(pkt.data['a'], '123') self.assertEqual(pkt.data['b'], b'456') @@ -167,12 +168,21 @@ class TestPacket(unittest.TestCase): pkt = packet.Packet(encoded_packet=( '62-{"a":"123","b":{"_placeholder":true,"num":0},' '"c":[{"_placeholder":true,"num":1},123]}')) - pkt.reconstruct_binary([b'456', b'789']) + self.assertFalse(pkt.add_attachment(b'456')) + self.assertTrue(pkt.add_attachment(b'789')) self.assertEqual(pkt.packet_type, packet.BINARY_ACK) self.assertEqual(pkt.data['a'], '123') self.assertEqual(pkt.data['b'], b'456') self.assertEqual(pkt.data['c'], [b'789', 123]) + def test_decode_too_many_binary_packets(self): + pkt = packet.Packet(encoded_packet=( + '62-{"a":"123","b":{"_placeholder":true,"num":0},' + '"c":[{"_placeholder":true,"num":1},123]}')) + self.assertFalse(pkt.add_attachment(b'456')) + self.assertTrue(pkt.add_attachment(b'789')) + self.assertRaises(ValueError, pkt.add_attachment, b'123') + def test_data_is_binary_list(self): pkt = packet.Packet() self.assertFalse(pkt._data_is_binary([six.text_type('foo')])) diff --git a/tests/test_server.py b/tests/test_server.py index 665e3a3..6ade1d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -283,22 +283,19 @@ class TestServer(unittest.TestCase): s._handle_eio_message('123', '52-["my message","a",' '{"_placeholder":true,"num":1},' '{"_placeholder":true,"num":0}]') - self.assertEqual(s._attachment_count, 2) s._handle_eio_message('123', b'foo') - self.assertEqual(s._attachment_count, 1) s._handle_eio_message('123', b'bar') - self.assertEqual(s._attachment_count, 0) handler.assert_called_once_with('123', 'a', b'bar', b'foo') def test_handle_event_binary_ack(self, eio): - s = server.Server() + mgr = mock.MagicMock() + s = server.Server(client_manager=mgr) s.manager.initialize(s) - s._handle_eio_message('123', '61-1["my message","a",' + s._handle_eio_message('123', '61-321["my message","a",' '{"_placeholder":true,"num":0}]') - self.assertEqual(s._attachment_count, 1) - # the following call should not raise an exception in spite of the - # callback id being invalid s._handle_eio_message('123', b'foo') + mgr.trigger_callback.assert_called_once_with( + '123', '/', 321, ['my message', 'a', b'foo']) def test_handle_event_with_ack(self, eio): s = server.Server()