Browse Source

Pass auth information sent by client to the connect handler

pull/602/head
Miguel Grinberg 4 years ago
parent
commit
11b6f1a08d
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 8
      docs/server.rst
  2. 16
      socketio/asyncio_server.py
  3. 16
      socketio/server.py
  4. 32
      tests/asyncio/test_asyncio_server.py
  5. 28
      tests/common/test_server.py

8
docs/server.rst

@ -182,7 +182,7 @@ The ``connect`` and ``disconnect`` events are special; they are invoked
automatically when a client connects or disconnects from the server:: automatically when a client connects or disconnects from the server::
@sio.event @sio.event
def connect(sid, environ): def connect(sid, environ, auth):
print('connect ', sid) print('connect ', sid)
@sio.event @sio.event
@ -193,8 +193,10 @@ The ``connect`` event is an ideal place to perform user authentication, and
any necessary mapping between user entities in the application and the ``sid`` any necessary mapping between user entities in the application and the ``sid``
that was assigned to the client. The ``environ`` argument is a dictionary in that was assigned to the client. The ``environ`` argument is a dictionary in
standard WSGI format containing the request information, including HTTP standard WSGI format containing the request information, including HTTP
headers. After inspecting the request, the connect event handler can return headers. The ``auth`` argument contains any authentication details passed by
``False`` to reject the connection with the client. the client, or ``None`` if the client did not pass anything. After inspecting
the request, the connect event handler can return ``False`` to reject the
connection with the client.
Sometimes it is useful to pass data back to the client being rejected. In that Sometimes it is useful to pass data back to the client being rejected. In that
case instead of returning ``False`` case instead of returning ``False``

16
socketio/asyncio_server.py

@ -433,7 +433,7 @@ class AsyncServer(server.Server):
else: else:
await self.eio.send(eio_sid, encoded_packet) await self.eio.send(eio_sid, encoded_packet)
async def _handle_connect(self, eio_sid, namespace): async def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
@ -442,8 +442,16 @@ class AsyncServer(server.Server):
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
success = await self._trigger_event('connect', namespace, sid, if data:
self.environ[eio_sid]) success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc: except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args fail_reason = exc.error_args
success = False success = False
@ -552,7 +560,7 @@ class AsyncServer(server.Server):
else: else:
pkt = packet.Packet(encoded_packet=data) pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace) await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(eio_sid, pkt.namespace) await self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT: elif pkt.packet_type == packet.EVENT:

16
socketio/server.py

@ -619,7 +619,7 @@ class Server(object):
else: else:
self.eio.send(eio_sid, encoded_packet) self.eio.send(eio_sid, encoded_packet)
def _handle_connect(self, eio_sid, namespace): def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request.""" """Handle a client connection request."""
namespace = namespace or '/' namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
@ -628,8 +628,16 @@ class Server(object):
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
success = self._trigger_event('connect', namespace, sid, if data:
self.environ[eio_sid]) success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc: except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args fail_reason = exc.error_args
success = False success = False
@ -729,7 +737,7 @@ class Server(object):
else: else:
pkt = packet.Packet(encoded_packet=data) pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace) self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(eio_sid, pkt.namespace) self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT: elif pkt.packet_type == packet.EVENT:

32
tests/asyncio/test_asyncio_server.py

@ -377,6 +377,38 @@ class TestAsyncServer(unittest.TestCase):
_run(s._handle_eio_message('456', '0')) _run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1 assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0{"token":"abc"}'))
assert s.manager.is_connected('1', '/')
handler.assert_called_once_with('1', 'environ', {'token': 'abc'})
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth_none(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None, None])
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1
def test_handle_connect_async(self, eio): def test_handle_connect_async(self, eio):
eio.return_value.send = AsyncMock() eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer() s = asyncio_server.AsyncServer()

28
tests/common/test_server.py

@ -326,6 +326,34 @@ class TestServer(unittest.TestCase):
s._handle_eio_connect('456', 'environ') s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1 assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0{"token":"abc"}')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', {'token': 'abc'})
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth_none(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None])
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1
def test_handle_connect_namespace(self, eio): def test_handle_connect_namespace(self, eio):
s = server.Server() s = server.Server()
handler = mock.MagicMock() handler = mock.MagicMock()

Loading…
Cancel
Save