From 7c952de346fba6db124cb0ec3bb9853e1ec58482 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 30 Jan 2019 23:51:30 +0000 Subject: [PATCH] added wait and timeout options to client's emit --- socketio/asyncio_client.py | 46 +++++++++++++++++++++++++--- socketio/client.py | 44 +++++++++++++++++++++++--- socketio/exceptions.py | 4 +++ tests/asyncio/test_asyncio_client.py | 40 ++++++++++++++++++++++-- tests/common/test_client.py | 43 ++++++++++++++++++++++++-- 5 files changed, 165 insertions(+), 12 deletions(-) diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 2f78a52..bf733bd 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -120,7 +120,8 @@ class AsyncClient(client.Client): if self.eio.state != 'connected': break - async def emit(self, event, data=None, namespace=None, callback=None): + async def emit(self, event, data=None, namespace=None, callback=None, + wait=False, timeout=60): """Emit a custom event to one or more connected clients. :param event: The event name. It can be any string. The event names @@ -137,11 +138,30 @@ class AsyncClient(client.Client): that will be passed to the function are those provided by the client. Callback functions can only be used when addressing an individual client. + :param wait: If set to ``True``, this function will wait for the + server to handle the event and acknowledge it via its + callback function. The value(s) passed by the server to + its callback will be returned. If set to ``False``, + this function emits the event and returns immediately. + :param timeout: If ``wait`` is set to ``True``, this parameter + specifies a waiting timeout. If the timeout is reached + before the server acknowledges the event, then a + ``TimeoutError`` exception is raised. Note: this method is a coroutine. """ namespace = namespace or '/' self.logger.info('Emitting event "%s" [%s]', event, namespace) + if wait is True: + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + callback = event_callback + if callback is not None: id = self._generate_ack_id(namespace, callback) else: @@ -161,8 +181,17 @@ class AsyncClient(client.Client): await self._send_packet(packet.Packet( packet.EVENT, namespace=namespace, data=[event] + data, id=id, binary=binary)) - - async def send(self, data, namespace=None, callback=None): + if wait is True: + try: + await asyncio.wait_for(callback_event.wait(), timeout) + except asyncio.TimeoutError: + six.raise_from(exceptions.TimeoutError(), None) + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + async def send(self, data, namespace=None, callback=None, wait=False, + timeout=60): """Send a message to one or more connected clients. This function emits an event with the name ``'message'``. Use @@ -179,11 +208,20 @@ class AsyncClient(client.Client): that will be passed to the function are those provided by the client. Callback functions can only be used when addressing an individual client. + :param wait: If set to ``True``, this function will wait for the + server to handle the event and acknowledge it via its + callback function. The value(s) passed by the server to + its callback will be returned. If set to ``False``, + this function emits the event and returns immediately. + :param timeout: If ``wait`` is set to ``True``, this parameter + specifies a waiting timeout. If the timeout is reached + before the server acknowledges the event, then a + ``TimeoutError`` exception is raised. Note: this method is a coroutine. """ await self.emit('message', data=data, namespace=namespace, - callback=callback) + callback=callback, wait=wait, timeout=timeout) async def disconnect(self): """Disconnect from the server. diff --git a/socketio/client.py b/socketio/client.py index f56493f..8032389 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -224,7 +224,8 @@ class Client(object): if self.eio.state != 'connected': break - def emit(self, event, data=None, namespace=None, callback=None): + def emit(self, event, data=None, namespace=None, callback=None, + wait=False, timeout=60): """Emit a custom event to one or more connected clients. :param event: The event name. It can be any string. The event names @@ -241,9 +242,28 @@ class Client(object): that will be passed to the function are those provided by the client. Callback functions can only be used when addressing an individual client. + :param wait: If set to ``True``, this function will wait for the + server to handle the event and acknowledge it via its + callback function. The value(s) passed by the server to + its callback will be returned. If set to ``False``, + this function emits the event and returns immediately. + :param timeout: If ``wait`` is set to ``True``, this parameter + specifies a waiting timeout. If the timeout is reached + before the server acknowledges the event, then a + ``TimeoutError`` exception is raised. """ namespace = namespace or '/' self.logger.info('Emitting event "%s" [%s]', event, namespace) + if wait is True: + callback_event = self.eio.create_event() + callback_args = [] + + def event_callback(*args): + callback_args.append(args) + callback_event.set() + + callback = event_callback + if callback is not None: id = self._generate_ack_id(namespace, callback) else: @@ -263,8 +283,15 @@ class Client(object): self._send_packet(packet.Packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id, binary=binary)) - - def send(self, data, namespace=None, callback=None): + if wait is True: + if not callback_event.wait(timeout=timeout): + raise exceptions.TimeoutError() + return callback_args[0] if len(callback_args[0]) > 1 \ + else callback_args[0][0] if len(callback_args[0]) == 1 \ + else None + + def send(self, data, namespace=None, callback=None, wait=False, + timeout=60): """Send a message to one or more connected clients. This function emits an event with the name ``'message'``. Use @@ -281,9 +308,18 @@ class Client(object): that will be passed to the function are those provided by the client. Callback functions can only be used when addressing an individual client. + :param wait: If set to ``True``, this function will wait for the + server to handle the event and acknowledge it via its + callback function. The value(s) passed by the server to + its callback will be returned. If set to ``False``, + this function emits the event and returns immediately. + :param timeout: If ``wait`` is set to ``True``, this parameter + specifies a waiting timeout. If the timeout is reached + before the server acknowledges the event, then a + ``TimeoutError`` exception is raised. """ self.emit('message', data=data, namespace=namespace, - callback=callback) + callback=callback, wait=wait, timeout=timeout) def disconnect(self): """Disconnect from the server.""" diff --git a/socketio/exceptions.py b/socketio/exceptions.py index 5bd8697..eb54efa 100644 --- a/socketio/exceptions.py +++ b/socketio/exceptions.py @@ -4,3 +4,7 @@ class SocketIOError(Exception): class ConnectionError(SocketIOError): pass + + +class TimeoutError(SocketIOError): + pass diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index 78bca0f..14b5f63 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -200,6 +200,41 @@ class TestAsyncClient(unittest.TestCase): expected_packet.encode()) c._generate_ack_id.assert_called_once_with('/', 'cb') + def test_emit_with_wait(self): + c = asyncio_client.AsyncClient() + + async def fake_event_wait(): + c._generate_ack_id.call_args_list[0][0][1]('foo', 321) + + c._send_packet = AsyncMock() + c._generate_ack_id = mock.MagicMock(return_value=123) + c.eio = mock.MagicMock() + c.eio.create_event.return_value.wait = fake_event_wait + self.assertEqual(_run(c.emit('foo', wait=True)), ('foo', 321)) + expected_packet = packet.Packet(packet.EVENT, namespace='/', + data=['foo'], id=123, binary=False) + self.assertEqual(c._send_packet.mock.call_count, 1) + self.assertEqual(c._send_packet.mock.call_args_list[0][0][0].encode(), + expected_packet.encode()) + + def test_emit_with_wait_and_timeout(self): + c = asyncio_client.AsyncClient() + + async def fake_event_wait(): + await asyncio.sleep(1) + + c._send_packet = AsyncMock() + c._generate_ack_id = mock.MagicMock(return_value=123) + c.eio = mock.MagicMock() + c.eio.create_event.return_value.wait = fake_event_wait + self.assertRaises(exceptions.TimeoutError, _run, + c.emit('foo', wait=True, timeout=0.01)) + expected_packet = packet.Packet(packet.EVENT, namespace='/', + data=['foo'], id=123, binary=False) + self.assertEqual(c._send_packet.mock.call_count, 1) + self.assertEqual(c._send_packet.mock.call_args_list[0][0][0].encode(), + expected_packet.encode()) + def test_emit_namespace_with_callback(self): c = asyncio_client.AsyncClient() c._send_packet = AsyncMock() @@ -240,14 +275,15 @@ class TestAsyncClient(unittest.TestCase): _run(c.send('data', 'namespace', 'callback')) c.emit.mock.assert_called_once_with( 'message', data='data', namespace='namespace', - callback='callback') + callback='callback', wait=False, timeout=60) def test_send_with_defaults(self): c = asyncio_client.AsyncClient() c.emit = AsyncMock() _run(c.send('data')) c.emit.mock.assert_called_once_with( - 'message', data='data', namespace=None, callback=None) + 'message', data='data', namespace=None, callback=None, wait=False, + timeout=60) def test_disconnect(self): c = asyncio_client.AsyncClient() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 02d5ed0..2b689b8 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -283,6 +283,44 @@ class TestClient(unittest.TestCase): expected_packet.encode()) c._generate_ack_id.assert_called_once_with('/', 'cb') + def test_emit_with_wait(self): + c = client.Client() + + def fake_event_wait(timeout=None): + self.assertEqual(timeout, 60) + c._generate_ack_id.call_args_list[0][0][1]('foo', 321) + return True + + c._send_packet = mock.MagicMock() + c._generate_ack_id = mock.MagicMock(return_value=123) + c.eio = mock.MagicMock() + c.eio.create_event.return_value.wait = fake_event_wait + self.assertEqual(c.emit('foo', wait=True), ('foo', 321)) + expected_packet = packet.Packet(packet.EVENT, namespace='/', + data=['foo'], id=123, binary=False) + self.assertEqual(c._send_packet.call_count, 1) + self.assertEqual(c._send_packet.call_args_list[0][0][0].encode(), + expected_packet.encode()) + + def test_emit_with_wait_and_timeout(self): + c = client.Client() + + def fake_event_wait(timeout=None): + self.assertEqual(timeout, 12) + return False + + c._send_packet = mock.MagicMock() + c._generate_ack_id = mock.MagicMock(return_value=123) + c.eio = mock.MagicMock() + c.eio.create_event.return_value.wait = fake_event_wait + self.assertRaises(exceptions.TimeoutError, c.emit, 'foo', wait=True, + timeout=12) + expected_packet = packet.Packet(packet.EVENT, namespace='/', + data=['foo'], id=123, binary=False) + self.assertEqual(c._send_packet.call_count, 1) + self.assertEqual(c._send_packet.call_args_list[0][0][0].encode(), + expected_packet.encode()) + def test_emit_namespace_with_callback(self): c = client.Client() c._send_packet = mock.MagicMock() @@ -323,14 +361,15 @@ class TestClient(unittest.TestCase): c.send('data', 'namespace', 'callback') c.emit.assert_called_once_with( 'message', data='data', namespace='namespace', - callback='callback') + callback='callback', wait=False, timeout=60) def test_send_with_defaults(self): c = client.Client() c.emit = mock.MagicMock() c.send('data') c.emit.assert_called_once_with( - 'message', data='data', namespace=None, callback=None) + 'message', data='data', namespace=None, callback=None, wait=False, + timeout=60) def test_disconnect(self): c = client.Client()