Browse Source

[types] cleanup model loading and make consume greedy

pull/39/head
andrei 8 years ago
parent
commit
cca53f0a70
  1. 51
      disco/types/base.py
  2. 46
      tests/test_types.py

51
disco/types/base.py

@ -83,9 +83,9 @@ class Field(object):
def has_default(self): def has_default(self):
return self.default is not None return self.default is not None
def try_convert(self, raw, client): def try_convert(self, raw, client, **kwargs):
try: try:
return self.deserializer(raw, client) return self.deserializer(raw, client, **kwargs)
except Exception as e: except Exception as e:
six.reraise(ConversionError, ConversionError(self, raw, e)) six.reraise(ConversionError, ConversionError(self, raw, e))
@ -94,11 +94,16 @@ class Field(object):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model): if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ return typ
elif isinstance(typ, BaseEnumMeta): elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw) def _f(raw, client, **kwargs):
return typ.get(raw)
return _f
elif typ is None: elif typ is None:
return lambda x, y: None def _f(*args, **kwargs):
return None
else: else:
return lambda raw, _: typ(raw) def _f(raw, client, **kwargs):
return typ(raw)
return _f
@staticmethod @staticmethod
def serialize(value, inst=None): def serialize(value, inst=None):
@ -111,8 +116,8 @@ class Field(object):
return inst.cast(value) return inst.cast(value)
return value return value
def __call__(self, raw, client): def __call__(self, raw, client, **kwargs):
return self.try_convert(raw, client) return self.try_convert(raw, client, **kwargs)
class DictField(Field): class DictField(Field):
@ -132,7 +137,7 @@ class DictField(Field):
if k not in (inst.ignore_dump if inst else []) if k not in (inst.ignore_dump if inst else [])
} }
def try_convert(self, raw, client): def try_convert(self, raw, client, **kwargs):
return HashMap({ return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
}) })
@ -145,7 +150,7 @@ class ListField(Field):
def serialize(value, inst=None): def serialize(value, inst=None):
return list(map(Field.serialize, value)) return list(map(Field.serialize, value))
def try_convert(self, raw, client): def try_convert(self, raw, client, **kwargs):
return [self.deserializer(i, client) for i in raw] return [self.deserializer(i, client) for i in raw]
@ -157,7 +162,7 @@ class AutoDictField(Field):
self.value_de = self.type_to_deserializer(value_type) self.value_de = self.type_to_deserializer(value_type)
self.key = key self.key = key
def try_convert(self, raw, client): def try_convert(self, raw, client, **kwargs):
return HashMap({ return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw) getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
}) })
@ -274,8 +279,9 @@ class Model(six.with_metaclass(ModelMeta, Chainable)):
obj, self.client = args obj, self.client = args
else: else:
obj = kwargs obj = kwargs
kwargs = {}
self.load(obj) self.load(obj, **kwargs)
self.validate() self.validate()
def after(self, delay): def after(self, delay):
@ -289,31 +295,32 @@ class Model(six.with_metaclass(ModelMeta, Chainable)):
def _fields(self): def _fields(self):
return self.__class__._fields return self.__class__._fields
def load(self, obj, consume=False, skip=None): def load(self, *args, **kwargs):
return self.load_into(self, obj, consume, skip) return self.load_into(self, *args, **kwargs)
@classmethod @classmethod
def load_into(cls, inst, obj, consume=False, skip=None): def load_into(cls, inst, obj, consume=False):
for name, field in six.iteritems(cls._fields): for name, field in six.iteritems(cls._fields):
should_skip = skip and name in skip try:
raw = obj[field.src_name]
if consume and not should_skip: if consume and not isinstance(raw, dict):
raw = obj.pop(field.src_name, UNSET) del obj[field.src_name]
else: except KeyError:
raw = obj.get(field.src_name, UNSET) raw = UNSET
# If the field is unset/none, and we have a default we need to set it # If the field is unset/none, and we have a default we need to set it
if (raw in (None, UNSET) or should_skip) and field.has_default(): if raw in (None, UNSET) and field.has_default():
default = field.default() if callable(field.default) else field.default default = field.default() if callable(field.default) else field.default
setattr(inst, field.dst_name, default) setattr(inst, field.dst_name, default)
continue continue
# Otherwise if the field is UNSET and has no default, skip conversion # Otherwise if the field is UNSET and has no default, skip conversion
if raw is UNSET or should_skip: if raw is UNSET:
setattr(inst, field.dst_name, raw) setattr(inst, field.dst_name, raw)
continue continue
value = field.try_convert(raw, inst.client) value = field.try_convert(raw, inst.client, consume=consume)
setattr(inst, field.dst_name, value) setattr(inst, field.dst_name, value)
def update(self, other, ignored=None): def update(self, other, ignored=None):

46
tests/test_types.py

@ -2,32 +2,58 @@ from unittest import TestCase
from disco.types.base import Model, Field from disco.types.base import Model, Field
class _M(Model): class _A(Model):
a = Field(int) a = Field(int)
b = Field(float) b = Field(float)
c = Field(str) c = Field(str)
class _B(Model):
a = Field(int)
b = Field(float)
c = Field(str)
class _C(Model):
a = Field(_A)
b = Field(_B)
class TestModel(TestCase): class TestModel(TestCase):
def test_model_simple_loading(self): def test_model_simple_loading(self):
inst = _M(dict(a=1, b=1.1, c='test')) inst = _A(dict(a=1, b=1.1, c='test'))
self.assertEquals(inst.a, 1) self.assertEquals(inst.a, 1)
self.assertEquals(inst.b, 1.1) self.assertEquals(inst.b, 1.1)
self.assertEquals(inst.c, 'test') self.assertEquals(inst.c, 'test')
def test_model_load_into(self): def test_model_load_into(self):
inst = _M() inst = _A()
_M.load_into(inst, dict(a=1, b=1.1, c='test')) _A.load_into(inst, dict(a=1, b=1.1, c='test'))
self.assertEquals(inst.a, 1) self.assertEquals(inst.a, 1)
self.assertEquals(inst.b, 1.1) self.assertEquals(inst.b, 1.1)
self.assertEquals(inst.c, 'test') self.assertEquals(inst.c, 'test')
def test_model_loading_consume(self): def test_model_loading_consume(self):
obj = dict(a=5, b=33.33, c='wtf') obj = {
inst = _M() 'a': {
'a': 1,
'b': 2.2,
'c': '3',
'd': 'wow',
},
'b': {
'a': 3,
'b': 2.2,
'c': '1',
'z': 'wtf'
},
'g': 'lmao'
}
inst = _C()
inst.load(obj, consume=True) inst.load(obj, consume=True)
self.assertEquals(obj, {}) self.assertEquals(inst.a.a, 1)
self.assertEquals(inst.a, 5) self.assertEquals(inst.b.c, '1')
self.assertEquals(inst.b, 33.33)
self.assertEquals(inst.c, 'wtf') self.assertEquals(obj, {'a': {'d': 'wow'}, 'b': {'z': 'wtf'}, 'g': 'lmao'})

Loading…
Cancel
Save