From 17eac0cca339146c990680b4323eccf4b02c02c8 Mon Sep 17 00:00:00 2001 From: Rossen Georgiev Date: Mon, 28 Nov 2016 22:13:28 +0000 Subject: [PATCH] refactor proto_fill_from_dict & proto_to_dict; #56 * added tests for both function * both function now support py3 iterators types like map, range, filter * proto_fill_from_dict will correctly overwrite lists when clear=False --- protobufs/test_messages.proto | 25 +++ steam/protobufs/test_messages_pb2.py | 238 +++++++++++++++++++++++++++ steam/util/__init__.py | 41 +++-- tests/test_util.py | 115 +++++++++++++ 4 files changed, 409 insertions(+), 10 deletions(-) create mode 100644 protobufs/test_messages.proto create mode 100644 steam/protobufs/test_messages_pb2.py diff --git a/protobufs/test_messages.proto b/protobufs/test_messages.proto new file mode 100644 index 0000000..8e2ebee --- /dev/null +++ b/protobufs/test_messages.proto @@ -0,0 +1,25 @@ +syntax = "proto2"; + +message ComplexProtoMessage { + message InnerMessage { + optional string text = 1; + repeated uint32 numbers = 2; + } + message InnerBuffer { + message Flags { + optional bool flag = 1; + } + + optional bytes data = 1; + repeated .ComplexProtoMessage.InnerBuffer.Flags flags = 2; + } + + optional uint32 number32 = 1; + optional uint64 number64 = 2; + + repeated uint32 list_number32 = 3; + repeated uint64 list_number64 = 4; + + repeated .ComplexProtoMessage.InnerMessage messages = 5; + repeated .ComplexProtoMessage.InnerBuffer buffers = 6; +} diff --git a/steam/protobufs/test_messages_pb2.py b/steam/protobufs/test_messages_pb2.py new file mode 100644 index 0000000..43456bf --- /dev/null +++ b/steam/protobufs/test_messages_pb2.py @@ -0,0 +1,238 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: test_messages.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='test_messages.proto', + package='', + syntax='proto2', + serialized_pb=_b('\n\x13test_messages.proto\"\xe9\x02\n\x13\x43omplexProtoMessage\x12\x10\n\x08number32\x18\x01 \x01(\r\x12\x10\n\x08number64\x18\x02 \x01(\x04\x12\x15\n\rlist_number32\x18\x03 \x03(\r\x12\x15\n\rlist_number64\x18\x04 \x03(\x04\x12\x33\n\x08messages\x18\x05 \x03(\x0b\x32!.ComplexProtoMessage.InnerMessage\x12\x31\n\x07\x62uffers\x18\x06 \x03(\x0b\x32 .ComplexProtoMessage.InnerBuffer\x1a-\n\x0cInnerMessage\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07numbers\x18\x02 \x03(\r\x1ai\n\x0bInnerBuffer\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x35\n\x05\x66lags\x18\x02 \x03(\x0b\x32&.ComplexProtoMessage.InnerBuffer.Flags\x1a\x15\n\x05\x46lags\x12\x0c\n\x04\x66lag\x18\x01 \x01(\x08') +) +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + + + +_COMPLEXPROTOMESSAGE_INNERMESSAGE = _descriptor.Descriptor( + name='InnerMessage', + full_name='ComplexProtoMessage.InnerMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='text', full_name='ComplexProtoMessage.InnerMessage.text', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='numbers', full_name='ComplexProtoMessage.InnerMessage.numbers', index=1, + number=2, type=13, cpp_type=3, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=233, + serialized_end=278, +) + +_COMPLEXPROTOMESSAGE_INNERBUFFER_FLAGS = _descriptor.Descriptor( + name='Flags', + full_name='ComplexProtoMessage.InnerBuffer.Flags', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='flag', full_name='ComplexProtoMessage.InnerBuffer.Flags.flag', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=364, + serialized_end=385, +) + +_COMPLEXPROTOMESSAGE_INNERBUFFER = _descriptor.Descriptor( + name='InnerBuffer', + full_name='ComplexProtoMessage.InnerBuffer', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='data', full_name='ComplexProtoMessage.InnerBuffer.data', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='flags', full_name='ComplexProtoMessage.InnerBuffer.flags', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[_COMPLEXPROTOMESSAGE_INNERBUFFER_FLAGS, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=280, + serialized_end=385, +) + +_COMPLEXPROTOMESSAGE = _descriptor.Descriptor( + name='ComplexProtoMessage', + full_name='ComplexProtoMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='number32', full_name='ComplexProtoMessage.number32', index=0, + number=1, type=13, cpp_type=3, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='number64', full_name='ComplexProtoMessage.number64', index=1, + number=2, type=4, cpp_type=4, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='list_number32', full_name='ComplexProtoMessage.list_number32', index=2, + number=3, type=13, cpp_type=3, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='list_number64', full_name='ComplexProtoMessage.list_number64', index=3, + number=4, type=4, cpp_type=4, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='messages', full_name='ComplexProtoMessage.messages', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='buffers', full_name='ComplexProtoMessage.buffers', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[_COMPLEXPROTOMESSAGE_INNERMESSAGE, _COMPLEXPROTOMESSAGE_INNERBUFFER, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=24, + serialized_end=385, +) + +_COMPLEXPROTOMESSAGE_INNERMESSAGE.containing_type = _COMPLEXPROTOMESSAGE +_COMPLEXPROTOMESSAGE_INNERBUFFER_FLAGS.containing_type = _COMPLEXPROTOMESSAGE_INNERBUFFER +_COMPLEXPROTOMESSAGE_INNERBUFFER.fields_by_name['flags'].message_type = _COMPLEXPROTOMESSAGE_INNERBUFFER_FLAGS +_COMPLEXPROTOMESSAGE_INNERBUFFER.containing_type = _COMPLEXPROTOMESSAGE +_COMPLEXPROTOMESSAGE.fields_by_name['messages'].message_type = _COMPLEXPROTOMESSAGE_INNERMESSAGE +_COMPLEXPROTOMESSAGE.fields_by_name['buffers'].message_type = _COMPLEXPROTOMESSAGE_INNERBUFFER +DESCRIPTOR.message_types_by_name['ComplexProtoMessage'] = _COMPLEXPROTOMESSAGE + +ComplexProtoMessage = _reflection.GeneratedProtocolMessageType('ComplexProtoMessage', (_message.Message,), dict( + + InnerMessage = _reflection.GeneratedProtocolMessageType('InnerMessage', (_message.Message,), dict( + DESCRIPTOR = _COMPLEXPROTOMESSAGE_INNERMESSAGE, + __module__ = 'test_messages_pb2' + # @@protoc_insertion_point(class_scope:ComplexProtoMessage.InnerMessage) + )) + , + + InnerBuffer = _reflection.GeneratedProtocolMessageType('InnerBuffer', (_message.Message,), dict( + + Flags = _reflection.GeneratedProtocolMessageType('Flags', (_message.Message,), dict( + DESCRIPTOR = _COMPLEXPROTOMESSAGE_INNERBUFFER_FLAGS, + __module__ = 'test_messages_pb2' + # @@protoc_insertion_point(class_scope:ComplexProtoMessage.InnerBuffer.Flags) + )) + , + DESCRIPTOR = _COMPLEXPROTOMESSAGE_INNERBUFFER, + __module__ = 'test_messages_pb2' + # @@protoc_insertion_point(class_scope:ComplexProtoMessage.InnerBuffer) + )) + , + DESCRIPTOR = _COMPLEXPROTOMESSAGE, + __module__ = 'test_messages_pb2' + # @@protoc_insertion_point(class_scope:ComplexProtoMessage) + )) +_sym_db.RegisterMessage(ComplexProtoMessage) +_sym_db.RegisterMessage(ComplexProtoMessage.InnerMessage) +_sym_db.RegisterMessage(ComplexProtoMessage.InnerBuffer) +_sym_db.RegisterMessage(ComplexProtoMessage.InnerBuffer.Flags) + + +# @@protoc_insertion_point(module_scope) diff --git a/steam/util/__init__.py b/steam/util/__init__.py index e540116..07bf11a 100644 --- a/steam/util/__init__.py +++ b/steam/util/__init__.py @@ -4,7 +4,16 @@ import weakref import struct import socket import sys +import six from six.moves import xrange as _range +from types import GeneratorType as _GeneratorType +from google.protobuf.internal.python_message import GeneratedProtocolMessageType as _ProtoMessageType + +if six.PY2: + _list_types = list, xrange, _GeneratorType +else: + _list_types = list, range, _GeneratorType, map, filter + def ip_from_int(ip): """Convert IP to :py:class:`int` @@ -61,13 +70,17 @@ def proto_to_dict(message): :param message: protobuf message instance :return: parameters and their values :rtype: dict + :raises: :class:`.TypeError` if ``message`` is not a proto message """ + if not isinstance(message.__class__, _ProtoMessageType): + raise TypeError("Expected `message` to be a instance of protobuf message") + data = {} for desc, field in message.ListFields(): if desc.type == desc.TYPE_MESSAGE: if desc.label == desc.LABEL_REPEATED: - data[desc.name] = map(proto_to_dict, field) + data[desc.name] = list(map(proto_to_dict, field)) else: data[desc.name] = proto_to_dict(field) else: @@ -86,6 +99,11 @@ def proto_fill_from_dict(message, data, clear=True): :return: value of message paramater :raises: incorrect types or values will raise """ + if not isinstance(message.__class__, _ProtoMessageType): + raise TypeError("Expected `message` to be a instance of protobuf message") + if not isinstance(data, dict): + raise TypeError("Expected `data` to be of type `dict`") + if clear: message.Clear() field_descs = message.DESCRIPTOR.fields_by_name @@ -94,24 +112,27 @@ def proto_fill_from_dict(message, data, clear=True): if desc.type == desc.TYPE_MESSAGE: if desc.label == desc.LABEL_REPEATED: - if not isinstance(val, list): - raise TypeError("Expected %s to be of type list, got %s" % ( - repr(key), type(val) - )) + if not isinstance(val, _list_types): + raise TypeError("Expected %s to be of type list, got %s" % (repr(key), type(val))) + + list_ref = getattr(message, key) + + # Takes care of overwriting list fields when merging partial data (clear=False) + if not clear: del list_ref[:] # clears the list for item in val: item_message = getattr(message, key).add() proto_fill_from_dict(item_message, item) else: if not isinstance(val, dict): - raise TypeError("Expected %s to be of type dict, got %s" % ( - repr(key), type(dict) - )) + raise TypeError("Expected %s to be of type dict, got %s" % (repr(key), type(dict))) proto_fill_from_dict(getattr(message, key), val) else: - if isinstance(val, list): - getattr(message, key).extend(val) + if isinstance(val, _list_types): + list_ref = getattr(message, key) + if not clear: del list_ref[:] # clears the list + list_ref.extend(val) else: setattr(message, key, val) diff --git a/tests/test_util.py b/tests/test_util.py index 7fc7efb..7a69719 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,6 +2,7 @@ import unittest import steam.util as ut import steam.util.web as uweb import requests +from steam.protobufs.test_messages_pb2 import ComplexProtoMessage proto_mask = 0x80000000 @@ -37,3 +38,117 @@ class Util_Functions(unittest.TestCase): def test_make_requests_session(self): self.assertIsInstance(uweb.make_requests_session(), requests.Session) + +class Util_Proto(unittest.TestCase): + def setUp(self): + self.msg = ComplexProtoMessage() + + def cleanUp(self): + self.msg.Clear() + + def test_proto_from_dict_to_dict(self): + DATA = {'buffers': [ + {'data': b'some data', 'flags': [{'flag': True}, {'flag': False}, {'flag': False}]}, + {'data': b'\x01\x02\x03\x04', 'flags': [{'flag': False}, {'flag': True}, {'flag': True}]} + ], + 'list_number32': [4,16,64,256,1024,4096,16384,65536,262144,1048576,4194304,16777216,67108864,268435456,1073741824], + 'list_number64': [4,64,1024,16384,262144,1125899906842624,18014398509481984,288230376151711744,4611686018427387904], + 'messages': [{'text': 'test string'}, {'text': 'another one'}, {'text': 'third'}], + 'number32': 16777216, + 'number64': 72057594037927936 + } + + ut.proto_fill_from_dict(self.msg, DATA) + + RESULT = ut.proto_to_dict(self.msg) + + self.assertEqual(DATA, RESULT) + + def test_proto_from_dict_merge(self): + self.msg.list_number32.extend([1,2,3]) + + ut.proto_fill_from_dict(self.msg, {'list_number32': [4,5,6]}, clear=False) + + self.assertEqual(self.msg.list_number32, [4,5,6]) + + def test_proto_from_dict_merge_dict(self): + self.msg.messages.add(text='one') + self.msg.messages.add(text='two') + + ut.proto_fill_from_dict(self.msg, {'messages': [{'text': 'three'}]}, clear=False) + + self.assertEqual(len(self.msg.messages), 1) + self.assertEqual(self.msg.messages[0].text, 'three') + + def test_proto_from_dict__dict_insteadof_list(self): + with self.assertRaises(TypeError): + ut.proto_fill_from_dict(self.msg, {'list_number32': [{}, {}]}) + + def test_proto_from_dict__list_insteadof_dict(self): + with self.assertRaises(TypeError): + ut.proto_fill_from_dict(self.msg, {'messages': [1,2,3]}) + + def test_proto_fill_from_dict__list(self): + ut.proto_fill_from_dict(self.msg, {'list_number32': [1,2,3]}) + self.assertEqual(self.msg.list_number32, [1,2,3]) + + def test_proto_fill_from_dict__dict_list(self): + ut.proto_fill_from_dict(self.msg, {'messages': [{'text': 'one'}, {'text': 'two'}]}) + self.assertEqual(len(self.msg.messages), 2) + self.assertEqual(self.msg.messages[0].text, 'one') + self.assertEqual(self.msg.messages[1].text, 'two') + + def test_proto_fill_from_dict__list(self): + ut.proto_fill_from_dict(self.msg, {'list_number32': range(10)}) + self.assertEqual(self.msg.list_number32, list(range(10))) + + + def test_proto_fill_from_dict__generator(self): + ut.proto_fill_from_dict(self.msg, {'list_number32': (x for x in [1,2,3])}) + self.assertEqual(self.msg.list_number32, [1,2,3]) + + def test_proto_fill_from_dict__dict_generator(self): + ut.proto_fill_from_dict(self.msg, {'messages': (x for x in [{'text': 'one'}, {'text': 'two'}])}) + self.assertEqual(len(self.msg.messages), 2) + self.assertEqual(self.msg.messages[0].text, 'one') + self.assertEqual(self.msg.messages[1].text, 'two') + + def test_proto_fill_from_dict__func_generator(self): + def number_gen(): + yield 1 + yield 2 + yield 3 + + ut.proto_fill_from_dict(self.msg, {'list_number32': number_gen()}) + self.assertEqual(self.msg.list_number32, [1,2,3]) + + def test_proto_fill_from_dict__dict_func_generator(self): + def dict_gen(): + yield {'text': 'one'} + yield {'text': 'two'} + + ut.proto_fill_from_dict(self.msg, {'messages': dict_gen()}) + self.assertEqual(len(self.msg.messages), 2) + self.assertEqual(self.msg.messages[0].text, 'one') + self.assertEqual(self.msg.messages[1].text, 'two') + + + def test_proto_fill_from_dict__map(self): + ut.proto_fill_from_dict(self.msg, {'list_number32': map(int, [1,2,3])}) + self.assertEqual(self.msg.list_number32, [1,2,3]) + + def test_proto_fill_from_dict__dict_map(self): + ut.proto_fill_from_dict(self.msg, {'messages': map(dict, [{'text': 'one'}, {'text': 'two'}])}) + self.assertEqual(len(self.msg.messages), 2) + self.assertEqual(self.msg.messages[0].text, 'one') + self.assertEqual(self.msg.messages[1].text, 'two') + + def test_proto_fill_from_dict__filter(self): + ut.proto_fill_from_dict(self.msg, {'list_number32': filter(lambda x: True, [1,2,3])}) + self.assertEqual(self.msg.list_number32, [1,2,3]) + + def test_proto_fill_from_dict__dict_filter(self): + ut.proto_fill_from_dict(self.msg, {'messages': filter(lambda x: True, [{'text': 'one'}, {'text': 'two'}])}) + self.assertEqual(len(self.msg.messages), 2) + self.assertEqual(self.msg.messages[0].text, 'one') + self.assertEqual(self.msg.messages[1].text, 'two')