Browse Source

MsgPackPacket.configure method

pull/1521/head
Miguel Grinberg 7 months ago
parent
commit
c7b872915c
Failed to extract signature
  1. 13
      src/socketio/async_client.py
  2. 19
      src/socketio/async_server.py
  3. 7
      src/socketio/base_client.py
  4. 7
      src/socketio/base_server.py
  5. 16
      src/socketio/client.py
  6. 36
      src/socketio/msgpack_packet.py
  7. 19
      src/socketio/server.py
  8. 23
      tests/async/test_client.py
  9. 23
      tests/async/test_server.py
  10. 21
      tests/common/test_client.py
  11. 19
      tests/common/test_msgpack_packet.py
  12. 21
      tests/common/test_server.py

13
src/socketio/async_client.py

@ -45,9 +45,6 @@ class AsyncClient(base_client.BaseClient):
leave interrupt handling to the calling application. leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the Interrupt handling can only be enabled when the
client instance is created in the main thread. client instance is created in the main thread.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -246,7 +243,7 @@ class AsyncClient(base_client.BaseClient):
data = [data] data = [data]
else: else:
data = [] data = []
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id)) packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def send(self, data, namespace=None, callback=None): async def send(self, data, namespace=None, callback=None):
@ -328,7 +325,7 @@ class AsyncClient(base_client.BaseClient):
# here we just request the disconnection # here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler # later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces: for n in self.namespaces:
await self._send_packet(self._create_packet(packet.DISCONNECT, await self._send_packet(self.packet_class(packet.DISCONNECT,
namespace=n)) namespace=n))
await self.eio.disconnect() await self.eio.disconnect()
@ -425,7 +422,7 @@ class AsyncClient(base_client.BaseClient):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, namespace, id, data): async def _handle_ack(self, namespace, id, data):
@ -558,7 +555,7 @@ class AsyncClient(base_client.BaseClient):
self.sid = self.eio.sid self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth) or {} real_auth = await self._get_real_value(self.connection_auth) or {}
for n in self.connection_namespaces: for n in self.connection_namespaces:
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n)) packet.CONNECT, data=real_auth, namespace=n))
async def _handle_eio_message(self, data): async def _handle_eio_message(self, data):
@ -572,7 +569,7 @@ class AsyncClient(base_client.BaseClient):
else: else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data) await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace, pkt.data) await self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

19
src/socketio/async_server.py

@ -50,9 +50,6 @@ class AsyncServer(base_server.BaseServer):
default is `['/']`, which always accepts connections to default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all the default namespace. Set to `'*'` to accept all
namespaces. namespaces.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -428,7 +425,7 @@ class AsyncServer(base_server.BaseServer):
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace)) packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid, await self._trigger_event('disconnect', namespace, sid,
self.reason.SERVER_DISCONNECT) self.reason.SERVER_DISCONNECT)
@ -541,13 +538,13 @@ class AsyncServer(base_server.BaseServer):
or self.namespaces == '*' or namespace in self.namespaces: or self.namespaces == '*' or namespace in self.namespaces:
sid = await self.manager.connect(eio_sid, namespace) sid = await self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data='Unable to connect', packet.CONNECT_ERROR, data='Unable to connect',
namespace=namespace)) namespace=namespace))
return return
if self.always_connect: if self.always_connect:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
@ -571,15 +568,15 @@ class AsyncServer(base_server.BaseServer):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
else: else:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
await self.manager.disconnect(sid, namespace, ignore_queue=True) await self.manager.disconnect(sid, namespace, ignore_queue=True)
elif not self.always_connect: elif not self.always_connect:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
async def _handle_disconnect(self, eio_sid, namespace, reason=None): async def _handle_disconnect(self, eio_sid, namespace, reason=None):
@ -625,7 +622,7 @@ class AsyncServer(base_server.BaseServer):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
await server._send_packet(eio_sid, self._create_packet( await server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, eio_sid, namespace, id, data): async def _handle_ack(self, eio_sid, namespace, id, data):
@ -689,7 +686,7 @@ class AsyncServer(base_server.BaseServer):
await self._handle_ack(eio_sid, pkt.namespace, pkt.id, await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data) pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace, pkt.data) await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

7
src/socketio/base_client.py

@ -38,8 +38,7 @@ class BaseClient:
def __init__(self, reconnection=True, reconnection_attempts=0, def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5, reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, serializer='default', randomization_factor=0.5, logger=False, serializer='default',
json=None, handle_sigint=True, serializer_args=None, json=None, handle_sigint=True, **kwargs):
**kwargs):
global original_signal_handler global original_signal_handler
if handle_sigint and original_signal_handler is None and \ if handle_sigint and original_signal_handler is None and \
threading.current_thread() == threading.main_thread(): threading.current_thread() == threading.main_thread():
@ -64,7 +63,6 @@ class BaseClient:
self.packet_class = msgpack_packet.MsgPackPacket self.packet_class = msgpack_packet.MsgPackPacket
else: else:
self.packet_class = serializer self.packet_class = serializer
self.packet_class_args = serializer_args or {}
if json is not None: if json is not None:
self.packet_class.json = json self.packet_class.json = json
engineio_options['json'] = json engineio_options['json'] = json
@ -285,9 +283,6 @@ class BaseClient:
self.callbacks[namespace][id] = callback self.callbacks[namespace][id] = callback
return id return id
def _create_packet(self, *args, **kwargs):
return self.packet_class(*args, **kwargs, **self.packet_class_args)
def _handle_eio_connect(self): # pragma: no cover def _handle_eio_connect(self): # pragma: no cover
raise NotImplementedError() raise NotImplementedError()

7
src/socketio/base_server.py

@ -15,7 +15,7 @@ class BaseServer:
def __init__(self, client_manager=None, logger=False, serializer='default', def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False, json=None, async_handlers=True, always_connect=False,
namespaces=None, serializer_args=None, **kwargs): namespaces=None, **kwargs):
engineio_options = kwargs engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None) engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None: if engineio_logger is not None:
@ -27,7 +27,6 @@ class BaseServer:
self.packet_class = msgpack_packet.MsgPackPacket self.packet_class = msgpack_packet.MsgPackPacket
else: else:
self.packet_class = serializer self.packet_class = serializer
self.packet_class_args = serializer_args or {}
if json is not None: if json is not None:
self.packet_class.json = json self.packet_class.json = json
engineio_options['json'] = json engineio_options['json'] = json
@ -254,10 +253,6 @@ class BaseServer:
args = (namespace, *args) args = (namespace, *args)
return handler, args return handler, args
def _create_packet(self, *args, **kwargs):
return self.packet_class(*args, **kwargs,
**self.packet_class_args)
def _handle_eio_connect(self): # pragma: no cover def _handle_eio_connect(self): # pragma: no cover
raise NotImplementedError() raise NotImplementedError()

16
src/socketio/client.py

@ -48,9 +48,6 @@ class Client(base_client.BaseClient):
leave interrupt handling to the calling application. leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the Interrupt handling can only be enabled when the
client instance is created in the main thread. client instance is created in the main thread.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -237,9 +234,8 @@ class Client(base_client.BaseClient):
data = [data] data = [data]
else: else:
data = [] data = []
self._send_packet( self._send_packet(self.packet_class(packet.EVENT, namespace=namespace,
self._create_packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id))
data=[event] + data, id=id))
def send(self, data, namespace=None, callback=None): def send(self, data, namespace=None, callback=None):
"""Send a message to the server. """Send a message to the server.
@ -311,7 +307,7 @@ class Client(base_client.BaseClient):
# here we just request the disconnection # here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler # later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces: for n in self.namespaces:
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.DISCONNECT, namespace=n)) packet.DISCONNECT, namespace=n))
self.eio.disconnect() self.eio.disconnect()
@ -406,7 +402,7 @@ class Client(base_client.BaseClient):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, namespace, id, data): def _handle_ack(self, namespace, id, data):
@ -510,7 +506,7 @@ class Client(base_client.BaseClient):
self.sid = self.eio.sid self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth) or {} real_auth = self._get_real_value(self.connection_auth) or {}
for n in self.connection_namespaces: for n in self.connection_namespaces:
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n)) packet.CONNECT, data=real_auth, namespace=n))
def _handle_eio_message(self, data): def _handle_eio_message(self, data):
@ -524,7 +520,7 @@ class Client(base_client.BaseClient):
else: else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data) self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace, pkt.data) self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

36
src/socketio/msgpack_packet.py

@ -4,34 +4,28 @@ from . import packet
class MsgPackPacket(packet.Packet): class MsgPackPacket(packet.Packet):
uses_binary_events = False uses_binary_events = False
dumps_default = None
ext_hook = msgpack.ExtType
def __init__( @classmethod
self, def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType):
packet_type=packet.EVENT, class CustomMsgPackPacket(MsgPackPacket):
data=None, dumps_default = None
namespace=None, ext_hook = None
id=None,
binary=None, CustomMsgPackPacket.dumps_default = dumps_default
encoded_packet=None, CustomMsgPackPacket.ext_hook = ext_hook
dumps_default=None, return CustomMsgPackPacket
ext_hook=None,
):
self.dumps_default = dumps_default
self.ext_hook = ext_hook
super().__init__(
packet_type, data, namespace, id, binary, encoded_packet
)
def encode(self): def encode(self):
"""Encode the packet for transmission.""" """Encode the packet for transmission."""
return msgpack.dumps(self._to_dict(), default=self.dumps_default) return msgpack.dumps(self._to_dict(),
default=self.__class__.dumps_default)
def decode(self, encoded_packet): def decode(self, encoded_packet):
"""Decode a transmitted package.""" """Decode a transmitted package."""
if self.ext_hook is None: decoded = msgpack.loads(encoded_packet,
decoded = msgpack.loads(encoded_packet) ext_hook=self.__class__.ext_hook)
else:
decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook)
self.packet_type = decoded['type'] self.packet_type = decoded['type']
self.data = decoded.get('data') self.data = decoded.get('data')
self.id = decoded.get('id') self.id = decoded.get('id')

19
src/socketio/server.py

@ -53,9 +53,6 @@ class Server(base_server.BaseServer):
default is `['/']`, which always accepts connections to default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all the default namespace. Set to `'*'` to accept all
namespaces. namespaces.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -404,7 +401,7 @@ class Server(base_server.BaseServer):
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace)) packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid, self._trigger_event('disconnect', namespace, sid,
self.reason.SERVER_DISCONNECT) self.reason.SERVER_DISCONNECT)
@ -523,13 +520,13 @@ class Server(base_server.BaseServer):
or self.namespaces == '*' or namespace in self.namespaces: or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data='Unable to connect', packet.CONNECT_ERROR, data='Unable to connect',
namespace=namespace)) namespace=namespace))
return return
if self.always_connect: if self.always_connect:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
@ -553,15 +550,15 @@ class Server(base_server.BaseServer):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
else: else:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
self.manager.disconnect(sid, namespace, ignore_queue=True) self.manager.disconnect(sid, namespace, ignore_queue=True)
elif not self.always_connect: elif not self.always_connect:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
def _handle_disconnect(self, eio_sid, namespace, reason=None): def _handle_disconnect(self, eio_sid, namespace, reason=None):
@ -604,7 +601,7 @@ class Server(base_server.BaseServer):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
server._send_packet(eio_sid, self._create_packet( server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, eio_sid, namespace, id, data): def _handle_ack(self, eio_sid, namespace, id, data):
@ -653,7 +650,7 @@ class Server(base_server.BaseServer):
else: else:
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace, pkt.data) self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

23
tests/async/test_client.py

@ -9,6 +9,7 @@ from socketio import async_namespace
from engineio import exceptions as engineio_exceptions from engineio import exceptions as engineio_exceptions
from socketio import exceptions from socketio import exceptions
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestAsyncClient: class TestAsyncClient:
@ -1244,32 +1245,20 @@ class TestAsyncClient:
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args(self):
args = {"foo": "bar"}
c = async_client.AsyncClient(serializer_args=args)
assert c.packet_class_args == args
def test_serializer_args_with_msgpack(self): def test_serializer_args_with_msgpack(self):
def default(o): def default(o):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
raise TypeError("Unknown type") raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))} data = {"current": datetime.now(timezone(timedelta(0)))}
c = async_client.AsyncClient(serializer='msgpack', c = async_client.AsyncClient(
serializer_args=args) serializer=MsgPackPacket.configure(dumps_default=default))
p = c._create_packet(data=data) p = c.packet_class(data=data)
p2 = c._create_packet(encoded_packet=p.encode()) p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data assert p.data != p2.data
assert isinstance(p2.data, dict) assert isinstance(p2.data, dict)
assert "current" in p2.data assert "current" in p2.data
assert isinstance(p2.data["current"], str) assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"] assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self):
args = {"invalid_arg": 123}
c = async_client.AsyncClient(serializer='msgpack',
serializer_args=args)
with pytest.raises(TypeError):
c._create_packet(data={"foo": "bar"}).encode()

23
tests/async/test_server.py

@ -12,6 +12,7 @@ from socketio import async_namespace
from socketio import exceptions from socketio import exceptions
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.AsyncServer', **{ @mock.patch('socketio.server.engineio.AsyncServer', **{
@ -1091,32 +1092,20 @@ class TestAsyncServer:
await s.sleep(1.23) await s.sleep(1.23)
s.eio.sleep.assert_awaited_once_with(1.23) s.eio.sleep.assert_awaited_once_with(1.23)
def test_serializer_args(self, eio):
args = {"foo": "bar"}
s = async_server.AsyncServer(serializer_args=args)
assert s.packet_class_args == args
def test_serializer_args_with_msgpack(self, eio): def test_serializer_args_with_msgpack(self, eio):
def default(o): def default(o):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
raise TypeError("Unknown type") raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))} data = {"current": datetime.now(timezone(timedelta(0)))}
s = async_server.AsyncServer(serializer='msgpack', s = async_server.AsyncServer(
serializer_args=args) serializer=MsgPackPacket.configure(dumps_default=default))
p = s._create_packet(data=data) p = s.packet_class(data=data)
p2 = s._create_packet(encoded_packet=p.encode()) p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data assert p.data != p2.data
assert isinstance(p2.data, dict) assert isinstance(p2.data, dict)
assert "current" in p2.data assert "current" in p2.data
assert isinstance(p2.data["current"], str) assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"] assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self, eio):
args = {"invalid_arg": 123}
s = async_server.AsyncServer(serializer='msgpack',
serializer_args=args)
with pytest.raises(TypeError):
s._create_packet(data={"foo": "bar"}).encode()

21
tests/common/test_client.py

@ -14,6 +14,7 @@ from socketio import exceptions
from socketio import msgpack_packet from socketio import msgpack_packet
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestClient: class TestClient:
@ -1388,30 +1389,20 @@ class TestClient:
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args(self):
args = {"foo": "bar"}
c = client.Client(serializer_args=args)
assert c.packet_class_args == args
def test_serializer_args_with_msgpack(self): def test_serializer_args_with_msgpack(self):
def default(o): def default(o):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
raise TypeError("Unknown type") raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))} data = {"current": datetime.now(timezone(timedelta(0)))}
c = client.Client(serializer='msgpack', serializer_args=args) c = client.Client(
p = c._create_packet(data=data) serializer=MsgPackPacket.configure(dumps_default=default))
p2 = c._create_packet(encoded_packet=p.encode()) p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data assert p.data != p2.data
assert isinstance(p2.data, dict) assert isinstance(p2.data, dict)
assert "current" in p2.data assert "current" in p2.data
assert isinstance(p2.data["current"], str) assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"] assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self):
args = {"invalid_arg": 123}
c = client.Client(serializer='msgpack', serializer_args=args)
with pytest.raises(TypeError):
c._create_packet(data={"foo": "bar"}).encode()

19
tests/common/test_msgpack_packet.py

@ -10,8 +10,7 @@ from socketio import packet
class TestMsgPackPacket: class TestMsgPackPacket:
def test_encode_decode(self): def test_encode_decode(self):
p = msgpack_packet.MsgPackPacket( p = msgpack_packet.MsgPackPacket(
packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo')
)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.data == p2.data assert p.data == p2.data
@ -20,8 +19,7 @@ class TestMsgPackPacket:
def test_encode_decode_with_id(self): def test_encode_decode_with_id(self):
p = msgpack_packet.MsgPackPacket( p = msgpack_packet.MsgPackPacket(
packet.EVENT, data=['ev', 42], id=123, namespace='/foo' packet.EVENT, data=['ev', 42], id=123, namespace='/foo')
)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.data == p2.data assert p.data == p2.data
@ -50,7 +48,8 @@ class TestMsgPackPacket:
'current': datetime.now(tz=timezone(timedelta(0))), 'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value', 'key': 'value',
} }
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.id == p2.id assert p.id == p2.id
@ -95,9 +94,10 @@ class TestMsgPackPacket:
raise TypeError('Unknown ext type') raise TypeError('Unknown ext type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'} data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
p2 = msgpack_packet.MsgPackPacket( data=data)
encoded_packet=p.encode(), ext_hook=ext_hook p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)(
encoded_packet=p.encode()
) )
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.id == p2.id assert p.id == p2.id
@ -118,7 +118,8 @@ class TestMsgPackPacket:
raise TypeError('Unknown type') raise TypeError('Unknown type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'} data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.id == p2.id assert p.id == p2.id

21
tests/common/test_server.py

@ -11,6 +11,7 @@ from socketio import msgpack_packet
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio import server from socketio import server
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.Server', **{ @mock.patch('socketio.server.engineio.Server', **{
@ -1034,30 +1035,20 @@ class TestServer:
s.sleep(1.23) s.sleep(1.23)
s.eio.sleep.assert_called_once_with(1.23) s.eio.sleep.assert_called_once_with(1.23)
def test_serializer_args(self, eio):
args = {"foo": "bar"}
s = server.Server(serializer_args=args)
assert s.packet_class_args == args
def test_serializer_args_with_msgpack(self, eio): def test_serializer_args_with_msgpack(self, eio):
def default(o): def default(o):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
raise TypeError("Unknown type") raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))} data = {"current": datetime.now(timezone(timedelta(0)))}
s = server.Server(serializer='msgpack', serializer_args=args) s = server.Server(
p = s._create_packet(data=data) serializer=MsgPackPacket.configure(dumps_default=default))
p2 = s._create_packet(encoded_packet=p.encode()) p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data assert p.data != p2.data
assert isinstance(p2.data, dict) assert isinstance(p2.data, dict)
assert "current" in p2.data assert "current" in p2.data
assert isinstance(p2.data["current"], str) assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"] assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self, eio):
args = {"invalid_arg": 123}
s = server.Server(serializer='msgpack', serializer_args=args)
with pytest.raises(TypeError):
s._create_packet(data={"foo": "bar"}).encode()

Loading…
Cancel
Save