diff --git a/docs/client.rst b/docs/client.rst index 3224646..9b2ba36 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -65,13 +65,18 @@ or can also be coroutines:: async def message(data): print('I received a message!') -The ``connect`` and ``disconnect`` events are special; they are invoked -automatically when a client connects or disconnects from the server:: +The ``connect``, ``connect_error`` and ``disconnect`` events are special; they +are invoked automatically when a client connects or disconnects from the +server:: @sio.event def connect(): print("I'm connected!") + @sio.event + def connect_error(): + print("The connection failed!") + @sio.event def disconnect(): print("I'm disconnected!") diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index dd8caf3..2b10434 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -348,10 +348,15 @@ class AsyncClient(client.Client): else: callback(*data) - def _handle_error(self, namespace): + async def _handle_error(self, namespace, data): namespace = namespace or '/' self.logger.info('Connection to namespace {} was rejected'.format( namespace)) + if data is None: + data = tuple() + elif not isinstance(data, (tuple, list)): + data = (data,) + await self._trigger_event('connect_error', namespace, *data) if namespace in self.namespaces: self.namespaces.remove(namespace) if namespace == '/': @@ -445,7 +450,7 @@ class AsyncClient(client.Client): pkt.packet_type == packet.BINARY_ACK: self._binary_packet = pkt elif pkt.packet_type == packet.ERROR: - self._handle_error(pkt.namespace) + await self._handle_error(pkt.namespace, pkt.data) else: raise ValueError('Unknown packet type.') diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index 27e1ad1..251d581 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -403,7 +403,6 @@ class AsyncServer(server.Server): packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] - return False elif not self.always_connect: await self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) diff --git a/socketio/client.py b/socketio/client.py index 0751c29..e917d63 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -504,10 +504,15 @@ class Client(object): if callback is not None: callback(*data) - def _handle_error(self, namespace): + def _handle_error(self, namespace, data): namespace = namespace or '/' self.logger.info('Connection to namespace {} was rejected'.format( namespace)) + if data is None: + data = tuple() + elif not isinstance(data, (tuple, list)): + data = (data,) + self._trigger_event('connect_error', namespace, *data) if namespace in self.namespaces: self.namespaces.remove(namespace) if namespace == '/': @@ -591,7 +596,7 @@ class Client(object): pkt.packet_type == packet.BINARY_ACK: self._binary_packet = pkt elif pkt.packet_type == packet.ERROR: - self._handle_error(pkt.namespace) + self._handle_error(pkt.namespace, pkt.data) else: raise ValueError('Unknown packet type.') diff --git a/socketio/exceptions.py b/socketio/exceptions.py index 289300c..36dddd9 100644 --- a/socketio/exceptions.py +++ b/socketio/exceptions.py @@ -16,7 +16,7 @@ class ConnectionRefusedError(ConnectionError): def __init__(self, *args): if len(args) == 0: self.error_args = None - elif len(args) == 1: + elif len(args) == 1 and not isinstance(args[0], list): self.error_args = args[0] else: self.error_args = args diff --git a/socketio/server.py b/socketio/server.py index 3cae39f..76b7d2e 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -613,7 +613,6 @@ class Server(object): packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] - return False elif not self.always_connect: self._send_packet(sid, packet.Packet(packet.CONNECT, namespace=namespace)) diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index 6ae071b..e5b7828 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -551,24 +551,50 @@ class TestAsyncClient(unittest.TestCase): def test_handle_error(self): c = asyncio_client.AsyncClient() c.connected = True + c._trigger_event = AsyncMock() + c.namespaces = ['/foo', '/bar'] + _run(c._handle_error('/', 'error')) + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) + c._trigger_event.mock.assert_called_once_with('connect_error', '/', + 'error') + + def test_handle_error_with_no_arguments(self): + c = asyncio_client.AsyncClient() + c.connected = True + c._trigger_event = AsyncMock() c.namespaces = ['/foo', '/bar'] - c._handle_error('/') + _run(c._handle_error('/', None)) self.assertEqual(c.namespaces, []) self.assertFalse(c.connected) + c._trigger_event.mock.assert_called_once_with('connect_error', '/') def test_handle_error_namespace(self): c = asyncio_client.AsyncClient() c.connected = True c.namespaces = ['/foo', '/bar'] - c._handle_error('/bar') + c._trigger_event = AsyncMock() + _run(c._handle_error('/bar', ['error', 'message'])) + self.assertEqual(c.namespaces, ['/foo']) + self.assertTrue(c.connected) + c._trigger_event.mock.assert_called_once_with('connect_error', '/bar', + 'error', 'message') + + def test_handle_error_namespace_with_no_arguments(self): + c = asyncio_client.AsyncClient() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._trigger_event = AsyncMock() + _run(c._handle_error('/bar', None)) self.assertEqual(c.namespaces, ['/foo']) self.assertTrue(c.connected) + c._trigger_event.mock.assert_called_once_with('connect_error', '/bar') def test_handle_error_unknown_namespace(self): c = asyncio_client.AsyncClient() c.connected = True c.namespaces = ['/foo', '/bar'] - c._handle_error('/baz') + _run(c._handle_error('/baz', 'error')) self.assertEqual(c.namespaces, ['/foo', '/bar']) self.assertTrue(c.connected) @@ -685,7 +711,7 @@ class TestAsyncClient(unittest.TestCase): c._handle_disconnect = AsyncMock() c._handle_event = AsyncMock() c._handle_ack = AsyncMock() - c._handle_error = mock.MagicMock() + c._handle_error = AsyncMock() _run(c._handle_eio_message('0')) c._handle_connect.mock.assert_called_with(None) @@ -700,9 +726,15 @@ class TestAsyncClient(unittest.TestCase): _run(c._handle_eio_message('3/foo,["bar"]')) c._handle_ack.mock.assert_called_with('/foo', None, ['bar']) _run(c._handle_eio_message('4')) - c._handle_error.assert_called_with(None) + c._handle_error.mock.assert_called_with(None, None) + _run(c._handle_eio_message('4"foo"')) + c._handle_error.mock.assert_called_with(None, 'foo') + _run(c._handle_eio_message('4["foo"]')) + c._handle_error.mock.assert_called_with(None, ['foo']) _run(c._handle_eio_message('4/foo')) - c._handle_error.assert_called_with('/foo') + c._handle_error.mock.assert_called_with('/foo', None) + _run(c._handle_eio_message('4/foo,["foo","bar"]')) + c._handle_error.mock.assert_called_with('/foo', ['foo', 'bar']) _run(c._handle_eio_message('51-{"_placeholder":true,"num":0}')) self.assertEqual(c._binary_packet.packet_type, packet.BINARY_EVENT) _run(c._handle_eio_message(b'foo')) diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 3974a8e..42840f7 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -660,23 +660,48 @@ class TestClient(unittest.TestCase): c = client.Client() c.connected = True c.namespaces = ['/foo', '/bar'] - c._handle_error('/') + c._trigger_event = mock.MagicMock() + c._handle_error('/', 'error') + self.assertEqual(c.namespaces, []) + self.assertFalse(c.connected) + c._trigger_event.assert_called_once_with('connect_error', '/', 'error') + + def test_handle_error_with_no_arguments(self): + c = client.Client() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._trigger_event = mock.MagicMock() + c._handle_error('/', None) self.assertEqual(c.namespaces, []) self.assertFalse(c.connected) + c._trigger_event.assert_called_once_with('connect_error', '/') def test_handle_error_namespace(self): c = client.Client() c.connected = True c.namespaces = ['/foo', '/bar'] - c._handle_error('/bar') + c._trigger_event = mock.MagicMock() + c._handle_error('/bar', ['error', 'message']) + self.assertEqual(c.namespaces, ['/foo']) + self.assertTrue(c.connected) + c._trigger_event.assert_called_once_with('connect_error', '/bar', + 'error', 'message') + + def test_handle_error_namespace_with_no_arguments(self): + c = client.Client() + c.connected = True + c.namespaces = ['/foo', '/bar'] + c._trigger_event = mock.MagicMock() + c._handle_error('/bar', None) self.assertEqual(c.namespaces, ['/foo']) self.assertTrue(c.connected) + c._trigger_event.assert_called_once_with('connect_error', '/bar') def test_handle_error_unknown_namespace(self): c = client.Client() c.connected = True c.namespaces = ['/foo', '/bar'] - c._handle_error('/baz') + c._handle_error('/baz', 'error') self.assertEqual(c.namespaces, ['/foo', '/bar']) self.assertTrue(c.connected) @@ -809,9 +834,15 @@ class TestClient(unittest.TestCase): c._handle_eio_message('3/foo,["bar"]') c._handle_ack.assert_called_with('/foo', None, ['bar']) c._handle_eio_message('4') - c._handle_error.assert_called_with(None) + c._handle_error.assert_called_with(None, None) + c._handle_eio_message('4"foo"') + c._handle_error.assert_called_with(None, 'foo') + c._handle_eio_message('4["foo"]') + c._handle_error.assert_called_with(None, ['foo']) c._handle_eio_message('4/foo') - c._handle_error.assert_called_with('/foo') + c._handle_error.assert_called_with('/foo', None) + c._handle_eio_message('4/foo,["foo","bar"]') + c._handle_error.assert_called_with('/foo', ['foo', 'bar']) c._handle_eio_message('51-{"_placeholder":true,"num":0}') self.assertEqual(c._binary_packet.packet_type, packet.BINARY_EVENT) c._handle_eio_message(b'foo') diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 3c4dbc2..abecaf0 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -341,7 +341,6 @@ class TestServer(unittest.TestCase): s._handle_eio_message('123', '0/foo') self.assertEqual(s.manager.connect.call_count, 2) self.assertEqual(s.manager.disconnect.call_count, 1) - print(s.eio.send.call_args) s.eio.send.assert_any_call('123', '4/foo,"fail_reason"', binary=False) def test_handle_disconnect(self, eio):