Browse Source

Added pass_environ=False keyword argument to Server.

pull/44/head
Robert Schindler 9 years ago
parent
commit
cdf0f71f56
  1. 29
      socketio/server.py
  2. 7
      tests/test_server.py

29
socketio/server.py

@ -56,12 +56,16 @@ class Server(object):
default.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param pass_environ: If set to ``True``, the current ``environ`` is passed
to all event handlers except for ``"disconnect"`` as
second argument (right after ``sid``).
(default is ``False``)
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``.
"""
def __init__(self, client_manager=None, logger=False, binary=False,
json=None, **kwargs):
json=None, pass_environ=False, **kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
@ -69,6 +73,8 @@ class Server(object):
if json is not None:
packet.Packet.json = json
engineio_options['json'] = json
engineio_options['pass_environ'] = pass_environ
self.pass_environ = pass_environ
self.eio = engineio.Server(**engineio_options)
self.eio.on('connect', self._handle_eio_connect)
self.eio.on('message', self._handle_eio_message)
@ -392,12 +398,16 @@ class Server(object):
if sid in self.environ:
del self.environ[sid]
def _handle_event(self, sid, namespace, id, data):
def _handle_event(self, environ, sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
r = self._trigger_event(data[0], namespace, sid, *data[1:])
if self.pass_environ:
r = self._trigger_event(data[0], namespace, sid, environ,
*data[1:])
else:
r = self._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
@ -432,14 +442,20 @@ class Server(object):
self.environ[sid] = environ
return self._handle_connect(sid, '/')
def _handle_eio_message(self, sid, data):
def _handle_eio_message(self, *args):
"""Dispatch Engine.IO messages."""
if self.pass_environ:
sid, environ, data = args
else:
sid, data = args
environ = None
if len(self._binary_packet):
pkt = self._binary_packet[0]
if pkt.add_attachment(data):
self._binary_packet.pop(0)
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_event(environ, sid, pkt.namespace, pkt.id,
pkt.data)
else:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
else:
@ -449,7 +465,8 @@ class Server(object):
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
self._handle_event(environ, sid, pkt.namespace, pkt.id,
pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \

7
tests/test_server.py

@ -23,7 +23,7 @@ class TestServer(unittest.TestCase):
s = server.Server(client_manager=mgr, binary=True, foo='bar')
s.handle_request({}, None)
s.handle_request({}, None)
eio.assert_called_once_with(**{'foo': 'bar'})
eio.assert_called_once_with(**{'pass_environ': False, 'foo': 'bar'})
self.assertEqual(s.manager, mgr)
self.assertEqual(s.eio.on.call_count, 3)
self.assertEqual(s.binary, True)
@ -401,7 +401,7 @@ class TestServer(unittest.TestCase):
def test_engineio_logger(self, eio):
server.Server(engineio_logger='foo')
eio.assert_called_once_with(**{'logger': 'foo'})
eio.assert_called_once_with(**{'pass_environ': False, 'logger': 'foo'})
def test_custom_json(self, eio):
# Warning: this test cannot run in parallel with other tests, as it
@ -417,7 +417,8 @@ class TestServer(unittest.TestCase):
return '+++ decoded +++'
server.Server(json=CustomJSON)
eio.assert_called_once_with(**{'json': CustomJSON})
eio.assert_called_once_with(**{'pass_environ': False,
'json': CustomJSON})
pkt = packet.Packet(packet_type=packet.EVENT,
data={six.text_type('foo'): six.text_type('bar')})

Loading…
Cancel
Save