Browse Source

Support ext_type in the MsgPackPacket class (#1521)

pull/1536/head
Miguel Grinberg 7 months ago
committed by GitHub
parent
commit
208925344a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 30
      src/socketio/msgpack_packet.py
  2. 20
      tests/async/test_client.py
  3. 20
      tests/async/test_server.py
  4. 20
      tests/common/test_client.py
  5. 104
      tests/common/test_msgpack_packet.py
  6. 20
      tests/common/test_server.py

30
src/socketio/msgpack_packet.py

@ -4,14 +4,40 @@ from . import packet
class MsgPackPacket(packet.Packet):
uses_binary_events = False
dumps_default = None
ext_hook = msgpack.ExtType
@classmethod
def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType):
"""Change the default options for msgpack encoding and decoding.
:param dumps_default: a function called for objects that cannot be
serialized by default msgpack. The function
receives one argument, the object to serialize.
It should return a serializable object or a
``msgpack.ExtType`` instance.
:param ext_hook: a function called when a ``msgpack.ExtType`` object is
seen during decoding. The function receives two
arguments, the code and the data. It should return the
decoded object.
"""
class CustomMsgPackPacket(MsgPackPacket):
dumps_default = None
ext_hook = None
CustomMsgPackPacket.dumps_default = dumps_default
CustomMsgPackPacket.ext_hook = ext_hook
return CustomMsgPackPacket
def encode(self):
"""Encode the packet for transmission."""
return msgpack.dumps(self._to_dict())
return msgpack.dumps(self._to_dict(),
default=self.__class__.dumps_default)
def decode(self, encoded_packet):
"""Decode a transmitted package."""
decoded = msgpack.loads(encoded_packet)
decoded = msgpack.loads(encoded_packet,
ext_hook=self.__class__.ext_hook)
self.packet_type = decoded['type']
self.data = decoded.get('data')
self.id = decoded.get('id')

20
tests/async/test_client.py

@ -1,5 +1,6 @@
import asyncio
from unittest import mock
from datetime import datetime, timezone, timedelta
import pytest
@ -8,6 +9,7 @@ from socketio import async_namespace
from engineio import exceptions as engineio_exceptions
from socketio import exceptions
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestAsyncClient:
@ -1242,3 +1244,21 @@ class TestAsyncClient:
assert c.sid is None
assert not c.connected
c.start_background_task.assert_not_called()
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
c = async_client.AsyncClient(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

20
tests/async/test_server.py

@ -1,6 +1,7 @@
import asyncio
import logging
from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json
from engineio import packet as eio_packet
@ -11,6 +12,7 @@ from socketio import async_namespace
from socketio import exceptions
from socketio import namespace
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.AsyncServer', **{
@ -1089,3 +1091,21 @@ class TestAsyncServer:
s = async_server.AsyncServer()
await s.sleep(1.23)
s.eio.sleep.assert_awaited_once_with(1.23)
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
s = async_server.AsyncServer(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

20
tests/common/test_client.py

@ -1,6 +1,7 @@
import logging
import time
from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import exceptions as engineio_exceptions
from engineio import json
@ -13,6 +14,7 @@ from socketio import exceptions
from socketio import msgpack_packet
from socketio import namespace
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestClient:
@ -1386,3 +1388,21 @@ class TestClient:
assert c.sid is None
assert not c.connected
c.start_background_task.assert_not_called()
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
c = client.Client(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

104
tests/common/test_msgpack_packet.py

@ -1,3 +1,8 @@
from datetime import datetime, timedelta, timezone
import pytest
import msgpack
from socketio import msgpack_packet
from socketio import packet
@ -32,3 +37,102 @@ class TestMsgPackPacket:
assert p.packet_type == packet.ACK
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p2.data == {'foo': b'bar'}
def test_encode_with_dumps_default(self):
def default(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError('Unknown type')
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'current' in p2.data
assert isinstance(p2.data['current'], str)
assert default(data['current']) == p2.data['current']
data.pop('current')
p2_data_without_current = p2.data.copy()
p2_data_without_current.pop('current')
assert data == p2_data_without_current
def test_encode_without_dumps_default(self):
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p_without_default = msgpack_packet.MsgPackPacket(data=data)
with pytest.raises(TypeError):
p_without_default.encode()
def test_encode_decode_with_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
def ext_hook(code, data):
if code == 1:
return Custom(data)
raise TypeError('Unknown ext type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)(
encoded_packet=p.encode()
)
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.data == p2.data
assert p.namespace == p2.namespace
def test_encode_decode_without_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'custom' in p2.data
assert isinstance(p2.data['custom'], msgpack.ExtType)
assert p2.data['custom'].code == 1
assert p2.data['custom'].data == b'custom_data'
data.pop('custom')
p2_data_without_custom = p2.data.copy()
p2_data_without_custom.pop('custom')
assert data == p2_data_without_custom

20
tests/common/test_server.py

@ -1,5 +1,6 @@
import logging
from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json
from engineio import packet as eio_packet
@ -10,6 +11,7 @@ from socketio import msgpack_packet
from socketio import namespace
from socketio import packet
from socketio import server
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.Server', **{
@ -1032,3 +1034,21 @@ class TestServer:
s = server.Server()
s.sleep(1.23)
s.eio.sleep.assert_called_once_with(1.23)
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
s = server.Server(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

Loading…
Cancel
Save