Browse Source

Support msgpack and custom packet serializers (Fixes #749)

pull/789/head
Miguel Grinberg 4 years ago
parent
commit
5159e84c49
No known key found for this signature in database GPG Key ID: 36848B262DF5F06C
  1. 10
      src/socketio/asyncio_client.py
  2. 16
      src/socketio/asyncio_server.py
  3. 34
      src/socketio/client.py
  4. 16
      src/socketio/msgpack_packet.py
  5. 12
      src/socketio/packet.py
  6. 46
      src/socketio/server.py
  7. 15
      tests/common/test_client.py
  8. 24
      tests/common/test_msgpack_packet.py
  9. 13
      tests/common/test_server.py
  10. 1
      tox.ini

10
src/socketio/asyncio_client.py

@ -220,7 +220,7 @@ class AsyncClient(client.Client):
data = [data]
else:
data = []
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def send(self, data, namespace=None, callback=None):
@ -296,7 +296,7 @@ class AsyncClient(client.Client):
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.DISCONNECT,
await self._send_packet(self.packet_class(packet.DISCONNECT,
namespace=n))
await self.eio.disconnect(abort=True)
@ -379,7 +379,7 @@ class AsyncClient(client.Client):
data = list(r)
else:
data = [r]
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, namespace, id, data):
@ -482,7 +482,7 @@ class AsyncClient(client.Client):
self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n))
async def _handle_eio_message(self, data):
@ -496,7 +496,7 @@ class AsyncClient(client.Client):
else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:

16
src/socketio/asyncio_server.py

@ -369,7 +369,7 @@ class AsyncServer(server.Server):
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
@ -423,7 +423,7 @@ class AsyncServer(server.Server):
data = [data]
else:
data = []
await self._send_packet(sid, packet.Packet(
await self._send_packet(sid, self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def _send_packet(self, eio_sid, pkt):
@ -440,7 +440,7 @@ class AsyncServer(server.Server):
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
@ -461,15 +461,15 @@ class AsyncServer(server.Server):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
else:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
elif not self.always_connect:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
async def _handle_disconnect(self, eio_sid, namespace):
@ -511,7 +511,7 @@ class AsyncServer(server.Server):
data = list(r)
else:
data = [r]
await server._send_packet(eio_sid, packet.Packet(
await server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, eio_sid, namespace, id, data):
@ -560,7 +560,7 @@ class AsyncServer(server.Server):
await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:

34
src/socketio/client.py

@ -57,6 +57,13 @@ class Client(object):
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param serializer: The serialization method to use when transmitting
packets. Valid values are ``'default'``, ``'pickle'``,
``'msgpack'`` and ``'cbor'``. Alternatively, a subclass
of the :class:`Packet` class with custom implementations
of the ``encode()`` and ``decode()`` methods can be
provided. Client and server must use compatible
serializers.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
@ -82,7 +89,8 @@ class Client(object):
"""
def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, json=None, **kwargs):
randomization_factor=0.5, logger=False, serializer='default',
json=None, **kwargs):
global original_signal_handler
if original_signal_handler is None and \
threading.current_thread() == threading.main_thread():
@ -98,8 +106,15 @@ class Client(object):
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if serializer == 'default':
self.packet_class = packet.Packet
elif serializer == 'msgpack':
from . import msgpack_packet
self.packet_class = msgpack_packet.MsgPackPacket
else:
self.packet_class = serializer
if json is not None:
packet.Packet.json = json
self.packet_class.json = json
engineio_options['json'] = json
self.eio = self._engineio_client_class()(**engineio_options)
@ -381,8 +396,8 @@ class Client(object):
data = [data]
else:
data = []
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id))
self._send_packet(self.packet_class(packet.EVENT, namespace=namespace,
data=[event] + data, id=id))
def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
@ -448,7 +463,8 @@ class Client(object):
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(self.packet_class(
packet.DISCONNECT, namespace=n))
self.eio.disconnect(abort=True)
def get_sid(self, namespace=None):
@ -557,8 +573,8 @@ class Client(object):
data = list(r)
else:
data = [r]
self._send_packet(packet.Packet(packet.ACK, namespace=namespace,
id=id, data=data))
self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
@ -647,7 +663,7 @@ class Client(object):
self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
self._send_packet(packet.Packet(
self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n))
def _handle_eio_message(self, data):
@ -661,7 +677,7 @@ class Client(object):
else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:

16
src/socketio/msgpack_packet.py

@ -0,0 +1,16 @@
import msgpack
from . import packet
class MsgPackPacket(packet.Packet):
def encode(self):
"""Encode the packet for transmission."""
return msgpack.dumps(self._to_dict())
def decode(self, encoded_packet):
"""Decode a transmitted package."""
decoded = msgpack.loads(encoded_packet)
self.packet_type = decoded['type']
self.data = decoded['data']
self.id = decoded.get('id')
self.namespace = decoded['nsp']

12
src/socketio/packet.py

@ -37,7 +37,7 @@ class Packet(object):
self.attachment_count = 0
self.attachments = []
if encoded_packet:
self.attachment_count = self.decode(encoded_packet)
self.attachment_count = self.decode(encoded_packet) or 0
def encode(self):
"""Encode the packet for transmission.
@ -175,3 +175,13 @@ class Packet(object):
False)
else:
return False
def _to_dict(self):
d = {
'type': self.packet_type,
'data': self.data,
'nsp': self.namespace,
}
if self.id:
d['id'] = self.id
return d

46
src/socketio/server.py

@ -24,6 +24,13 @@ class Server(object):
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param serializer: The serialization method to use when transmitting
packets. Valid values are ``'default'``, ``'pickle'``,
``'msgpack'`` and ``'cbor'``. Alternatively, a subclass
of the :class:`Packet` class with custom implementations
of the ``encode()`` and ``decode()`` methods can be
provided. Client and server must use compatible
serializers.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
@ -48,10 +55,11 @@ class Server(object):
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
available options. Valid async modes are
``'threading'``, ``'eventlet'``, ``'gevent'`` and
``'gevent_uwsgi'``. If this argument is not given,
``'eventlet'`` is tried first, then ``'gevent_uwsgi'``,
then ``'gevent'``, and finally ``'threading'``.
The first async mode that has all its dependencies
installed is then one that is chosen.
:param ping_interval: The interval in seconds at which the server pings
@ -98,14 +106,22 @@ class Server(object):
fatal errors are logged even when
``engineio_logger`` is ``False``.
"""
def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, always_connect=False, **kwargs):
def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False,
**kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if serializer == 'default':
self.packet_class = packet.Packet
elif serializer == 'msgpack':
from . import msgpack_packet
self.packet_class = msgpack_packet.MsgPackPacket
else:
self.packet_class = serializer
if json is not None:
packet.Packet.json = json
self.packet_class.json = json
engineio_options['json'] = json
engineio_options['async_handlers'] = False
self.eio = self._engineio_server_class()(**engineio_options)
@ -531,7 +547,7 @@ class Server(object):
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
@ -609,7 +625,7 @@ class Server(object):
data = [data]
else:
data = []
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))
def _send_packet(self, eio_sid, pkt):
@ -626,7 +642,7 @@ class Server(object):
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
@ -647,15 +663,15 @@ class Server(object):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
else:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
elif not self.always_connect:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
def _handle_disconnect(self, eio_sid, namespace):
@ -697,7 +713,7 @@ class Server(object):
data = list(r)
else:
data = [r]
server._send_packet(eio_sid, packet.Packet(
server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, eio_sid, namespace, id, data):
@ -737,7 +753,7 @@ class Server(object):
else:
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:

15
tests/common/test_client.py

@ -11,6 +11,7 @@ import pytest
from socketio import asyncio_namespace
from socketio import client
from socketio import exceptions
from socketio import msgpack_packet
from socketio import namespace
from socketio import packet
@ -49,8 +50,20 @@ class TestClient(unittest.TestCase):
assert c.callbacks == {}
assert c._binary_packet is None
assert c._reconnect_task is None
assert c.packet_class == packet.Packet
def test_custon_json(self):
def test_msgpack(self):
c = client.Client(serializer='msgpack')
assert c.packet_class == msgpack_packet.MsgPackPacket
def test_custom_serializer(self):
class CustomPacket(packet.Packet):
pass
c = client.Client(serializer=CustomPacket)
assert c.packet_class == CustomPacket
def test_custom_json(self):
client.Client()
assert packet.Packet.json == json
assert engineio_packet.Packet.json == json

24
tests/common/test_msgpack_packet.py

@ -0,0 +1,24 @@
import unittest
from socketio import msgpack_packet
from socketio import packet
class TestMsgPackPacket(unittest.TestCase):
def test_encode_decode(self):
p = msgpack_packet.MsgPackPacket(
packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo')
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.data == p2.data
assert p.id == p2.id
assert p.namespace == p2.namespace
def test_encode_decode_with_id(self):
p = msgpack_packet.MsgPackPacket(
packet.EVENT, data=['ev', 42], id=123, namespace='/foo')
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.data == p2.data
assert p.id == p2.id
assert p.namespace == p2.namespace

13
tests/common/test_server.py

@ -6,6 +6,7 @@ from engineio import json
import pytest
from socketio import exceptions
from socketio import msgpack_packet
from socketio import namespace
from socketio import packet
from socketio import server
@ -29,6 +30,7 @@ class TestServer(unittest.TestCase):
assert s.manager == mgr
assert s.eio.on.call_count == 3
assert s.async_handlers
assert s.packet_class == packet.Packet
def test_on_event(self, eio):
s = server.Server()
@ -813,6 +815,17 @@ class TestServer(unittest.TestCase):
**{'logger': 'foo', 'async_handlers': False}
)
def test_msgpack(self, eio):
s = server.Server(serializer='msgpack')
assert s.packet_class == msgpack_packet.MsgPackPacket
def test_custom_serializer(self, eio):
class CustomPacket(packet.Packet):
pass
s = server.Server(serializer=CustomPacket)
assert s.packet_class == CustomPacket
def test_custom_json(self, eio):
# Warning: this test cannot run in parallel with other tests, as it
# changes the JSON encoding/decoding functions

1
tox.ini

@ -15,6 +15,7 @@ commands=
pip install -e .
pytest -p no:logging --cov=socketio --cov-branch --cov-report=term-missing
deps=
msgpack
pytest
pytest-cov

Loading…
Cancel
Save