Browse Source

Add namespaces argument to Server and AsyncServer (Fixes #822)

pull/962/head
Miguel Grinberg 3 years ago
parent
commit
efe87d867a
  1. 13
      src/socketio/asyncio_server.py
  2. 11
      src/socketio/server.py
  3. 22
      tests/asyncio/test_asyncio_server.py
  4. 20
      tests/common/test_server.py

13
src/socketio/asyncio_server.py

@ -40,6 +40,11 @@ class AsyncServer(server.Server):
connect handler and your client is confused when it connect handler and your client is confused when it
receives events before the connection acceptance. receives events before the connection acceptance.
In any other case use the default of ``False``. In any other case use the default of ``False``.
:param namespaces: a list of namespaces that are accepted, in addition to
any namespaces for which handlers have been defined. The
default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all
namespaces.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -97,11 +102,12 @@ class AsyncServer(server.Server):
``engineio_logger`` is ``False``. ``engineio_logger`` is ``False``.
""" """
def __init__(self, client_manager=None, logger=False, json=None, def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, **kwargs): async_handlers=True, namespaces=None, **kwargs):
if client_manager is None: if client_manager is None:
client_manager = asyncio_manager.AsyncManager() client_manager = asyncio_manager.AsyncManager()
super().__init__(client_manager=client_manager, logger=logger, super().__init__(client_manager=client_manager, logger=logger,
json=json, async_handlers=async_handlers, **kwargs) json=json, async_handlers=async_handlers,
namespaces=namespaces, **kwargs)
def is_asyncio_based(self): def is_asyncio_based(self):
return True return True
@ -443,7 +449,8 @@ class AsyncServer(server.Server):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
sid = None sid = None
if namespace in self.handlers or namespace in self.namespace_handlers: if namespace in self.handlers or namespace in self.namespace_handlers \
or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
await self._send_packet(eio_sid, self.packet_class( await self._send_packet(eio_sid, self.packet_class(

11
src/socketio/server.py

@ -49,6 +49,11 @@ class Server(object):
connect handler and your client is confused when it connect handler and your client is confused when it
receives events before the connection acceptance. receives events before the connection acceptance.
In any other case use the default of ``False``. In any other case use the default of ``False``.
:param namespaces: a list of namespaces that are accepted, in addition to
any namespaces for which handlers have been defined. The
default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all
namespaces.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -110,7 +115,7 @@ class Server(object):
def __init__(self, client_manager=None, logger=False, serializer='default', def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False, json=None, async_handlers=True, always_connect=False,
**kwargs): namespaces=None, **kwargs):
engineio_options = kwargs engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None) engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None: if engineio_logger is not None:
@ -157,6 +162,7 @@ class Server(object):
self.async_handlers = async_handlers self.async_handlers = async_handlers
self.always_connect = always_connect self.always_connect = always_connect
self.namespaces = namespaces or ['/']
self.async_mode = self.eio.async_mode self.async_mode = self.eio.async_mode
@ -650,7 +656,8 @@ class Server(object):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
sid = None sid = None
if namespace in self.handlers or namespace in self.namespace_handlers: if namespace in self.handlers or namespace in self.namespace_handlers \
or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
self._send_packet(eio_sid, self.packet_class( self._send_packet(eio_sid, self.packet_class(

22
tests/asyncio/test_asyncio_server.py

@ -425,12 +425,32 @@ class TestAsyncServer(unittest.TestCase):
_run(s._handle_eio_message('456', '0')) _run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1 assert s.manager.initialize.call_count == 1
def test_handle_connect_with_bad_namespace(self, eio): def test_handle_connect_with_default_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock() eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer() s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ')) _run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0')) _run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert s.manager.is_connected('1', '/')
assert not s.manager.is_connected('2', '/foo')
def test_handle_connect_with_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer(namespaces=['/foo'])
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert not s.manager.is_connected('1', '/') assert not s.manager.is_connected('1', '/')
assert s.manager.is_connected('1', '/foo')
def test_handle_connect_with_all_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer(namespaces='*')
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert s.manager.is_connected('1', '/')
assert s.manager.is_connected('2', '/foo')
def test_handle_connect_namespace(self, eio): def test_handle_connect_namespace(self, eio):
eio.return_value.send = AsyncMock() eio.return_value.send = AsyncMock()

20
tests/common/test_server.py

@ -356,11 +356,29 @@ class TestServer(unittest.TestCase):
s._handle_eio_connect('456', 'environ') s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1 assert s.manager.initialize.call_count == 1
def test_handle_connect_with_bad_namespace(self, eio): def test_handle_connect_with_default_implied_namespaces(self, eio):
s = server.Server() s = server.Server()
s._handle_eio_connect('123', 'environ') s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0') s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert s.manager.is_connected('1', '/')
assert not s.manager.is_connected('2', '/foo')
def test_handle_connect_with_implied_namespaces(self, eio):
s = server.Server(namespaces=['/foo'])
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert not s.manager.is_connected('1', '/') assert not s.manager.is_connected('1', '/')
assert s.manager.is_connected('1', '/foo')
def test_handle_connect_with_all_implied_namespaces(self, eio):
s = server.Server(namespaces='*')
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert s.manager.is_connected('1', '/')
assert s.manager.is_connected('2', '/foo')
def test_handle_connect_namespace(self, eio): def test_handle_connect_namespace(self, eio):
s = server.Server() s = server.Server()

Loading…
Cancel
Save