diff --git a/socketio/server.py b/socketio/server.py index 227ff80..5e088d8 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -56,12 +56,16 @@ class Server(object): default. :param cors_credentials: Whether credentials (cookies, authentication) are allowed in requests to this server. + :param pass_environ: If set to ``True``, the current ``environ`` is passed + to all event handlers except for ``"disconnect"`` as + second argument (right after ``sid``). + (default is ``False``) :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. """ def __init__(self, client_manager=None, logger=False, binary=False, - json=None, **kwargs): + json=None, pass_environ=False, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -69,6 +73,8 @@ class Server(object): if json is not None: packet.Packet.json = json engineio_options['json'] = json + engineio_options['pass_environ'] = pass_environ + self.pass_environ = pass_environ self.eio = engineio.Server(**engineio_options) self.eio.on('connect', self._handle_eio_connect) self.eio.on('message', self._handle_eio_message) @@ -392,12 +398,16 @@ class Server(object): if sid in self.environ: del self.environ[sid] - def _handle_event(self, sid, namespace, id, data): + def _handle_event(self, environ, sid, namespace, id, data): """Handle an incoming client event.""" namespace = namespace or '/' self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) - r = self._trigger_event(data[0], namespace, sid, *data[1:]) + if self.pass_environ: + r = self._trigger_event(data[0], namespace, sid, environ, + *data[1:]) + else: + r = self._trigger_event(data[0], namespace, sid, *data[1:]) if id is not None: # send ACK packet with the response returned by the handler # tuples are expanded as multiple arguments @@ -432,14 +442,20 @@ class Server(object): self.environ[sid] = environ return self._handle_connect(sid, '/') - def _handle_eio_message(self, sid, data): + def _handle_eio_message(self, *args): """Dispatch Engine.IO messages.""" + if self.pass_environ: + sid, environ, data = args + else: + sid, data = args + environ = None if len(self._binary_packet): pkt = self._binary_packet[0] if pkt.add_attachment(data): self._binary_packet.pop(0) if pkt.packet_type == packet.BINARY_EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(environ, sid, pkt.namespace, pkt.id, + pkt.data) else: self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) else: @@ -449,7 +465,8 @@ class Server(object): elif pkt.packet_type == packet.DISCONNECT: self._handle_disconnect(sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(environ, sid, pkt.namespace, pkt.id, + pkt.data) elif pkt.packet_type == packet.ACK: self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ diff --git a/tests/test_server.py b/tests/test_server.py index 372f61a..65ce135 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -23,7 +23,7 @@ class TestServer(unittest.TestCase): s = server.Server(client_manager=mgr, binary=True, foo='bar') s.handle_request({}, None) s.handle_request({}, None) - eio.assert_called_once_with(**{'foo': 'bar'}) + eio.assert_called_once_with(**{'pass_environ': False, 'foo': 'bar'}) self.assertEqual(s.manager, mgr) self.assertEqual(s.eio.on.call_count, 3) self.assertEqual(s.binary, True) @@ -401,7 +401,7 @@ class TestServer(unittest.TestCase): def test_engineio_logger(self, eio): server.Server(engineio_logger='foo') - eio.assert_called_once_with(**{'logger': 'foo'}) + eio.assert_called_once_with(**{'pass_environ': False, 'logger': 'foo'}) def test_custom_json(self, eio): # Warning: this test cannot run in parallel with other tests, as it @@ -417,7 +417,8 @@ class TestServer(unittest.TestCase): return '+++ decoded +++' server.Server(json=CustomJSON) - eio.assert_called_once_with(**{'json': CustomJSON}) + eio.assert_called_once_with(**{'pass_environ': False, + 'json': CustomJSON}) pkt = packet.Packet(packet_type=packet.EVENT, data={six.text_type('foo'): six.text_type('bar')})