diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 6e86a72..846db2f 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -16,11 +16,11 @@ class MsgPackPacket(packet.Packet): dumps_default=None, ext_hook=None, ): + self.dumps_default = dumps_default + self.ext_hook = ext_hook super().__init__( packet_type, data, namespace, id, binary, encoded_packet ) - self.dumps_default = dumps_default - self.ext_hook = ext_hook def encode(self): """Encode the packet for transmission.""" diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index e0197a2..1079e01 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -1,3 +1,8 @@ +from datetime import datetime, timedelta, timezone + +import pytest +import msgpack + from socketio import msgpack_packet from socketio import packet @@ -5,7 +10,8 @@ from socketio import packet class TestMsgPackPacket: def test_encode_decode(self): p = msgpack_packet.MsgPackPacket( - packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -14,7 +20,8 @@ class TestMsgPackPacket: def test_encode_decode_with_id(self): p = msgpack_packet.MsgPackPacket( - packet.EVENT, data=['ev', 42], id=123, namespace='/foo') + packet.EVENT, data=['ev', 42], id=123, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -32,3 +39,101 @@ class TestMsgPackPacket: assert p.packet_type == packet.ACK p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p2.data == {'foo': b'bar'} + + def test_encode_with_dumps_default(self): + def default(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError('Unknown type') + + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'current' in p2.data + assert isinstance(p2.data['current'], str) + assert default(data['current']) == p2.data['current'] + + data.pop('current') + p2_data_without_current = p2.data.copy() + p2_data_without_current.pop('current') + assert data == p2_data_without_current + + def test_encode_without_dumps_default(self): + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p_without_default = msgpack_packet.MsgPackPacket(data=data) + with pytest.raises( + TypeError, match="can not serialize 'datetime.datetime' object" + ): + p_without_default.encode() + + def test_encode_decode_with_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + def ext_hook(code, data): + if code == 1: + return Custom(data) + raise TypeError('Unknown ext type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket( + encoded_packet=p.encode(), ext_hook=ext_hook + ) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.data == p2.data + assert p.namespace == p2.namespace + + def test_encode_decode_without_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'custom' in p2.data + assert isinstance(p2.data['custom'], msgpack.ExtType) + assert p2.data['custom'].code == 1 + assert p2.data['custom'].data == b'custom_data' + + data.pop('custom') + p2_data_without_custom = p2.data.copy() + p2_data_without_custom.pop('custom') + assert data == p2_data_without_custom