Browse Source

Pass custom authentication data with client connection (Fixes #661)

pull/667/head
Miguel Grinberg 4 years ago
parent
commit
a07eedf54e
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 2
      examples/client/asyncio/fiddle_client.py
  2. 2
      examples/client/javascript/fiddle_client.js
  3. 2012
      examples/client/javascript/package-lock.json
  4. 4
      examples/client/javascript/package.json
  5. 2
      examples/client/threads/fiddle_client.py
  6. 4
      examples/server/aiohttp/fiddle.py
  7. 4
      examples/server/asgi/fiddle.py
  8. 2
      examples/server/javascript/fiddle.js
  9. 2012
      examples/server/javascript/package-lock.json
  10. 6
      examples/server/javascript/package.json
  11. 4
      examples/server/sanic/fiddle.py
  12. 4
      examples/server/tornado/fiddle.py
  13. 4
      examples/server/wsgi/fiddle.py
  14. 10
      socketio/asyncio_client.py
  15. 11
      socketio/client.py
  16. 11
      tests/asyncio/test_asyncio_client.py
  17. 9
      tests/common/test_client.py

2
examples/client/asyncio/fiddle_client.py

@ -20,7 +20,7 @@ def hello(a, b, c):
async def start_server():
await sio.connect('http://localhost:5000')
await sio.connect('http://localhost:5000', auth={'token': 'my-token'})
await sio.wait()

2
examples/client/javascript/fiddle_client.js

@ -1,7 +1,7 @@
const io = require('socket.io-client')
const port = process.env.PORT || 5000;
const socket = io('http://localhost:' + port);
const socket = io('http://localhost:' + port, {auth: {token: 'my-token'}});
socket.on('connect', () => {
console.log(`connect ${socket.id}`);

2012
examples/client/javascript/package-lock.json

File diff suppressed because it is too large

4
examples/client/javascript/package.json

@ -4,7 +4,7 @@
"dependencies": {
"express": "^4.17.1",
"smoothie": "1.19.0",
"socket.io": "^3.0.0",
"socket.io-client": "^3.0.0"
"socket.io": "^3.0.3",
"socket.io-client": "^3.0.3"
}
}

2
examples/client/threads/fiddle_client.py

@ -19,5 +19,5 @@ def hello(a, b, c):
if __name__ == '__main__':
sio.connect('http://localhost:5000')
sio.connect('http://localhost:5000', auth={'token': 'my-token'})
sio.wait()

4
examples/server/aiohttp/fiddle.py

@ -13,8 +13,8 @@ async def index(request):
@sio.event
async def connect(sid, environ):
print('connected', sid)
async def connect(sid, environ, auth):
print(f'connected auth={auth} sid={sid}')
await sio.emit('hello', (1, 2, {'hello': 'you'}), to=sid)

4
examples/server/asgi/fiddle.py

@ -11,8 +11,8 @@ app = socketio.ASGIApp(sio, static_files={
@sio.event
async def connect(sid, environ):
print('connected', sid)
async def connect(sid, environ, auth):
print(f'connected auth={auth} sid={sid}')
await sio.emit('hello', (1, 2, {'hello': 'you'}), to=sid)

2
examples/server/javascript/fiddle.js

@ -7,7 +7,7 @@ const port = process.env.PORT || 5000;
app.use(express.static(__dirname + '/fiddle_public'));
io.on('connection', socket => {
console.log(`connect ${socket.id}`);
console.log(`connect auth=${JSON.stringify(socket.handshake.auth)} sid=${socket.id}`);
socket.emit('hello', 1, '2', {
hello: 'you'

2012
examples/server/javascript/package-lock.json

File diff suppressed because it is too large

6
examples/server/javascript/package.json

@ -2,9 +2,9 @@
"name": "socketio-examples",
"version": "0.1.0",
"dependencies": {
"socket.io": "^3.0.0",
"socket.io-client": "^3.0.0",
"express": "^4.17.1",
"smoothie": "1.19.0"
"smoothie": "1.19.0",
"socket.io": "^3.0.3",
"socket.io-client": "^3.0.3"
}
}

4
examples/server/sanic/fiddle.py

@ -15,8 +15,8 @@ def index(request):
@sio.event
async def connect(sid, environ):
print('connected', sid)
async def connect(sid, environ, auth):
print(f'connected auth={auth} sid={sid}')
await sio.emit('hello', (1, 2, {'hello': 'you'}), to=sid)

4
examples/server/tornado/fiddle.py

@ -18,8 +18,8 @@ class MainHandler(tornado.web.RequestHandler):
@sio.event
async def connect(sid, environ):
print('connected', sid)
async def connect(sid, environ, auth):
print(f'connected auth={auth} sid={sid}')
await sio.emit('hello', (1, 2, {'hello': 'you'}), to=sid)

4
examples/server/wsgi/fiddle.py

@ -17,8 +17,8 @@ def index():
@sio.event
def connect(sid, environ):
print('connected', sid)
def connect(sid, environ, auth):
print(f'connected auth={auth} sid={sid}')
sio.emit('hello', (1, 2, {'hello': 'you'}), to=sid)

10
socketio/asyncio_client.py

@ -62,7 +62,7 @@ class AsyncClient(client.Client):
def is_asyncio_based(self):
return True
async def connect(self, url, headers={}, transports=None,
async def connect(self, url, headers={}, auth=None, transports=None,
namespaces=None, socketio_path='socket.io', wait=True,
wait_timeout=1):
"""Connect to a Socket.IO server.
@ -71,6 +71,9 @@ class AsyncClient(client.Client):
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
@ -103,6 +106,7 @@ class AsyncClient(client.Client):
self.connection_url = url
self.connection_headers = headers
self.connection_auth = auth
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
@ -437,6 +441,7 @@ class AsyncClient(client.Client):
try:
await self.connect(self.connection_url,
headers=self.connection_headers,
auth=self.connection_auth,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
@ -458,7 +463,8 @@ class AsyncClient(client.Client):
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
for n in self.connection_namespaces:
await self._send_packet(packet.Packet(packet.CONNECT, namespace=n))
await self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
async def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""

11
socketio/client.py

@ -120,6 +120,7 @@ class Client(object):
self.connection_url = None
self.connection_headers = None
self.connection_auth = None
self.connection_transports = None
self.connection_namespaces = []
self.socketio_path = None
@ -233,7 +234,7 @@ class Client(object):
self.namespace_handlers[namespace_handler.namespace] = \
namespace_handler
def connect(self, url, headers={}, transports=None,
def connect(self, url, headers={}, auth=None, transports=None,
namespaces=None, socketio_path='socket.io', wait=True,
wait_timeout=1):
"""Connect to a Socket.IO server.
@ -242,6 +243,9 @@ class Client(object):
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
@ -272,6 +276,7 @@ class Client(object):
self.connection_url = url
self.connection_headers = headers
self.connection_auth = auth
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
@ -602,6 +607,7 @@ class Client(object):
try:
self.connect(self.connection_url,
headers=self.connection_headers,
auth=self.connection_auth,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
@ -623,7 +629,8 @@ class Client(object):
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
for n in self.connection_namespaces:
self._send_packet(packet.Packet(packet.CONNECT, namespace=n))
self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""

11
tests/asyncio/test_asyncio_client.py

@ -55,6 +55,7 @@ class TestAsyncClient(unittest.TestCase):
c.connect(
'url',
headers='headers',
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
@ -63,6 +64,7 @@ class TestAsyncClient(unittest.TestCase):
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
assert c.connection_auth == 'auth'
assert c.connection_transports == 'transports'
assert c.connection_namespaces == ['/foo', '/', '/bar']
assert c.socketio_path == 'path'
@ -934,21 +936,24 @@ class TestAsyncClient(unittest.TestCase):
]
assert c._reconnect_task == 'foo'
def test_eio_connect(self):
def test_handle_eio_connect(self):
c = asyncio_client.AsyncClient()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = 'auth'
c._send_packet = AsyncMock()
c.eio.sid = 'foo'
assert c.sid is None
_run(c._handle_eio_connect())
assert c.sid == 'foo'
assert c._send_packet.mock.call_count == 2
expected_packet = packet.Packet(packet.CONNECT, namespace='/')
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.mock.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(packet.CONNECT, namespace='/foo')
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.mock.call_args_list[1][0][0].encode()
== expected_packet.encode()

9
tests/common/test_client.py

@ -154,6 +154,7 @@ class TestClient(unittest.TestCase):
c.connect(
'url',
headers='headers',
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
@ -161,6 +162,7 @@ class TestClient(unittest.TestCase):
)
assert c.connection_url == 'url'
assert c.connection_headers == 'headers'
assert c.connection_auth == 'auth'
assert c.connection_transports == 'transports'
assert c.connection_namespaces == ['/foo', '/', '/bar']
assert c.socketio_path == 'path'
@ -1008,18 +1010,21 @@ class TestClient(unittest.TestCase):
def test_handle_eio_connect(self):
c = client.Client()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = 'auth'
c._send_packet = mock.MagicMock()
c.eio.sid = 'foo'
assert c.sid is None
c._handle_eio_connect()
assert c.sid == 'foo'
assert c._send_packet.call_count == 2
expected_packet = packet.Packet(packet.CONNECT, namespace='/')
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(packet.CONNECT, namespace='/foo')
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.call_args_list[1][0][0].encode()
== expected_packet.encode()

Loading…
Cancel
Save