You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
423 lines
12 KiB
423 lines
12 KiB
import six
|
|
import gevent
|
|
import inspect
|
|
import functools
|
|
|
|
from datetime import datetime as real_datetime
|
|
|
|
from disco.util.chains import Chainable
|
|
from disco.util.hashmap import HashMap
|
|
from disco.util.enum import get_enum_members
|
|
|
|
DATETIME_FORMATS = [
|
|
'%Y-%m-%dT%H:%M:%S.%f',
|
|
'%Y-%m-%dT%H:%M:%S',
|
|
]
|
|
|
|
|
|
def get_item_by_path(obj, path):
|
|
for part in path.split('.'):
|
|
obj = getattr(obj, part)
|
|
return obj
|
|
|
|
|
|
class Unset(object):
|
|
def __nonzero__(self):
|
|
return False
|
|
|
|
def __bool__(self):
|
|
return False
|
|
|
|
|
|
UNSET = Unset()
|
|
|
|
|
|
def cached_property(method):
|
|
method._cached_property = set()
|
|
return method
|
|
|
|
|
|
def strict_cached_property(*args):
|
|
def _cached_property(method):
|
|
method._cached_property = set(args)
|
|
return method
|
|
return _cached_property
|
|
|
|
|
|
class ConversionError(Exception):
|
|
def __init__(self, field, raw, e):
|
|
super(ConversionError, self).__init__(
|
|
'Failed to convert `{}` (`{}`) to {}: {}'.format(
|
|
str(raw)[:144], field.src_name, field.true_type, e))
|
|
|
|
if six.PY3:
|
|
self.__cause__ = e
|
|
|
|
|
|
class Field(object):
|
|
def __init__(self, value_type, alias=None, default=UNSET, create=True, ignore_dump=None, cast=None, **kwargs):
|
|
# TODO: fix default bullshit
|
|
self.true_type = value_type
|
|
self.src_name = alias
|
|
self.dst_name = None
|
|
self.ignore_dump = ignore_dump or []
|
|
self.cast = cast
|
|
self.metadata = kwargs
|
|
|
|
# Only set the default value if we where given one
|
|
if default is not UNSET:
|
|
self.default = default
|
|
# Attempt to use the instances default type (e.g. from a subclass)
|
|
elif not hasattr(self, 'default'):
|
|
self.default = UNSET
|
|
|
|
self.deserializer = None
|
|
|
|
if value_type:
|
|
self.deserializer = self.type_to_deserializer(value_type)
|
|
|
|
if isinstance(self.deserializer, Field) and self.default is UNSET:
|
|
self.default = self.deserializer.default
|
|
elif (inspect.isclass(self.deserializer) and
|
|
issubclass(self.deserializer, Model) and
|
|
self.default is UNSET and
|
|
create):
|
|
self.default = self.deserializer
|
|
|
|
@property
|
|
def name(self):
|
|
return None
|
|
|
|
@name.setter
|
|
def name(self, name):
|
|
if not self.dst_name:
|
|
self.dst_name = name
|
|
|
|
if not self.src_name:
|
|
self.src_name = name
|
|
|
|
def has_default(self):
|
|
return self.default is not UNSET
|
|
|
|
def try_convert(self, raw, client, **kwargs):
|
|
try:
|
|
return self.deserializer(raw, client, **kwargs)
|
|
except Exception as e:
|
|
six.reraise(ConversionError, ConversionError(self, raw, e))
|
|
|
|
@staticmethod
|
|
def type_to_deserializer(typ):
|
|
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
|
|
return typ
|
|
# elif isinstance(typ, BaseEnumMeta):
|
|
# def _f(raw, client, **kwargs):
|
|
# return typ.get(raw)
|
|
# return _f
|
|
elif typ is None:
|
|
def _f(*args, **kwargs):
|
|
return None
|
|
else:
|
|
def _f(raw, client, **kwargs):
|
|
return typ(raw)
|
|
return _f
|
|
|
|
@staticmethod
|
|
def serialize(value, inst=None):
|
|
if isinstance(value, Model):
|
|
return value.to_dict(ignore=(inst.ignore_dump if inst else []))
|
|
else:
|
|
if inst and inst.cast:
|
|
return inst.cast(value)
|
|
return value
|
|
|
|
def __call__(self, raw, client, **kwargs):
|
|
return self.try_convert(raw, client, **kwargs)
|
|
|
|
|
|
class DictField(Field):
|
|
default = HashMap
|
|
|
|
def __init__(self, key_type, value_type=None, **kwargs):
|
|
super(DictField, self).__init__({}, **kwargs)
|
|
self.true_key_type = key_type
|
|
self.true_value_type = value_type
|
|
self.key_de = self.type_to_deserializer(key_type)
|
|
self.value_de = self.type_to_deserializer(value_type or key_type)
|
|
|
|
@staticmethod
|
|
def serialize(value, inst=None):
|
|
return {
|
|
Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)
|
|
if k not in (inst.ignore_dump if inst else [])
|
|
}
|
|
|
|
def try_convert(self, raw, client, **kwargs):
|
|
return HashMap({
|
|
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
|
|
})
|
|
|
|
|
|
class ListField(Field):
|
|
default = list
|
|
|
|
@staticmethod
|
|
def serialize(value, inst=None):
|
|
return list(map(Field.serialize, value))
|
|
|
|
def try_convert(self, raw, client, **kwargs):
|
|
return [self.deserializer(i, client) for i in raw]
|
|
|
|
|
|
class AutoDictField(Field):
|
|
default = HashMap
|
|
|
|
def __init__(self, value_type, key, **kwargs):
|
|
super(AutoDictField, self).__init__({}, **kwargs)
|
|
self.value_de = self.type_to_deserializer(value_type)
|
|
self.key = key
|
|
|
|
def try_convert(self, raw, client, **kwargs):
|
|
return HashMap({
|
|
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
|
|
})
|
|
|
|
|
|
def _make(typ, data, client):
|
|
if inspect.isclass(typ) and issubclass(typ, Model):
|
|
return typ(data, client)
|
|
return typ(data)
|
|
|
|
|
|
def snowflake(data):
|
|
return int(data) if data else None
|
|
|
|
|
|
def enum(typ):
|
|
def _f(data):
|
|
if data is None:
|
|
return None
|
|
|
|
for k, v in get_enum_members(typ):
|
|
if isinstance(data, six.string_types) and k == data.upper():
|
|
return v
|
|
elif k == data or v == data:
|
|
return v
|
|
|
|
return None
|
|
return _f
|
|
|
|
|
|
def datetime(data):
|
|
if not data:
|
|
return None
|
|
|
|
if isinstance(data, int):
|
|
return real_datetime.utcfromtimestamp(data)
|
|
|
|
for fmt in DATETIME_FORMATS:
|
|
try:
|
|
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
raise ValueError('Failed to convert `{}` to datetime'.format(data))
|
|
|
|
|
|
def text(obj):
|
|
if obj is None:
|
|
return None
|
|
|
|
if six.PY2:
|
|
if isinstance(obj, str):
|
|
return obj.decode('utf-8')
|
|
return six.text_type(obj)
|
|
|
|
|
|
def with_equality(field):
|
|
class T(object):
|
|
def __eq__(self, other):
|
|
if isinstance(other, self.__class__):
|
|
return getattr(self, field) == getattr(other, field)
|
|
else:
|
|
return getattr(self, field) == other
|
|
return T
|
|
|
|
|
|
def with_hash(field):
|
|
class T(object):
|
|
def __hash__(self):
|
|
return hash(getattr(self, field))
|
|
return T
|
|
|
|
|
|
# Resolution hacks :(
|
|
Model = None
|
|
SlottedModel = None
|
|
|
|
|
|
def _get_cached_property(name, func):
|
|
def _getattr(self):
|
|
try:
|
|
return getattr(self, '_' + name)
|
|
except AttributeError:
|
|
value = func(self)
|
|
setattr(self, '_' + name, value)
|
|
return value
|
|
|
|
def _setattr(self, value):
|
|
setattr(self, '_' + name)
|
|
|
|
def _delattr(self):
|
|
delattr(self, '_' + name)
|
|
|
|
prop = property(_getattr, _setattr, _delattr)
|
|
return prop
|
|
|
|
|
|
class ModelMeta(type):
|
|
def __new__(mcs, name, parents, dct):
|
|
fields = {}
|
|
slots = set()
|
|
|
|
for parent in parents:
|
|
if Model and issubclass(parent, Model) and parent != Model:
|
|
fields.update(parent._fields)
|
|
|
|
for k, v in six.iteritems(dct):
|
|
if hasattr(v, '_cached_property'):
|
|
dct[k] = _get_cached_property(k, v)
|
|
slots.add('_' + k)
|
|
|
|
if not isinstance(v, Field):
|
|
continue
|
|
|
|
v.name = k
|
|
fields[k] = v
|
|
slots.add(k)
|
|
|
|
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):
|
|
# Merge our set of field slots with any other slots from the mro
|
|
dct['__slots__'] = tuple(set(dct.get('__slots__', [])) | slots)
|
|
|
|
# Remove all fields from the dict
|
|
dct = {k: v for k, v in six.iteritems(dct) if k not in dct['__slots__']}
|
|
else:
|
|
dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
|
|
|
|
dct['_fields'] = fields
|
|
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
|
|
|
|
|
|
class Model(six.with_metaclass(ModelMeta, Chainable)):
|
|
__slots__ = ['client']
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.client = kwargs.pop('client', None)
|
|
|
|
if len(args) == 1:
|
|
obj = args[0]
|
|
elif len(args) == 2:
|
|
obj, self.client = args
|
|
else:
|
|
obj = kwargs
|
|
kwargs = {}
|
|
|
|
self.load(obj, **kwargs)
|
|
self.validate()
|
|
|
|
def after(self, delay):
|
|
gevent.sleep(delay)
|
|
return self
|
|
|
|
def validate(self):
|
|
pass
|
|
|
|
@property
|
|
def _fields(self):
|
|
return self.__class__._fields
|
|
|
|
def load(self, *args, **kwargs):
|
|
return self.load_into(self, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def load_into(cls, inst, obj, consume=False):
|
|
for name, field in six.iteritems(cls._fields):
|
|
try:
|
|
raw = obj[field.src_name]
|
|
|
|
if consume and not isinstance(raw, dict):
|
|
del obj[field.src_name]
|
|
except KeyError:
|
|
raw = UNSET
|
|
|
|
# If the field is unset/none, and we have a default we need to set it
|
|
if raw in (None, UNSET) and field.has_default():
|
|
default = field.default() if callable(field.default) else field.default
|
|
setattr(inst, field.dst_name, default)
|
|
continue
|
|
|
|
# Otherwise if the field is UNSET and has no default, skip conversion
|
|
if raw is UNSET:
|
|
setattr(inst, field.dst_name, raw)
|
|
continue
|
|
|
|
value = field.try_convert(raw, inst.client, consume=consume)
|
|
setattr(inst, field.dst_name, value)
|
|
|
|
def inplace_update(self, other, ignored=None):
|
|
for name in six.iterkeys(self._fields):
|
|
if ignored and name in ignored:
|
|
continue
|
|
|
|
if hasattr(other, name) and not getattr(other, name) is UNSET:
|
|
setattr(self, name, getattr(other, name))
|
|
|
|
# Clear cached properties
|
|
for name in dir(type(self)):
|
|
if isinstance(getattr(type(self), name), property):
|
|
try:
|
|
delattr(self, name)
|
|
except Exception:
|
|
pass
|
|
|
|
def to_dict(self, ignore=None):
|
|
obj = {}
|
|
for name, field in six.iteritems(self.__class__._fields):
|
|
if ignore and name in ignore:
|
|
continue
|
|
|
|
if getattr(self, name) is UNSET:
|
|
continue
|
|
obj[name] = field.serialize(getattr(self, name), field)
|
|
return obj
|
|
|
|
@classmethod
|
|
def create(cls, client, data, **kwargs):
|
|
data.update(kwargs)
|
|
inst = cls(data, client)
|
|
return inst
|
|
|
|
@classmethod
|
|
def create_map(cls, client, data, *args, **kwargs):
|
|
return list(map(functools.partial(cls.create, client, *args, **kwargs), data))
|
|
|
|
@classmethod
|
|
def create_hash(cls, client, key, data, **kwargs):
|
|
return HashMap({
|
|
get_item_by_path(item, key): item
|
|
for item in [
|
|
cls.create(client, item, **kwargs) for item in data]
|
|
})
|
|
|
|
@classmethod
|
|
def attach(cls, it, data):
|
|
for item in it:
|
|
for k, v in six.iteritems(data):
|
|
try:
|
|
setattr(item, k, v)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class SlottedModel(Model):
|
|
__slots__ = ['client']
|
|
|