You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
349 lines
9.2 KiB
349 lines
9.2 KiB
import six
|
|
import gevent
|
|
import inspect
|
|
import functools
|
|
|
|
from holster.enum import BaseEnumMeta, EnumAttr
|
|
from datetime import datetime as real_datetime
|
|
|
|
from disco.util.functional import CachedSlotProperty
|
|
from disco.util.hashmap import HashMap
|
|
|
|
DATETIME_FORMATS = [
|
|
'%Y-%m-%dT%H:%M:%S.%f',
|
|
'%Y-%m-%dT%H:%M:%S'
|
|
]
|
|
|
|
|
|
class Unset(object):
|
|
def __nonzero__(self):
|
|
return False
|
|
|
|
|
|
UNSET = Unset()
|
|
|
|
|
|
class ConversionError(Exception):
|
|
def __init__(self, field, raw, e):
|
|
super(ConversionError, self).__init__(
|
|
'Failed to convert `{}` (`{}`) to {}: {}'.format(
|
|
str(raw)[:144], field.src_name, field.deserializer, e))
|
|
|
|
if six.PY3:
|
|
self.__cause__ = e
|
|
|
|
|
|
class Field(object):
|
|
def __init__(self, value_type, alias=None, default=None, **kwargs):
|
|
self.src_name = alias
|
|
self.dst_name = None
|
|
self.metadata = kwargs
|
|
|
|
if default is not None:
|
|
self.default = default
|
|
elif not hasattr(self, 'default'):
|
|
self.default = None
|
|
|
|
self.deserializer = None
|
|
|
|
if value_type:
|
|
self.deserializer = self.type_to_deserializer(value_type)
|
|
|
|
if isinstance(self.deserializer, Field) and self.default is None:
|
|
self.default = self.deserializer.default
|
|
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None:
|
|
self.default = self.deserializer
|
|
|
|
@property
|
|
def name(self):
|
|
return None
|
|
|
|
@name.setter
|
|
def name(self, name):
|
|
if not self.dst_name:
|
|
self.dst_name = name
|
|
|
|
if not self.src_name:
|
|
self.src_name = name
|
|
|
|
def has_default(self):
|
|
return self.default is not None
|
|
|
|
def try_convert(self, raw, client):
|
|
try:
|
|
return self.deserializer(raw, client)
|
|
except Exception as e:
|
|
six.reraise(ConversionError, ConversionError(self, raw, e))
|
|
|
|
@staticmethod
|
|
def type_to_deserializer(typ):
|
|
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
|
|
return typ
|
|
elif isinstance(typ, BaseEnumMeta):
|
|
return lambda raw, _: typ.get(raw)
|
|
elif typ is None:
|
|
return lambda x, y: None
|
|
else:
|
|
return lambda raw, _: typ(raw)
|
|
|
|
@staticmethod
|
|
def serialize(value):
|
|
if isinstance(value, EnumAttr):
|
|
return value.value
|
|
elif isinstance(value, Model):
|
|
return value.to_dict()
|
|
else:
|
|
return value
|
|
|
|
def __call__(self, raw, client):
|
|
return self.try_convert(raw, client)
|
|
|
|
|
|
class DictField(Field):
|
|
default = HashMap
|
|
|
|
def __init__(self, key_type, value_type=None, **kwargs):
|
|
super(DictField, self).__init__(None, **kwargs)
|
|
self.key_de = self.type_to_deserializer(key_type)
|
|
self.value_de = self.type_to_deserializer(value_type or key_type)
|
|
|
|
@staticmethod
|
|
def serialize(value):
|
|
return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)}
|
|
|
|
def try_convert(self, raw, client):
|
|
return HashMap({
|
|
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
|
|
})
|
|
|
|
|
|
class ListField(Field):
|
|
default = list
|
|
|
|
@staticmethod
|
|
def serialize(value):
|
|
return list(map(Field.serialize, value))
|
|
|
|
def try_convert(self, raw, client):
|
|
return [self.deserializer(i, client) for i in raw]
|
|
|
|
|
|
class AutoDictField(Field):
|
|
default = HashMap
|
|
|
|
def __init__(self, value_type, key, **kwargs):
|
|
super(AutoDictField, self).__init__(None, **kwargs)
|
|
self.value_de = self.type_to_deserializer(value_type)
|
|
self.key = key
|
|
|
|
def try_convert(self, raw, client):
|
|
return HashMap({
|
|
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
|
|
})
|
|
|
|
|
|
def _make(typ, data, client):
|
|
if inspect.isclass(typ) and issubclass(typ, Model):
|
|
return typ(data, client)
|
|
return typ(data)
|
|
|
|
|
|
def snowflake(data):
|
|
return int(data) if data else None
|
|
|
|
|
|
def enum(typ):
|
|
def _f(data):
|
|
if isinstance(data, str):
|
|
data = data.lower()
|
|
return typ.get(data) if data is not None else None
|
|
return _f
|
|
|
|
|
|
# TODO: make lazy
|
|
def lazy_datetime(data):
|
|
if not data:
|
|
return None
|
|
|
|
if isinstance(data, int):
|
|
return real_datetime.utcfromtimestamp(data)
|
|
|
|
for fmt in DATETIME_FORMATS:
|
|
try:
|
|
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
raise ValueError('Failed to conver `{}` to datetime'.format(data))
|
|
|
|
|
|
def datetime(data):
|
|
if not data:
|
|
return None
|
|
|
|
for fmt in DATETIME_FORMATS:
|
|
try:
|
|
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
raise ValueError('Failed to conver `{}` to datetime'.format(data))
|
|
|
|
|
|
def text(obj):
|
|
if six.PY2:
|
|
if isinstance(obj, str):
|
|
return obj.decode('utf-8')
|
|
return obj
|
|
else:
|
|
return str(obj)
|
|
|
|
|
|
def binary(obj):
|
|
if six.PY2:
|
|
if isinstance(obj, str):
|
|
return obj.decode('utf-8')
|
|
return unicode(obj)
|
|
else:
|
|
return bytes(obj, 'utf-8')
|
|
|
|
|
|
def with_equality(field):
|
|
class T(object):
|
|
def __eq__(self, other):
|
|
return getattr(self, field) == getattr(other, field)
|
|
return T
|
|
|
|
|
|
def with_hash(field):
|
|
class T(object):
|
|
def __hash__(self):
|
|
return hash(getattr(self, field))
|
|
return T
|
|
|
|
|
|
# Resolution hacks :(
|
|
Model = None
|
|
SlottedModel = None
|
|
|
|
|
|
class ModelMeta(type):
|
|
def __new__(mcs, name, parents, dct):
|
|
fields = {}
|
|
|
|
for parent in parents:
|
|
if Model and issubclass(parent, Model) and parent != Model:
|
|
fields.update(parent._fields)
|
|
|
|
for k, v in six.iteritems(dct):
|
|
if not isinstance(v, Field):
|
|
continue
|
|
|
|
v.name = k
|
|
fields[k] = v
|
|
|
|
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):
|
|
bases = set(v.stored_name for v in six.itervalues(dct) if isinstance(v, CachedSlotProperty))
|
|
|
|
if '__slots__' in dct:
|
|
dct['__slots__'] = tuple(set(dct['__slots__']) | set(fields.keys()) | bases)
|
|
else:
|
|
dct['__slots__'] = tuple(fields.keys()) + tuple(bases)
|
|
|
|
dct = {k: v for k, v in six.iteritems(dct) if k not in dct['__slots__']}
|
|
else:
|
|
dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
|
|
|
|
dct['_fields'] = fields
|
|
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
|
|
|
|
|
|
class AsyncChainable(object):
|
|
__slots__ = []
|
|
|
|
def after(self, delay):
|
|
gevent.sleep(delay)
|
|
return self
|
|
|
|
|
|
class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
|
|
__slots__ = ['client']
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.client = kwargs.pop('client', None)
|
|
|
|
if len(args) == 1:
|
|
obj = args[0]
|
|
elif len(args) == 2:
|
|
obj, self.client = args
|
|
else:
|
|
obj = kwargs
|
|
|
|
self.load(obj)
|
|
|
|
@property
|
|
def fields(self):
|
|
return self.__class__._fields
|
|
|
|
def load(self, obj, consume=False, skip=None):
|
|
for name, field in six.iteritems(self.fields):
|
|
should_skip = skip and name in skip
|
|
|
|
if consume and not should_skip:
|
|
raw = obj.pop(field.src_name, None)
|
|
else:
|
|
raw = obj.get(field.src_name, None)
|
|
|
|
if raw is None or should_skip:
|
|
if field.has_default():
|
|
default = field.default() if callable(field.default) else field.default
|
|
else:
|
|
default = UNSET
|
|
setattr(self, field.dst_name, default)
|
|
continue
|
|
|
|
value = field.try_convert(raw, self.client)
|
|
setattr(self, field.dst_name, value)
|
|
|
|
def update(self, other):
|
|
for name in six.iterkeys(self.fields):
|
|
if hasattr(other, name) and not getattr(other, name) is UNSET:
|
|
setattr(self, name, getattr(other, name))
|
|
|
|
# Clear cached properties
|
|
for name in dir(type(self)):
|
|
if isinstance(getattr(type(self), name), property):
|
|
try:
|
|
delattr(self, name)
|
|
except:
|
|
pass
|
|
|
|
def to_dict(self):
|
|
obj = {}
|
|
for name, field in six.iteritems(self.__class__._fields):
|
|
if getattr(self, name) == UNSET:
|
|
continue
|
|
obj[name] = field.serialize(getattr(self, name))
|
|
return obj
|
|
|
|
@classmethod
|
|
def create(cls, client, data, **kwargs):
|
|
data.update(kwargs)
|
|
inst = cls(data, client)
|
|
return inst
|
|
|
|
@classmethod
|
|
def create_map(cls, client, data, **kwargs):
|
|
return list(map(functools.partial(cls.create, client, **kwargs), data))
|
|
|
|
@classmethod
|
|
def attach(cls, it, data):
|
|
for item in it:
|
|
for k, v in six.iteritems(data):
|
|
try:
|
|
setattr(item, k, v)
|
|
except:
|
|
pass
|
|
|
|
|
|
class SlottedModel(Model):
|
|
__slots__ = ['client']
|
|
|