From 5159e84c49daaf2da0579bfc6ee954a9c738a076 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 21 Jul 2021 00:17:34 +0100 Subject: [PATCH] Support msgpack and custom packet serializers (Fixes #749) --- src/socketio/asyncio_client.py | 10 +++---- src/socketio/asyncio_server.py | 16 +++++----- src/socketio/client.py | 34 +++++++++++++++------ src/socketio/msgpack_packet.py | 16 ++++++++++ src/socketio/packet.py | 12 +++++++- src/socketio/server.py | 46 +++++++++++++++++++---------- tests/common/test_client.py | 15 +++++++++- tests/common/test_msgpack_packet.py | 24 +++++++++++++++ tests/common/test_server.py | 13 ++++++++ tox.ini | 1 + 10 files changed, 148 insertions(+), 39 deletions(-) create mode 100644 src/socketio/msgpack_packet.py create mode 100644 tests/common/test_msgpack_packet.py diff --git a/src/socketio/asyncio_client.py b/src/socketio/asyncio_client.py index 157cd87..63d0899 100644 --- a/src/socketio/asyncio_client.py +++ b/src/socketio/asyncio_client.py @@ -220,7 +220,7 @@ class AsyncClient(client.Client): data = [data] else: data = [] - await self._send_packet(packet.Packet( + await self._send_packet(self.packet_class( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): @@ -296,7 +296,7 @@ class AsyncClient(client.Client): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - await self._send_packet(packet.Packet(packet.DISCONNECT, + await self._send_packet(self.packet_class(packet.DISCONNECT, namespace=n)) await self.eio.disconnect(abort=True) @@ -379,7 +379,7 @@ class AsyncClient(client.Client): data = list(r) else: data = [r] - await self._send_packet(packet.Packet( + await self._send_packet(self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): @@ -482,7 +482,7 @@ class AsyncClient(client.Client): self.sid = self.eio.sid real_auth = await self._get_real_value(self.connection_auth) for n in self.connection_namespaces: - await self._send_packet(packet.Packet( + await self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): @@ -496,7 +496,7 @@ class AsyncClient(client.Client): else: await self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = packet.Packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/asyncio_server.py b/src/socketio/asyncio_server.py index 8726e30..ffdcdec 100644 --- a/src/socketio/asyncio_server.py +++ b/src/socketio/asyncio_server.py @@ -369,7 +369,7 @@ class AsyncServer(server.Server): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(eio_sid, packet.Packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) @@ -423,7 +423,7 @@ class AsyncServer(server.Server): data = [data] else: data = [] - await self._send_packet(sid, packet.Packet( + await self._send_packet(sid, self.packet_class( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def _send_packet(self, eio_sid, pkt): @@ -440,7 +440,7 @@ class AsyncServer(server.Server): namespace = namespace or '/' sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - await self._send_packet(eio_sid, packet.Packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -461,15 +461,15 @@ class AsyncServer(server.Server): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(eio_sid, packet.Packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - await self._send_packet(eio_sid, packet.Packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) elif not self.always_connect: - await self._send_packet(eio_sid, packet.Packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) async def _handle_disconnect(self, eio_sid, namespace): @@ -511,7 +511,7 @@ class AsyncServer(server.Server): data = list(r) else: data = [r] - await server._send_packet(eio_sid, packet.Packet( + await server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, eio_sid, namespace, id, data): @@ -560,7 +560,7 @@ class AsyncServer(server.Server): await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = packet.Packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/client.py b/src/socketio/client.py index 84fa7e0..b631608 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -57,6 +57,13 @@ class Client(object): use. To disable logging set to ``False``. The default is ``False``. Note that fatal errors are logged even when ``logger`` is ``False``. + :param serializer: The serialization method to use when transmitting + packets. Valid values are ``'default'``, ``'pickle'``, + ``'msgpack'`` and ``'cbor'``. Alternatively, a subclass + of the :class:`Packet` class with custom implementations + of the ``encode()`` and ``decode()`` methods can be + provided. Client and server must use compatible + serializers. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -82,7 +89,8 @@ class Client(object): """ def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, - randomization_factor=0.5, logger=False, json=None, **kwargs): + randomization_factor=0.5, logger=False, serializer='default', + json=None, **kwargs): global original_signal_handler if original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -98,8 +106,15 @@ class Client(object): engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: engineio_options['logger'] = engineio_logger + if serializer == 'default': + self.packet_class = packet.Packet + elif serializer == 'msgpack': + from . import msgpack_packet + self.packet_class = msgpack_packet.MsgPackPacket + else: + self.packet_class = serializer if json is not None: - packet.Packet.json = json + self.packet_class.json = json engineio_options['json'] = json self.eio = self._engineio_client_class()(**engineio_options) @@ -381,8 +396,8 @@ class Client(object): data = [data] else: data = [] - self._send_packet(packet.Packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id)) + self._send_packet(self.packet_class(packet.EVENT, namespace=namespace, + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to one or more connected clients. @@ -448,7 +463,8 @@ class Client(object): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n)) + self._send_packet(self.packet_class( + packet.DISCONNECT, namespace=n)) self.eio.disconnect(abort=True) def get_sid(self, namespace=None): @@ -557,8 +573,8 @@ class Client(object): data = list(r) else: data = [r] - self._send_packet(packet.Packet(packet.ACK, namespace=namespace, - id=id, data=data)) + self._send_packet(self.packet_class( + packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, namespace, id, data): namespace = namespace or '/' @@ -647,7 +663,7 @@ class Client(object): self.sid = self.eio.sid real_auth = self._get_real_value(self.connection_auth) for n in self.connection_namespaces: - self._send_packet(packet.Packet( + self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): @@ -661,7 +677,7 @@ class Client(object): else: self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = packet.Packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py new file mode 100644 index 0000000..d883b57 --- /dev/null +++ b/src/socketio/msgpack_packet.py @@ -0,0 +1,16 @@ +import msgpack +from . import packet + + +class MsgPackPacket(packet.Packet): + def encode(self): + """Encode the packet for transmission.""" + return msgpack.dumps(self._to_dict()) + + def decode(self, encoded_packet): + """Decode a transmitted package.""" + decoded = msgpack.loads(encoded_packet) + self.packet_type = decoded['type'] + self.data = decoded['data'] + self.id = decoded.get('id') + self.namespace = decoded['nsp'] diff --git a/src/socketio/packet.py b/src/socketio/packet.py index 0d6f808..e85e58c 100644 --- a/src/socketio/packet.py +++ b/src/socketio/packet.py @@ -37,7 +37,7 @@ class Packet(object): self.attachment_count = 0 self.attachments = [] if encoded_packet: - self.attachment_count = self.decode(encoded_packet) + self.attachment_count = self.decode(encoded_packet) or 0 def encode(self): """Encode the packet for transmission. @@ -175,3 +175,13 @@ class Packet(object): False) else: return False + + def _to_dict(self): + d = { + 'type': self.packet_type, + 'data': self.data, + 'nsp': self.namespace, + } + if self.id: + d['id'] = self.id + return d diff --git a/src/socketio/server.py b/src/socketio/server.py index b8f33ea..9084120 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -24,6 +24,13 @@ class Server(object): use. To disable logging set to ``False``. The default is ``False``. Note that fatal errors are logged even when ``logger`` is ``False``. + :param serializer: The serialization method to use when transmitting + packets. Valid values are ``'default'``, ``'pickle'``, + ``'msgpack'`` and ``'cbor'``. Alternatively, a subclass + of the :class:`Packet` class with custom implementations + of the ``encode()`` and ``decode()`` methods can be + provided. Client and server must use compatible + serializers. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -48,10 +55,11 @@ class Server(object): :param async_mode: The asynchronous model to use. See the Deployment section in the documentation for a description of the - available options. Valid async modes are "threading", - "eventlet", "gevent" and "gevent_uwsgi". If this - argument is not given, "eventlet" is tried first, then - "gevent_uwsgi", then "gevent", and finally "threading". + available options. Valid async modes are + ``'threading'``, ``'eventlet'``, ``'gevent'`` and + ``'gevent_uwsgi'``. If this argument is not given, + ``'eventlet'`` is tried first, then ``'gevent_uwsgi'``, + then ``'gevent'``, and finally ``'threading'``. The first async mode that has all its dependencies installed is then one that is chosen. :param ping_interval: The interval in seconds at which the server pings @@ -98,14 +106,22 @@ class Server(object): fatal errors are logged even when ``engineio_logger`` is ``False``. """ - def __init__(self, client_manager=None, logger=False, json=None, - async_handlers=True, always_connect=False, **kwargs): + def __init__(self, client_manager=None, logger=False, serializer='default', + json=None, async_handlers=True, always_connect=False, + **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: engineio_options['logger'] = engineio_logger + if serializer == 'default': + self.packet_class = packet.Packet + elif serializer == 'msgpack': + from . import msgpack_packet + self.packet_class = msgpack_packet.MsgPackPacket + else: + self.packet_class = serializer if json is not None: - packet.Packet.json = json + self.packet_class.json = json engineio_options['json'] = json engineio_options['async_handlers'] = False self.eio = self._engineio_server_class()(**engineio_options) @@ -531,7 +547,7 @@ class Server(object): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) @@ -609,7 +625,7 @@ class Server(object): data = [data] else: data = [] - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) def _send_packet(self, eio_sid, pkt): @@ -626,7 +642,7 @@ class Server(object): namespace = namespace or '/' sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -647,15 +663,15 @@ class Server(object): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace) elif not self.always_connect: - self._send_packet(eio_sid, packet.Packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) def _handle_disconnect(self, eio_sid, namespace): @@ -697,7 +713,7 @@ class Server(object): data = list(r) else: data = [r] - server._send_packet(eio_sid, packet.Packet( + server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, eio_sid, namespace, id, data): @@ -737,7 +753,7 @@ class Server(object): else: self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = packet.Packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 9eb8211..391187f 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -11,6 +11,7 @@ import pytest from socketio import asyncio_namespace from socketio import client from socketio import exceptions +from socketio import msgpack_packet from socketio import namespace from socketio import packet @@ -49,8 +50,20 @@ class TestClient(unittest.TestCase): assert c.callbacks == {} assert c._binary_packet is None assert c._reconnect_task is None + assert c.packet_class == packet.Packet - def test_custon_json(self): + def test_msgpack(self): + c = client.Client(serializer='msgpack') + assert c.packet_class == msgpack_packet.MsgPackPacket + + def test_custom_serializer(self): + class CustomPacket(packet.Packet): + pass + + c = client.Client(serializer=CustomPacket) + assert c.packet_class == CustomPacket + + def test_custom_json(self): client.Client() assert packet.Packet.json == json assert engineio_packet.Packet.json == json diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py new file mode 100644 index 0000000..d8049a0 --- /dev/null +++ b/tests/common/test_msgpack_packet.py @@ -0,0 +1,24 @@ +import unittest + +from socketio import msgpack_packet +from socketio import packet + + +class TestMsgPackPacket(unittest.TestCase): + def test_encode_decode(self): + p = msgpack_packet.MsgPackPacket( + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.data == p2.data + assert p.id == p2.id + assert p.namespace == p2.namespace + + def test_encode_decode_with_id(self): + p = msgpack_packet.MsgPackPacket( + packet.EVENT, data=['ev', 42], id=123, namespace='/foo') + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.data == p2.data + assert p.id == p2.id + assert p.namespace == p2.namespace diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 556dab7..3325fcb 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -6,6 +6,7 @@ from engineio import json import pytest from socketio import exceptions +from socketio import msgpack_packet from socketio import namespace from socketio import packet from socketio import server @@ -29,6 +30,7 @@ class TestServer(unittest.TestCase): assert s.manager == mgr assert s.eio.on.call_count == 3 assert s.async_handlers + assert s.packet_class == packet.Packet def test_on_event(self, eio): s = server.Server() @@ -813,6 +815,17 @@ class TestServer(unittest.TestCase): **{'logger': 'foo', 'async_handlers': False} ) + def test_msgpack(self, eio): + s = server.Server(serializer='msgpack') + assert s.packet_class == msgpack_packet.MsgPackPacket + + def test_custom_serializer(self, eio): + class CustomPacket(packet.Packet): + pass + + s = server.Server(serializer=CustomPacket) + assert s.packet_class == CustomPacket + def test_custom_json(self, eio): # Warning: this test cannot run in parallel with other tests, as it # changes the JSON encoding/decoding functions diff --git a/tox.ini b/tox.ini index 34c16ac..f47e77d 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ commands= pip install -e . pytest -p no:logging --cov=socketio --cov-branch --cov-report=term-missing deps= + msgpack pytest pytest-cov