You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
9.2 KiB

import six
import gevent
import inspect
import functools
from holster.enum import BaseEnumMeta, EnumAttr
from datetime import datetime as real_datetime
from disco.util.functional import CachedSlotProperty
from disco.util.hashmap import HashMap
DATETIME_FORMATS = [
'%Y-%m-%dT%H:%M:%S.%f',
'%Y-%m-%dT%H:%M:%S'
]
class Unset(object):
def __nonzero__(self):
return False
UNSET = Unset()
class ConversionError(Exception):
def __init__(self, field, raw, e):
super(ConversionError, self).__init__(
'Failed to convert `{}` (`{}`) to {}: {}'.format(
str(raw)[:144], field.src_name, field.deserializer, e))
if six.PY3:
self.__cause__ = e
class Field(object):
def __init__(self, value_type, alias=None, default=None, **kwargs):
self.src_name = alias
self.dst_name = None
self.metadata = kwargs
if default is not None:
self.default = default
elif not hasattr(self, 'default'):
self.default = None
self.deserializer = None
if value_type:
self.deserializer = self.type_to_deserializer(value_type)
if isinstance(self.deserializer, Field) and self.default is None:
self.default = self.deserializer.default
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None:
self.default = self.deserializer
@property
def name(self):
return None
@name.setter
def name(self, name):
if not self.dst_name:
self.dst_name = name
if not self.src_name:
self.src_name = name
def has_default(self):
return self.default is not None
def try_convert(self, raw, client):
try:
return self.deserializer(raw, client)
except Exception as e:
six.reraise(ConversionError, ConversionError(self, raw, e))
@staticmethod
def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
elif typ is None:
return lambda x, y: None
else:
return lambda raw, _: typ(raw)
@staticmethod
def serialize(value):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict()
else:
return value
def __call__(self, raw, client):
return self.try_convert(raw, client)
class DictField(Field):
default = HashMap
def __init__(self, key_type, value_type=None, **kwargs):
super(DictField, self).__init__(None, **kwargs)
self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type)
@staticmethod
def serialize(value):
return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)}
def try_convert(self, raw, client):
return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
})
class ListField(Field):
default = list
@staticmethod
def serialize(value):
return list(map(Field.serialize, value))
def try_convert(self, raw, client):
return [self.deserializer(i, client) for i in raw]
class AutoDictField(Field):
default = HashMap
def __init__(self, value_type, key, **kwargs):
super(AutoDictField, self).__init__(None, **kwargs)
self.value_de = self.type_to_deserializer(value_type)
self.key = key
def try_convert(self, raw, client):
return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
def _make(typ, data, client):
if inspect.isclass(typ) and issubclass(typ, Model):
return typ(data, client)
return typ(data)
def snowflake(data):
return int(data) if data else None
def enum(typ):
def _f(data):
if isinstance(data, str):
data = data.lower()
return typ.get(data) if data is not None else None
return _f
# TODO: make lazy
def lazy_datetime(data):
if not data:
return None
if isinstance(data, int):
return real_datetime.utcfromtimestamp(data)
for fmt in DATETIME_FORMATS:
try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
except (ValueError, TypeError):
continue
raise ValueError('Failed to conver `{}` to datetime'.format(data))
def datetime(data):
if not data:
return None
for fmt in DATETIME_FORMATS:
try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
except (ValueError, TypeError):
continue
raise ValueError('Failed to conver `{}` to datetime'.format(data))
def text(obj):
if six.PY2:
if isinstance(obj, str):
return obj.decode('utf-8')
return obj
else:
return str(obj)
def binary(obj):
if six.PY2:
if isinstance(obj, str):
return obj.decode('utf-8')
return unicode(obj)
else:
return bytes(obj, 'utf-8')
def with_equality(field):
class T(object):
def __eq__(self, other):
return getattr(self, field) == getattr(other, field)
return T
def with_hash(field):
class T(object):
def __hash__(self):
return hash(getattr(self, field))
return T
# Resolution hacks :(
Model = None
SlottedModel = None
class ModelMeta(type):
def __new__(mcs, name, parents, dct):
fields = {}
for parent in parents:
if Model and issubclass(parent, Model) and parent != Model:
fields.update(parent._fields)
for k, v in six.iteritems(dct):
if not isinstance(v, Field):
continue
v.name = k
fields[k] = v
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):
bases = set(v.stored_name for v in six.itervalues(dct) if isinstance(v, CachedSlotProperty))
if '__slots__' in dct:
dct['__slots__'] = tuple(set(dct['__slots__']) | set(fields.keys()) | bases)
else:
dct['__slots__'] = tuple(fields.keys()) + tuple(bases)
dct = {k: v for k, v in six.iteritems(dct) if k not in dct['__slots__']}
else:
dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
dct['_fields'] = fields
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
class AsyncChainable(object):
__slots__ = []
def after(self, delay):
gevent.sleep(delay)
return self
class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
__slots__ = ['client']
def __init__(self, *args, **kwargs):
self.client = kwargs.pop('client', None)
if len(args) == 1:
obj = args[0]
elif len(args) == 2:
obj, self.client = args
else:
obj = kwargs
self.load(obj)
@property
def fields(self):
return self.__class__._fields
def load(self, obj, consume=False, skip=None):
for name, field in six.iteritems(self.fields):
should_skip = skip and name in skip
if consume and not should_skip:
raw = obj.pop(field.src_name, None)
else:
raw = obj.get(field.src_name, None)
if raw is None or should_skip:
if field.has_default():
default = field.default() if callable(field.default) else field.default
else:
default = UNSET
setattr(self, field.dst_name, default)
continue
value = field.try_convert(raw, self.client)
setattr(self, field.dst_name, value)
def update(self, other):
for name in six.iterkeys(self.fields):
if hasattr(other, name) and not getattr(other, name) is UNSET:
setattr(self, name, getattr(other, name))
# Clear cached properties
for name in dir(type(self)):
if isinstance(getattr(type(self), name), property):
try:
delattr(self, name)
except:
pass
def to_dict(self):
obj = {}
for name, field in six.iteritems(self.__class__._fields):
if getattr(self, name) == UNSET:
continue
obj[name] = field.serialize(getattr(self, name))
return obj
@classmethod
def create(cls, client, data, **kwargs):
data.update(kwargs)
inst = cls(data, client)
return inst
@classmethod
def create_map(cls, client, data, **kwargs):
return list(map(functools.partial(cls.create, client, **kwargs), data))
@classmethod
def attach(cls, it, data):
for item in it:
for k, v in six.iteritems(data):
try:
setattr(item, k, v)
except:
pass
class SlottedModel(Model):
__slots__ = ['client']