From 2970466311248978a15f3bc3a0ab6e7f55f72614 Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 24 Apr 2017 15:30:43 -0700 Subject: [PATCH] Add Chainable/async utils --- disco/types/base.py | 17 +++++------ disco/util/chains.py | 71 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) create mode 100644 disco/util/chains.py diff --git a/disco/types/base.py b/disco/types/base.py index 4c0d916..d396762 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -6,8 +6,9 @@ import functools from holster.enum import BaseEnumMeta, EnumAttr from datetime import datetime as real_datetime -from disco.util.functional import CachedSlotProperty +from disco.util.chains import Chainable from disco.util.hashmap import HashMap +from disco.util.functional import CachedSlotProperty DATETIME_FORMATS = [ '%Y-%m-%dT%H:%M:%S.%f', @@ -273,15 +274,7 @@ class ModelMeta(type): return super(ModelMeta, mcs).__new__(mcs, name, parents, dct) -class AsyncChainable(object): - __slots__ = [] - - def after(self, delay): - gevent.sleep(delay) - return self - - -class Model(six.with_metaclass(ModelMeta, AsyncChainable)): +class Model(six.with_metaclass(ModelMeta, Chainable)): __slots__ = ['client'] def __init__(self, *args, **kwargs): @@ -297,6 +290,10 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): self.load(obj) self.validate() + def after(self, delay): + gevent.sleep(delay) + return self + def validate(self): pass diff --git a/disco/util/chains.py b/disco/util/chains.py new file mode 100644 index 0000000..5ea261f --- /dev/null +++ b/disco/util/chains.py @@ -0,0 +1,71 @@ +import gevent + +""" +Object.chain -> creates a chain where each action happens after the last + pass_result = False -> whether the result of the last action is passed, or the original + +Object.async_chain -> creates an async chain where each action happens at the same time +""" + + +class Chainable(object): + __slots__ = [] + + def chain(self, pass_result=True): + return Chain(self, pass_result=pass_result, async_=False) + + def async_chain(self): + return Chain(self, pass_result=False, async_=True) + + +class Chain(object): + def __init__(self, obj, pass_result=True, async_=False): + self._obj = obj + self._pass_result = pass_result + self._async = async_ + self._parts = [] + + @property + def obj(self): + if isinstance(self._obj, Chain): + return self._obj._next() + return self._obj + + def __getattr__(self, item): + func = getattr(self.obj, item) + if not func or not callable(func): + return func + + def _wrapped(*args, **kwargs): + inst = gevent.spawn(func, *args, **kwargs) + self._parts.append(inst) + + # If async, just return instantly + if self._async: + return self + + # Otherwise return a chain + return Chain(self) + return _wrapped + + def _next(self): + res = self._parts[0].get() + if self._pass_result: + return res + return self + + def then(self, func, *args, **kwargs): + inst = gevent.spawn(func, *args, **kwargs) + self._parts.append(inst) + if self._async: + return self + return Chain(self) + + def first(self): + return self._obj + + def get(self, timeout=None): + return gevent.wait(self._parts, timeout=timeout) + + def wait(self, timeout=None): + gevent.joinall(self._parts, timeout=None)