diff --git a/src/socketio/asyncio_server.py b/src/socketio/asyncio_server.py index 59fab4a..c08adab 100644 --- a/src/socketio/asyncio_server.py +++ b/src/socketio/asyncio_server.py @@ -442,7 +442,15 @@ class AsyncServer(server.Server): async def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' - sid = self.manager.connect(eio_sid, namespace) + sid = None + if namespace in self.handlers or namespace in self.namespace_handlers: + sid = self.manager.connect(eio_sid, namespace) + if sid is None: + self._send_packet(eio_sid, self.packet_class( + packet.CONNECT_ERROR, data='Unable to connect', + namespace=namespace)) + return + if self.always_connect: await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) @@ -547,7 +555,7 @@ class AsyncServer(server.Server): return ret # or else, forward the event to a namepsace handler if one exists - elif namespace in self.namespace_handlers: + elif namespace in self.namespace_handlers: # pragma: no branch return await self.namespace_handlers[namespace].trigger_event( event, *args) diff --git a/src/socketio/server.py b/src/socketio/server.py index cdf255b..6a0210e 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -648,7 +648,9 @@ class Server(object): def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' - sid = self.manager.connect(eio_sid, namespace) + sid = None + if namespace in self.handlers or namespace in self.namespace_handlers: + sid = self.manager.connect(eio_sid, namespace) if sid is None: self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', @@ -748,7 +750,7 @@ class Server(object): return self.handlers[namespace]['*'](event, *args) # or else, forward the event to a namespace handler if one exists - elif namespace in self.namespace_handlers: + elif namespace in self.namespace_handlers: # pragma: no branch return self.namespace_handlers[namespace].trigger_event( event, *args) diff --git a/tests/asyncio/test_asyncio_server.py b/tests/asyncio/test_asyncio_server.py index 824d4a5..704eea6 100644 --- a/tests/asyncio/test_asyncio_server.py +++ b/tests/asyncio/test_asyncio_server.py @@ -425,6 +425,13 @@ class TestAsyncServer(unittest.TestCase): _run(s._handle_eio_message('456', '0')) assert s.manager.initialize.call_count == 1 + def test_handle_connect_with_bad_namespace(self, eio): + eio.return_value.send = AsyncMock() + s = asyncio_server.AsyncServer() + _run(s._handle_eio_connect('123', 'environ')) + _run(s._handle_eio_message('123', '0')) + assert not s.manager.is_connected('1', '/') + def test_handle_connect_namespace(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() @@ -752,6 +759,7 @@ class TestAsyncServer(unittest.TestCase): def test_send_with_ack(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/'] = {} _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0')) cb = mock.MagicMock() @@ -765,6 +773,7 @@ class TestAsyncServer(unittest.TestCase): def test_send_with_ack_namespace(self, eio): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/foo'] = {} _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo,')) cb = mock.MagicMock() @@ -791,6 +800,8 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/'] = {} + s.handlers['/ns'] = {} s.eio.get_session = fake_get_session s.eio.save_session = fake_save_session @@ -822,6 +833,7 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/'] = {} _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0')) _run(s.disconnect('1')) @@ -832,6 +844,7 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/'] = {} _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0')) _run(s.disconnect('1', ignore_queue=True)) @@ -842,6 +855,7 @@ class TestAsyncServer(unittest.TestCase): eio.return_value.send = AsyncMock() eio.return_value.disconnect = AsyncMock() s = asyncio_server.AsyncServer() + s.handlers['/foo'] = {} _run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_message('123', '0/foo,')) _run(s.disconnect('1', namespace='/foo')) diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 3b89c3b..860e9d0 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -356,6 +356,12 @@ class TestServer(unittest.TestCase): s._handle_eio_connect('456', 'environ') assert s.manager.initialize.call_count == 1 + def test_handle_connect_with_bad_namespace(self, eio): + s = server.Server() + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + assert not s.manager.is_connected('1', '/') + def test_handle_connect_namespace(self, eio): s = server.Server() handler = mock.MagicMock() @@ -663,6 +669,7 @@ class TestServer(unittest.TestCase): def test_send_with_ack(self, eio): s = server.Server() + s.handlers['/'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') cb = mock.MagicMock() @@ -675,6 +682,7 @@ class TestServer(unittest.TestCase): def test_send_with_ack_namespace(self, eio): s = server.Server() + s.handlers['/foo'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo,') cb = mock.MagicMock() @@ -696,6 +704,8 @@ class TestServer(unittest.TestCase): fake_session = session s = server.Server() + s.handlers['/'] = {} + s.handlers['/ns'] = {} s.eio.get_session = fake_get_session s.eio.save_session = fake_save_session s._handle_eio_connect('123', 'environ') @@ -721,6 +731,7 @@ class TestServer(unittest.TestCase): def test_disconnect(self, eio): s = server.Server() + s.handlers['/'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') s.disconnect('1') @@ -728,6 +739,7 @@ class TestServer(unittest.TestCase): def test_disconnect_ignore_queue(self, eio): s = server.Server() + s.handlers['/'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') s.disconnect('1', ignore_queue=True) @@ -735,6 +747,7 @@ class TestServer(unittest.TestCase): def test_disconnect_namespace(self, eio): s = server.Server() + s.handlers['/foo'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo,') s.disconnect('1', namespace='/foo') @@ -813,6 +826,7 @@ class TestServer(unittest.TestCase): def test_get_environ(self, eio): s = server.Server() + s.handlers['/'] = {} s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') sid = s.manager.sid_from_eio_sid('123', '/')