Browse Source

Support ext_type in the MsgPackPacket class (#1521)

pull/1536/head
Miguel Grinberg 7 months ago
committed by GitHub
parent
commit
208925344a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 30
      src/socketio/msgpack_packet.py
  2. 20
      tests/async/test_client.py
  3. 20
      tests/async/test_server.py
  4. 20
      tests/common/test_client.py
  5. 104
      tests/common/test_msgpack_packet.py
  6. 20
      tests/common/test_server.py

30
src/socketio/msgpack_packet.py

@ -4,14 +4,40 @@ from . import packet
class MsgPackPacket(packet.Packet): class MsgPackPacket(packet.Packet):
uses_binary_events = False uses_binary_events = False
dumps_default = None
ext_hook = msgpack.ExtType
@classmethod
def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType):
"""Change the default options for msgpack encoding and decoding.
:param dumps_default: a function called for objects that cannot be
serialized by default msgpack. The function
receives one argument, the object to serialize.
It should return a serializable object or a
``msgpack.ExtType`` instance.
:param ext_hook: a function called when a ``msgpack.ExtType`` object is
seen during decoding. The function receives two
arguments, the code and the data. It should return the
decoded object.
"""
class CustomMsgPackPacket(MsgPackPacket):
dumps_default = None
ext_hook = None
CustomMsgPackPacket.dumps_default = dumps_default
CustomMsgPackPacket.ext_hook = ext_hook
return CustomMsgPackPacket
def encode(self): def encode(self):
"""Encode the packet for transmission.""" """Encode the packet for transmission."""
return msgpack.dumps(self._to_dict()) return msgpack.dumps(self._to_dict(),
default=self.__class__.dumps_default)
def decode(self, encoded_packet): def decode(self, encoded_packet):
"""Decode a transmitted package.""" """Decode a transmitted package."""
decoded = msgpack.loads(encoded_packet) decoded = msgpack.loads(encoded_packet,
ext_hook=self.__class__.ext_hook)
self.packet_type = decoded['type'] self.packet_type = decoded['type']
self.data = decoded.get('data') self.data = decoded.get('data')
self.id = decoded.get('id') self.id = decoded.get('id')

20
tests/async/test_client.py

@ -1,5 +1,6 @@
import asyncio import asyncio
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
import pytest import pytest
@ -8,6 +9,7 @@ from socketio import async_namespace
from engineio import exceptions as engineio_exceptions from engineio import exceptions as engineio_exceptions
from socketio import exceptions from socketio import exceptions
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestAsyncClient: class TestAsyncClient:
@ -1242,3 +1244,21 @@ class TestAsyncClient:
assert c.sid is None assert c.sid is None
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
c = async_client.AsyncClient(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

20
tests/async/test_server.py

@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json from engineio import json
from engineio import packet as eio_packet from engineio import packet as eio_packet
@ -11,6 +12,7 @@ from socketio import async_namespace
from socketio import exceptions from socketio import exceptions
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.AsyncServer', **{ @mock.patch('socketio.server.engineio.AsyncServer', **{
@ -1089,3 +1091,21 @@ class TestAsyncServer:
s = async_server.AsyncServer() s = async_server.AsyncServer()
await s.sleep(1.23) await s.sleep(1.23)
s.eio.sleep.assert_awaited_once_with(1.23) s.eio.sleep.assert_awaited_once_with(1.23)
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
s = async_server.AsyncServer(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

20
tests/common/test_client.py

@ -1,6 +1,7 @@
import logging import logging
import time import time
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import exceptions as engineio_exceptions from engineio import exceptions as engineio_exceptions
from engineio import json from engineio import json
@ -13,6 +14,7 @@ from socketio import exceptions
from socketio import msgpack_packet from socketio import msgpack_packet
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio.msgpack_packet import MsgPackPacket
class TestClient: class TestClient:
@ -1386,3 +1388,21 @@ class TestClient:
assert c.sid is None assert c.sid is None
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
c = client.Client(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

104
tests/common/test_msgpack_packet.py

@ -1,3 +1,8 @@
from datetime import datetime, timedelta, timezone
import pytest
import msgpack
from socketio import msgpack_packet from socketio import msgpack_packet
from socketio import packet from socketio import packet
@ -32,3 +37,102 @@ class TestMsgPackPacket:
assert p.packet_type == packet.ACK assert p.packet_type == packet.ACK
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p2.data == {'foo': b'bar'} assert p2.data == {'foo': b'bar'}
def test_encode_with_dumps_default(self):
def default(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError('Unknown type')
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'current' in p2.data
assert isinstance(p2.data['current'], str)
assert default(data['current']) == p2.data['current']
data.pop('current')
p2_data_without_current = p2.data.copy()
p2_data_without_current.pop('current')
assert data == p2_data_without_current
def test_encode_without_dumps_default(self):
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p_without_default = msgpack_packet.MsgPackPacket(data=data)
with pytest.raises(TypeError):
p_without_default.encode()
def test_encode_decode_with_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
def ext_hook(code, data):
if code == 1:
return Custom(data)
raise TypeError('Unknown ext type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)(
encoded_packet=p.encode()
)
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.data == p2.data
assert p.namespace == p2.namespace
def test_encode_decode_without_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'custom' in p2.data
assert isinstance(p2.data['custom'], msgpack.ExtType)
assert p2.data['custom'].code == 1
assert p2.data['custom'].data == b'custom_data'
data.pop('custom')
p2_data_without_custom = p2.data.copy()
p2_data_without_custom.pop('custom')
assert data == p2_data_without_custom

20
tests/common/test_server.py

@ -1,5 +1,6 @@
import logging import logging
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json from engineio import json
from engineio import packet as eio_packet from engineio import packet as eio_packet
@ -10,6 +11,7 @@ from socketio import msgpack_packet
from socketio import namespace from socketio import namespace
from socketio import packet from socketio import packet
from socketio import server from socketio import server
from socketio.msgpack_packet import MsgPackPacket
@mock.patch('socketio.server.engineio.Server', **{ @mock.patch('socketio.server.engineio.Server', **{
@ -1032,3 +1034,21 @@ class TestServer:
s = server.Server() s = server.Server()
s.sleep(1.23) s.sleep(1.23)
s.eio.sleep.assert_called_once_with(1.23) s.eio.sleep.assert_called_once_with(1.23)
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
data = {"current": datetime.now(timezone(timedelta(0)))}
s = server.Server(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]

Loading…
Cancel
Save