Browse Source

Revert all

This reverts commit db7fcdaaf7.
pull/1517/head
phi 8 months ago
parent
commit
aac7612685
  1. 13
      src/socketio/async_client.py
  2. 19
      src/socketio/async_server.py
  3. 7
      src/socketio/base_client.py
  4. 7
      src/socketio/base_server.py
  5. 16
      src/socketio/client.py
  6. 24
      src/socketio/msgpack_packet.py
  7. 19
      src/socketio/server.py
  8. 31
      tests/async/test_client.py
  9. 31
      tests/async/test_server.py
  10. 29
      tests/common/test_client.py
  11. 107
      tests/common/test_msgpack_packet.py
  12. 29
      tests/common/test_server.py

13
src/socketio/async_client.py

@ -45,9 +45,6 @@ class AsyncClient(base_client.BaseClient):
leave interrupt handling to the calling application. leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the Interrupt handling can only be enabled when the
client instance is created in the main thread. client instance is created in the main thread.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -246,7 +243,7 @@ class AsyncClient(base_client.BaseClient):
data = [data] data = [data]
else: else:
data = [] data = []
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id)) packet.EVENT, namespace=namespace, data=[event] + data, id=id))
async def send(self, data, namespace=None, callback=None): async def send(self, data, namespace=None, callback=None):
@ -328,7 +325,7 @@ class AsyncClient(base_client.BaseClient):
# here we just request the disconnection # here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler # later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces: for n in self.namespaces:
await self._send_packet(self._create_packet(packet.DISCONNECT, await self._send_packet(self.packet_class(packet.DISCONNECT,
namespace=n)) namespace=n))
await self.eio.disconnect() await self.eio.disconnect()
@ -425,7 +422,7 @@ class AsyncClient(base_client.BaseClient):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, namespace, id, data): async def _handle_ack(self, namespace, id, data):
@ -558,7 +555,7 @@ class AsyncClient(base_client.BaseClient):
self.sid = self.eio.sid self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth) or {} real_auth = await self._get_real_value(self.connection_auth) or {}
for n in self.connection_namespaces: for n in self.connection_namespaces:
await self._send_packet(self._create_packet( await self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n)) packet.CONNECT, data=real_auth, namespace=n))
async def _handle_eio_message(self, data): async def _handle_eio_message(self, data):
@ -572,7 +569,7 @@ class AsyncClient(base_client.BaseClient):
else: else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data) await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace, pkt.data) await self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

19
src/socketio/async_server.py

@ -50,9 +50,6 @@ class AsyncServer(base_server.BaseServer):
default is `['/']`, which always accepts connections to default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all the default namespace. Set to `'*'` to accept all
namespaces. namespaces.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -428,7 +425,7 @@ class AsyncServer(base_server.BaseServer):
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace)) packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid, await self._trigger_event('disconnect', namespace, sid,
self.reason.SERVER_DISCONNECT) self.reason.SERVER_DISCONNECT)
@ -541,13 +538,13 @@ class AsyncServer(base_server.BaseServer):
or self.namespaces == '*' or namespace in self.namespaces: or self.namespaces == '*' or namespace in self.namespaces:
sid = await self.manager.connect(eio_sid, namespace) sid = await self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data='Unable to connect', packet.CONNECT_ERROR, data='Unable to connect',
namespace=namespace)) namespace=namespace))
return return
if self.always_connect: if self.always_connect:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
@ -571,15 +568,15 @@ class AsyncServer(base_server.BaseServer):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
else: else:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
await self.manager.disconnect(sid, namespace, ignore_queue=True) await self.manager.disconnect(sid, namespace, ignore_queue=True)
elif not self.always_connect: elif not self.always_connect:
await self._send_packet(eio_sid, self._create_packet( await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
async def _handle_disconnect(self, eio_sid, namespace, reason=None): async def _handle_disconnect(self, eio_sid, namespace, reason=None):
@ -625,7 +622,7 @@ class AsyncServer(base_server.BaseServer):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
await server._send_packet(eio_sid, self._create_packet( await server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
async def _handle_ack(self, eio_sid, namespace, id, data): async def _handle_ack(self, eio_sid, namespace, id, data):
@ -689,7 +686,7 @@ class AsyncServer(base_server.BaseServer):
await self._handle_ack(eio_sid, pkt.namespace, pkt.id, await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data) pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace, pkt.data) await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

7
src/socketio/base_client.py

@ -38,8 +38,7 @@ class BaseClient:
def __init__(self, reconnection=True, reconnection_attempts=0, def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5, reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, serializer='default', randomization_factor=0.5, logger=False, serializer='default',
json=None, handle_sigint=True, serializer_args=None, json=None, handle_sigint=True, **kwargs):
**kwargs):
global original_signal_handler global original_signal_handler
if handle_sigint and original_signal_handler is None and \ if handle_sigint and original_signal_handler is None and \
threading.current_thread() == threading.main_thread(): threading.current_thread() == threading.main_thread():
@ -64,7 +63,6 @@ class BaseClient:
self.packet_class = msgpack_packet.MsgPackPacket self.packet_class = msgpack_packet.MsgPackPacket
else: else:
self.packet_class = serializer self.packet_class = serializer
self.packet_class_args = serializer_args or {}
if json is not None: if json is not None:
self.packet_class.json = json self.packet_class.json = json
engineio_options['json'] = json engineio_options['json'] = json
@ -285,9 +283,6 @@ class BaseClient:
self.callbacks[namespace][id] = callback self.callbacks[namespace][id] = callback
return id return id
def _create_packet(self, *args, **kwargs):
return self.packet_class(*args, **kwargs, **self.packet_class_args)
def _handle_eio_connect(self): # pragma: no cover def _handle_eio_connect(self): # pragma: no cover
raise NotImplementedError() raise NotImplementedError()

7
src/socketio/base_server.py

@ -15,7 +15,7 @@ class BaseServer:
def __init__(self, client_manager=None, logger=False, serializer='default', def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False, json=None, async_handlers=True, always_connect=False,
namespaces=None, serializer_args=None, **kwargs): namespaces=None, **kwargs):
engineio_options = kwargs engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None) engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None: if engineio_logger is not None:
@ -27,7 +27,6 @@ class BaseServer:
self.packet_class = msgpack_packet.MsgPackPacket self.packet_class = msgpack_packet.MsgPackPacket
else: else:
self.packet_class = serializer self.packet_class = serializer
self.packet_class_args = serializer_args or {}
if json is not None: if json is not None:
self.packet_class.json = json self.packet_class.json = json
engineio_options['json'] = json engineio_options['json'] = json
@ -254,10 +253,6 @@ class BaseServer:
args = (namespace, *args) args = (namespace, *args)
return handler, args return handler, args
def _create_packet(self, *args, **kwargs):
return self.packet_class(*args, **kwargs,
**self.packet_class_args)
def _handle_eio_connect(self): # pragma: no cover def _handle_eio_connect(self): # pragma: no cover
raise NotImplementedError() raise NotImplementedError()

16
src/socketio/client.py

@ -48,9 +48,6 @@ class Client(base_client.BaseClient):
leave interrupt handling to the calling application. leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the Interrupt handling can only be enabled when the
client instance is created in the main thread. client instance is created in the main thread.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -237,9 +234,8 @@ class Client(base_client.BaseClient):
data = [data] data = [data]
else: else:
data = [] data = []
self._send_packet( self._send_packet(self.packet_class(packet.EVENT, namespace=namespace,
self._create_packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id))
data=[event] + data, id=id))
def send(self, data, namespace=None, callback=None): def send(self, data, namespace=None, callback=None):
"""Send a message to the server. """Send a message to the server.
@ -311,7 +307,7 @@ class Client(base_client.BaseClient):
# here we just request the disconnection # here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler # later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces: for n in self.namespaces:
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.DISCONNECT, namespace=n)) packet.DISCONNECT, namespace=n))
self.eio.disconnect() self.eio.disconnect()
@ -406,7 +402,7 @@ class Client(base_client.BaseClient):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, namespace, id, data): def _handle_ack(self, namespace, id, data):
@ -510,7 +506,7 @@ class Client(base_client.BaseClient):
self.sid = self.eio.sid self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth) or {} real_auth = self._get_real_value(self.connection_auth) or {}
for n in self.connection_namespaces: for n in self.connection_namespaces:
self._send_packet(self._create_packet( self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n)) packet.CONNECT, data=real_auth, namespace=n))
def _handle_eio_message(self, data): def _handle_eio_message(self, data):
@ -524,7 +520,7 @@ class Client(base_client.BaseClient):
else: else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data) self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace, pkt.data) self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

24
src/socketio/msgpack_packet.py

@ -5,33 +5,13 @@ from . import packet
class MsgPackPacket(packet.Packet): class MsgPackPacket(packet.Packet):
uses_binary_events = False uses_binary_events = False
def __init__(
self,
packet_type=packet.EVENT,
data=None,
namespace=None,
id=None,
binary=None,
encoded_packet=None,
dumps_default=None,
ext_hook=None,
):
self.dumps_default = dumps_default
self.ext_hook = ext_hook
super().__init__(
packet_type, data, namespace, id, binary, encoded_packet
)
def encode(self): def encode(self):
"""Encode the packet for transmission.""" """Encode the packet for transmission."""
return msgpack.dumps(self._to_dict(), default=self.dumps_default) return msgpack.dumps(self._to_dict())
def decode(self, encoded_packet): def decode(self, encoded_packet):
"""Decode a transmitted package.""" """Decode a transmitted package."""
if self.ext_hook is None: decoded = msgpack.loads(encoded_packet)
decoded = msgpack.loads(encoded_packet)
else:
decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook)
self.packet_type = decoded['type'] self.packet_type = decoded['type']
self.data = decoded.get('data') self.data = decoded.get('data')
self.id = decoded.get('id') self.id = decoded.get('id')

19
src/socketio/server.py

@ -53,9 +53,6 @@ class Server(base_server.BaseServer):
default is `['/']`, which always accepts connections to default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all the default namespace. Set to `'*'` to accept all
namespaces. namespaces.
:param serializer_args: A mapping of additional parameters to pass to
the serializer. The content of this dictionary
depends on the selected serialization method.
:param kwargs: Connection parameters for the underlying Engine.IO server. :param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings: The Engine.IO configuration supports the following settings:
@ -404,7 +401,7 @@ class Server(base_server.BaseServer):
if delete_it: if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace) self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace)) packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid, self._trigger_event('disconnect', namespace, sid,
self.reason.SERVER_DISCONNECT) self.reason.SERVER_DISCONNECT)
@ -523,13 +520,13 @@ class Server(base_server.BaseServer):
or self.namespaces == '*' or namespace in self.namespaces: or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace) sid = self.manager.connect(eio_sid, namespace)
if sid is None: if sid is None:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data='Unable to connect', packet.CONNECT_ERROR, data='Unable to connect',
namespace=namespace)) namespace=namespace))
return return
if self.always_connect: if self.always_connect:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args fail_reason = exceptions.ConnectionRefusedError().error_args
try: try:
@ -553,15 +550,15 @@ class Server(base_server.BaseServer):
if success is False: if success is False:
if self.always_connect: if self.always_connect:
self.manager.pre_disconnect(sid, namespace) self.manager.pre_disconnect(sid, namespace)
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace)) packet.DISCONNECT, data=fail_reason, namespace=namespace))
else: else:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason, packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace)) namespace=namespace))
self.manager.disconnect(sid, namespace, ignore_queue=True) self.manager.disconnect(sid, namespace, ignore_queue=True)
elif not self.always_connect: elif not self.always_connect:
self._send_packet(eio_sid, self._create_packet( self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace)) packet.CONNECT, {'sid': sid}, namespace=namespace))
def _handle_disconnect(self, eio_sid, namespace, reason=None): def _handle_disconnect(self, eio_sid, namespace, reason=None):
@ -604,7 +601,7 @@ class Server(base_server.BaseServer):
data = list(r) data = list(r)
else: else:
data = [r] data = [r]
server._send_packet(eio_sid, self._create_packet( server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data)) packet.ACK, namespace=namespace, id=id, data=data))
def _handle_ack(self, eio_sid, namespace, id, data): def _handle_ack(self, eio_sid, namespace, id, data):
@ -653,7 +650,7 @@ class Server(base_server.BaseServer):
else: else:
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else: else:
pkt = self._create_packet(encoded_packet=data) pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT: if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace, pkt.data) self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT: elif pkt.packet_type == packet.DISCONNECT:

31
tests/async/test_client.py

@ -1,6 +1,5 @@
import asyncio import asyncio
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
import pytest import pytest
@ -1243,33 +1242,3 @@ class TestAsyncClient:
assert c.sid is None assert c.sid is None
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args(self):
args = {"foo": "bar"}
c = async_client.AsyncClient(serializer_args=args)
assert c.packet_class_args == args
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))}
c = async_client.AsyncClient(serializer='msgpack',
serializer_args=args)
p = c._create_packet(data=data)
p2 = c._create_packet(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self):
args = {"invalid_arg": 123}
c = async_client.AsyncClient(serializer='msgpack',
serializer_args=args)
with pytest.raises(TypeError):
c._create_packet(data={"foo": "bar"}).encode()

31
tests/async/test_server.py

@ -1,7 +1,6 @@
import asyncio import asyncio
import logging import logging
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json from engineio import json
from engineio import packet as eio_packet from engineio import packet as eio_packet
@ -1090,33 +1089,3 @@ class TestAsyncServer:
s = async_server.AsyncServer() s = async_server.AsyncServer()
await s.sleep(1.23) await s.sleep(1.23)
s.eio.sleep.assert_awaited_once_with(1.23) s.eio.sleep.assert_awaited_once_with(1.23)
def test_serializer_args(self, eio):
args = {"foo": "bar"}
s = async_server.AsyncServer(serializer_args=args)
assert s.packet_class_args == args
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))}
s = async_server.AsyncServer(serializer='msgpack',
serializer_args=args)
p = s._create_packet(data=data)
p2 = s._create_packet(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self, eio):
args = {"invalid_arg": 123}
s = async_server.AsyncServer(serializer='msgpack',
serializer_args=args)
with pytest.raises(TypeError):
s._create_packet(data={"foo": "bar"}).encode()

29
tests/common/test_client.py

@ -1,7 +1,6 @@
import logging import logging
import time import time
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import exceptions as engineio_exceptions from engineio import exceptions as engineio_exceptions
from engineio import json from engineio import json
@ -1387,31 +1386,3 @@ class TestClient:
assert c.sid is None assert c.sid is None
assert not c.connected assert not c.connected
c.start_background_task.assert_not_called() c.start_background_task.assert_not_called()
def test_serializer_args(self):
args = {"foo": "bar"}
c = client.Client(serializer_args=args)
assert c.packet_class_args == args
def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))}
c = client.Client(serializer='msgpack', serializer_args=args)
p = c._create_packet(data=data)
p2 = c._create_packet(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self):
args = {"invalid_arg": 123}
c = client.Client(serializer='msgpack', serializer_args=args)
with pytest.raises(TypeError):
c._create_packet(data={"foo": "bar"}).encode()

107
tests/common/test_msgpack_packet.py

@ -1,8 +1,3 @@
from datetime import datetime, timedelta, timezone
import pytest
import msgpack
from socketio import msgpack_packet from socketio import msgpack_packet
from socketio import packet from socketio import packet
@ -10,8 +5,7 @@ from socketio import packet
class TestMsgPackPacket: class TestMsgPackPacket:
def test_encode_decode(self): def test_encode_decode(self):
p = msgpack_packet.MsgPackPacket( p = msgpack_packet.MsgPackPacket(
packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo')
)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.data == p2.data assert p.data == p2.data
@ -20,8 +14,7 @@ class TestMsgPackPacket:
def test_encode_decode_with_id(self): def test_encode_decode_with_id(self):
p = msgpack_packet.MsgPackPacket( p = msgpack_packet.MsgPackPacket(
packet.EVENT, data=['ev', 42], id=123, namespace='/foo' packet.EVENT, data=['ev', 42], id=123, namespace='/foo')
)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type assert p.packet_type == p2.packet_type
assert p.data == p2.data assert p.data == p2.data
@ -39,99 +32,3 @@ class TestMsgPackPacket:
assert p.packet_type == packet.ACK assert p.packet_type == packet.ACK
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p2.data == {'foo': b'bar'} assert p2.data == {'foo': b'bar'}
def test_encode_with_dumps_default(self):
def default(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError('Unknown type')
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'current' in p2.data
assert isinstance(p2.data['current'], str)
assert default(data['current']) == p2.data['current']
data.pop('current')
p2_data_without_current = p2.data.copy()
p2_data_without_current.pop('current')
assert data == p2_data_without_current
def test_encode_without_dumps_default(self):
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p_without_default = msgpack_packet.MsgPackPacket(data=data)
with pytest.raises(TypeError):
p_without_default.encode()
def test_encode_decode_with_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
def ext_hook(code, data):
if code == 1:
return Custom(data)
raise TypeError('Unknown ext type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default)
p2 = msgpack_packet.MsgPackPacket(
encoded_packet=p.encode(), ext_hook=ext_hook
)
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.data == p2.data
assert p.namespace == p2.namespace
def test_encode_decode_without_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value
def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value
def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert 'custom' in p2.data
assert isinstance(p2.data['custom'], msgpack.ExtType)
assert p2.data['custom'].code == 1
assert p2.data['custom'].data == b'custom_data'
data.pop('custom')
p2_data_without_custom = p2.data.copy()
p2_data_without_custom.pop('custom')
assert data == p2_data_without_custom

29
tests/common/test_server.py

@ -1,6 +1,5 @@
import logging import logging
from unittest import mock from unittest import mock
from datetime import datetime, timezone, timedelta
from engineio import json from engineio import json
from engineio import packet as eio_packet from engineio import packet as eio_packet
@ -1033,31 +1032,3 @@ class TestServer:
s = server.Server() s = server.Server()
s.sleep(1.23) s.sleep(1.23)
s.eio.sleep.assert_called_once_with(1.23) s.eio.sleep.assert_called_once_with(1.23)
def test_serializer_args(self, eio):
args = {"foo": "bar"}
s = server.Server(serializer_args=args)
assert s.packet_class_args == args
def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")
args = {"dumps_default": default}
data = {"current": datetime.now(timezone(timedelta(0)))}
s = server.Server(serializer='msgpack', serializer_args=args)
p = s._create_packet(data=data)
p2 = s._create_packet(encoded_packet=p.encode())
assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
def test_invalid_serializer_args(self, eio):
args = {"invalid_arg": 123}
s = server.Server(serializer='msgpack', serializer_args=args)
with pytest.raises(TypeError):
s._create_packet(data={"foo": "bar"}).encode()

Loading…
Cancel
Save