Browse Source

Pass auth information sent by client to the connect handler

pull/602/head
Miguel Grinberg 4 years ago
parent
commit
11b6f1a08d
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 8
      docs/server.rst
  2. 16
      socketio/asyncio_server.py
  3. 16
      socketio/server.py
  4. 32
      tests/asyncio/test_asyncio_server.py
  5. 28
      tests/common/test_server.py

8
docs/server.rst

@ -182,7 +182,7 @@ The ``connect`` and ``disconnect`` events are special; they are invoked
automatically when a client connects or disconnects from the server::
@sio.event
def connect(sid, environ):
def connect(sid, environ, auth):
print('connect ', sid)
@sio.event
@ -193,8 +193,10 @@ The ``connect`` event is an ideal place to perform user authentication, and
any necessary mapping between user entities in the application and the ``sid``
that was assigned to the client. The ``environ`` argument is a dictionary in
standard WSGI format containing the request information, including HTTP
headers. After inspecting the request, the connect event handler can return
``False`` to reject the connection with the client.
headers. The ``auth`` argument contains any authentication details passed by
the client, or ``None`` if the client did not pass anything. After inspecting
the request, the connect event handler can return ``False`` to reject the
connection with the client.
Sometimes it is useful to pass data back to the client being rejected. In that
case instead of returning ``False``

16
socketio/asyncio_server.py

@ -433,7 +433,7 @@ class AsyncServer(server.Server):
else:
await self.eio.send(eio_sid, encoded_packet)
async def _handle_connect(self, eio_sid, namespace):
async def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
@ -442,8 +442,16 @@ class AsyncServer(server.Server):
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[eio_sid])
if data:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
@ -552,7 +560,7 @@ class AsyncServer(server.Server):
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace)
await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:

16
socketio/server.py

@ -619,7 +619,7 @@ class Server(object):
else:
self.eio.send(eio_sid, encoded_packet)
def _handle_connect(self, eio_sid, namespace):
def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
@ -628,8 +628,16 @@ class Server(object):
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[eio_sid])
if data:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
@ -729,7 +737,7 @@ class Server(object):
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace)
self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:

32
tests/asyncio/test_asyncio_server.py

@ -377,6 +377,38 @@ class TestAsyncServer(unittest.TestCase):
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0{"token":"abc"}'))
assert s.manager.is_connected('1', '/')
handler.assert_called_once_with('1', 'environ', {'token': 'abc'})
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth_none(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None, None])
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1
def test_handle_connect_async(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()

28
tests/common/test_server.py

@ -326,6 +326,34 @@ class TestServer(unittest.TestCase):
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0{"token":"abc"}')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', {'token': 'abc'})
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1
def test_handle_connect_with_auth_none(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None])
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1
def test_handle_connect_namespace(self, eio):
s = server.Server()
handler = mock.MagicMock()

Loading…
Cancel
Save