From cca53f0a7072e3b7914cad62159d15f9afe9a172 Mon Sep 17 00:00:00 2001 From: andrei Date: Fri, 23 Jun 2017 16:11:56 -0700 Subject: [PATCH] [types] cleanup model loading and make consume greedy --- disco/types/base.py | 51 ++++++++++++++++++++++++++------------------- tests/test_types.py | 46 +++++++++++++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 32 deletions(-) diff --git a/disco/types/base.py b/disco/types/base.py index 883ba38..04679ff 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -83,9 +83,9 @@ class Field(object): def has_default(self): return self.default is not None - def try_convert(self, raw, client): + def try_convert(self, raw, client, **kwargs): try: - return self.deserializer(raw, client) + return self.deserializer(raw, client, **kwargs) except Exception as e: six.reraise(ConversionError, ConversionError(self, raw, e)) @@ -94,11 +94,16 @@ class Field(object): if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model): return typ elif isinstance(typ, BaseEnumMeta): - return lambda raw, _: typ.get(raw) + def _f(raw, client, **kwargs): + return typ.get(raw) + return _f elif typ is None: - return lambda x, y: None + def _f(*args, **kwargs): + return None else: - return lambda raw, _: typ(raw) + def _f(raw, client, **kwargs): + return typ(raw) + return _f @staticmethod def serialize(value, inst=None): @@ -111,8 +116,8 @@ class Field(object): return inst.cast(value) return value - def __call__(self, raw, client): - return self.try_convert(raw, client) + def __call__(self, raw, client, **kwargs): + return self.try_convert(raw, client, **kwargs) class DictField(Field): @@ -132,7 +137,7 @@ class DictField(Field): if k not in (inst.ignore_dump if inst else []) } - def try_convert(self, raw, client): + def try_convert(self, raw, client, **kwargs): return HashMap({ self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) }) @@ -145,7 +150,7 @@ class ListField(Field): def serialize(value, inst=None): return list(map(Field.serialize, value)) - def try_convert(self, raw, client): + def try_convert(self, raw, client, **kwargs): return [self.deserializer(i, client) for i in raw] @@ -157,7 +162,7 @@ class AutoDictField(Field): self.value_de = self.type_to_deserializer(value_type) self.key = key - def try_convert(self, raw, client): + def try_convert(self, raw, client, **kwargs): return HashMap({ getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw) }) @@ -274,8 +279,9 @@ class Model(six.with_metaclass(ModelMeta, Chainable)): obj, self.client = args else: obj = kwargs + kwargs = {} - self.load(obj) + self.load(obj, **kwargs) self.validate() def after(self, delay): @@ -289,31 +295,32 @@ class Model(six.with_metaclass(ModelMeta, Chainable)): def _fields(self): return self.__class__._fields - def load(self, obj, consume=False, skip=None): - return self.load_into(self, obj, consume, skip) + def load(self, *args, **kwargs): + return self.load_into(self, *args, **kwargs) @classmethod - def load_into(cls, inst, obj, consume=False, skip=None): + def load_into(cls, inst, obj, consume=False): for name, field in six.iteritems(cls._fields): - should_skip = skip and name in skip + try: + raw = obj[field.src_name] - if consume and not should_skip: - raw = obj.pop(field.src_name, UNSET) - else: - raw = obj.get(field.src_name, UNSET) + if consume and not isinstance(raw, dict): + del obj[field.src_name] + except KeyError: + raw = UNSET # If the field is unset/none, and we have a default we need to set it - if (raw in (None, UNSET) or should_skip) and field.has_default(): + if raw in (None, UNSET) and field.has_default(): default = field.default() if callable(field.default) else field.default setattr(inst, field.dst_name, default) continue # Otherwise if the field is UNSET and has no default, skip conversion - if raw is UNSET or should_skip: + if raw is UNSET: setattr(inst, field.dst_name, raw) continue - value = field.try_convert(raw, inst.client) + value = field.try_convert(raw, inst.client, consume=consume) setattr(inst, field.dst_name, value) def update(self, other, ignored=None): diff --git a/tests/test_types.py b/tests/test_types.py index c5df25c..0f96560 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -2,32 +2,58 @@ from unittest import TestCase from disco.types.base import Model, Field -class _M(Model): +class _A(Model): a = Field(int) b = Field(float) c = Field(str) +class _B(Model): + a = Field(int) + b = Field(float) + c = Field(str) + + +class _C(Model): + a = Field(_A) + b = Field(_B) + + class TestModel(TestCase): def test_model_simple_loading(self): - inst = _M(dict(a=1, b=1.1, c='test')) + inst = _A(dict(a=1, b=1.1, c='test')) self.assertEquals(inst.a, 1) self.assertEquals(inst.b, 1.1) self.assertEquals(inst.c, 'test') def test_model_load_into(self): - inst = _M() - _M.load_into(inst, dict(a=1, b=1.1, c='test')) + inst = _A() + _A.load_into(inst, dict(a=1, b=1.1, c='test')) self.assertEquals(inst.a, 1) self.assertEquals(inst.b, 1.1) self.assertEquals(inst.c, 'test') def test_model_loading_consume(self): - obj = dict(a=5, b=33.33, c='wtf') - inst = _M() + obj = { + 'a': { + 'a': 1, + 'b': 2.2, + 'c': '3', + 'd': 'wow', + }, + 'b': { + 'a': 3, + 'b': 2.2, + 'c': '1', + 'z': 'wtf' + }, + 'g': 'lmao' + } + + inst = _C() inst.load(obj, consume=True) - self.assertEquals(obj, {}) - self.assertEquals(inst.a, 5) - self.assertEquals(inst.b, 33.33) - self.assertEquals(inst.c, 'wtf') + self.assertEquals(inst.a.a, 1) + self.assertEquals(inst.b.c, '1') + + self.assertEquals(obj, {'a': {'d': 'wow'}, 'b': {'z': 'wtf'}, 'g': 'lmao'})