From fc85deea521719b4756331d7421b1342b3e6d3ed Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 11 Oct 2016 01:36:46 -0500 Subject: [PATCH] This one weird trick to reduce memory usage by 60%! Tl;dr we now use __slots__ in a bunch of places. This could still be better, and we do a bit too much magic in the modeling to make me happy. But thats for later, for now we're going from ~250mb on 2500 guilds to ~160mb. - Allow configuring the state module within the normal configuration (under the 'state' key) - Rate limit events being sent on the gateway socket - Convert to using lazy_datetime in a bunch of places - Allow configuring guild member sync - Better logic around loading guilds, add State.ready condition which can be waited on - Fix inheritance in the modeling framework (how was this not working before lol wut) - Added __slots__ to a bunch of low-hanging fruit models - Move member sync onto the guild object as Guild.sync() - Convert to Dannys CachedSlotProperty (could still be better, will improve later) - Added util.snowflake.calculate_shard --- disco/client.py | 4 +-- disco/gateway/client.py | 13 ++++++++-- disco/gateway/events.py | 6 ++--- disco/state.py | 48 +++++++++++++++++++++-------------- disco/types/base.py | 54 ++++++++++++++++++++++++++++++++++------ disco/types/channel.py | 1 + disco/types/guild.py | 28 +++++++++++++++++++-- disco/types/invite.py | 4 +-- disco/types/message.py | 6 ++--- disco/types/user.py | 6 +++++ disco/types/voice.py | 5 ++++ disco/util/config.py | 3 +++ disco/util/functional.py | 54 +++++++++++----------------------------- disco/util/limiter.py | 37 +++++++++++++++++++++++++++ disco/util/snowflake.py | 4 +++ 15 files changed, 193 insertions(+), 80 deletions(-) create mode 100644 disco/util/limiter.py diff --git a/disco/client.py b/disco/client.py index 4d3bf54..54dacd5 100644 --- a/disco/client.py +++ b/disco/client.py @@ -2,7 +2,7 @@ import gevent from holster.emitter import Emitter -from disco.state import State +from disco.state import State, StateConfig from disco.api.client import APIClient from disco.gateway.client import GatewayClient from disco.util.config import Config @@ -82,9 +82,9 @@ class Client(object): self.events = Emitter(gevent.spawn) self.packets = Emitter(gevent.spawn) - self.state = State(self) self.api = APIClient(self) self.gw = GatewayClient(self, self.config.encoder) + self.state = State(self, StateConfig(self.config.get('state', {}))) if self.config.manhole_enable: self.manhole_locals = { diff --git a/disco/gateway/client.py b/disco/gateway/client.py index ac7ea79..b3f8012 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -8,6 +8,7 @@ from disco.gateway.events import GatewayEvent from disco.gateway.encoding import ENCODERS from disco.util.websocket import Websocket from disco.util.logging import LoggingClass +from disco.util.limiter import SimpleLimiter TEN_MEGABYTES = 10490000 @@ -24,6 +25,9 @@ class GatewayClient(LoggingClass): self.events = client.events self.packets = client.packets + # Its actually 60, 120 but lets give ourselves a buffer + self.limiter = SimpleLimiter(60, 130) + # Create emitter and bind to gateway payloads self.packets.on((RECV, OPCode.DISPATCH), self.handle_dispatch) self.packets.on((RECV, OPCode.HEARTBEAT), self.handle_heartbeat) @@ -51,6 +55,11 @@ class GatewayClient(LoggingClass): self._heartbeat_task = None def send(self, op, data): + self.limiter.check() + return self._send(op, data) + + def _send(self, op, data): + self.log.debug('SEND %s', op) self.packets.emit((SEND, op), data) self.ws.send(self.encoder.encode({ 'op': op.value, @@ -59,7 +68,7 @@ class GatewayClient(LoggingClass): def heartbeat_task(self, interval): while True: - self.send(OPCode.HEARTBEAT, self.seq) + self._send(OPCode.HEARTBEAT, self.seq) gevent.sleep(interval / 1000) def handle_dispatch(self, packet): @@ -68,7 +77,7 @@ class GatewayClient(LoggingClass): self.client.events.emit(obj.__class__.__name__, obj) def handle_heartbeat(self, packet): - self.send(OPCode.HEARTBEAT, self.seq) + self._send(OPCode.HEARTBEAT, self.seq) def handle_reconnect(self, packet): self.log.warning('Received RECONNECT request, forcing a fresh reconnect') diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 18e505b..a31b027 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -9,7 +9,7 @@ from disco.types.message import Message from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role -from disco.types.base import Model, ModelMeta, Field, snowflake, listof +from disco.types.base import SlottedModel, ModelMeta, Field, snowflake, listof, lazy_datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -25,7 +25,7 @@ class GatewayEventMeta(ModelMeta): return obj -class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): +class GatewayEvent(six.with_metaclass(GatewayEventMeta, SlottedModel)): """ The GatewayEvent class wraps various functionality for events passed to us over the gateway websocket, and serves as a simple proxy to inner values for @@ -167,7 +167,7 @@ class ChannelPinsUpdate(GatewayEvent): Sent when a channel's pins are updated. """ channel_id = Field(snowflake) - last_pin_timestamp = Field(int) + last_pin_timestamp = Field(lazy_datetime) @wraps_model(User) diff --git a/disco/state.py b/disco/state.py index 3356f5b..442ae76 100644 --- a/disco/state.py +++ b/disco/state.py @@ -3,8 +3,9 @@ import inflection from collections import defaultdict, deque, namedtuple from weakref import WeakValueDictionary +from gevent.event import Event -from disco.gateway.packets import OPCode +from disco.util.config import Config class StackMessage(namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])): @@ -23,7 +24,7 @@ class StackMessage(namedtuple('StackMessage', ['id', 'channel_id', 'author_id']) """ -class StateConfig(object): +class StateConfig(Config): """ A configuration object for determining how the State tracking behaves. @@ -43,10 +44,16 @@ class StateConfig(object): total number of possible :class:`StackMessage` objects kept in memory, using: `total_mesages_size * total_channels`. This can be tweaked based on usage to help prevent memory pressure. + sync_guild_members : bool + If true, guilds will be automatically synced when they are initially loaded + or joined. Generally this setting is OK for smaller bots, however bots in over + 50 guilds will notice this operation can take a while to complete. """ track_messages = True track_messages_size = 100 + sync_guild_members = True + class State(object): """ @@ -84,9 +91,12 @@ class State(object): 'PresenceUpdate' ] - def __init__(self, client, config=None): + def __init__(self, client, config): self.client = client - self.config = config or StateConfig() + self.config = config + + self.ready = Event() + self.guilds_waiting_sync = 0 self.me = None self.dms = {} @@ -129,6 +139,7 @@ class State(object): def on_ready(self, event): self.me = event.user + self.guilds_waiting_sync = len(event.guilds) def on_message_create(self, event): if self.config.track_messages: @@ -158,26 +169,27 @@ class State(object): self.messages[event.channel_id].remove(sm) def on_guild_create(self, event): + if event.unavailable is False: + self.guilds_waiting_sync -= 1 + if self.guilds_waiting_sync <= 0: + self.ready.set() + self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) for member in six.itervalues(event.guild.members): self.users[member.user.id] = member.user - # Request full member list - self.client.gw.send(OPCode.REQUEST_GUILD_MEMBERS, { - 'guild_id': event.guild.id, - 'query': '', - 'limit': 0, - }) + if self.config.sync_guild_members: + event.guild.sync() def on_guild_update(self, event): self.guilds[event.guild.id].update(event.guild) def on_guild_delete(self, event): - if event.guild_id in self.guilds: + if event.id in self.guilds: # Just delete the guild, channel references will fall - del self.guilds[event.guild_id] + del self.guilds[event.id] def on_channel_create(self, event): if event.channel.is_guild and event.channel.guild_id in self.guilds: @@ -192,14 +204,16 @@ class State(object): self.channels[event.channel.id].update(event.channel) def on_channel_delete(self, event): - if event.channel.is_guild and event.channel.guild_id in self.guilds: - del self.guilds[event.channel.id] - elif event.channel.is_dm: + if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels: + del event.channel.guild.channels[event.channel.id] + elif event.channel.is_dm and event.channel.id in self.dms: del self.dms[event.channel.id] def on_voice_state_update(self, event): # Happy path: we have the voice state and want to update/delete it guild = self.guilds.get(event.state.guild_id) + if not guild: + return if event.state.session_id in guild.voice_states: if event.state.channel_id: @@ -218,14 +232,12 @@ class State(object): if event.member.guild_id not in self.guilds: return - event.member.guild = self.guilds[event.member.guild_id] self.guilds[event.member.guild_id].members[event.member.id] = event.member def on_guild_member_update(self, event): if event.member.guild_id not in self.guilds: return - event.member.guild = self.guilds[event.member.guild_id] self.guilds[event.member.guild_id].members[event.member.id].update(event.member) def on_guild_member_remove(self, event): @@ -243,10 +255,10 @@ class State(object): guild = self.guilds[event.guild_id] for member in event.members: - member.guild = guild member.guild_id = guild.id guild.members[member.id] = member self.users[member.id] = member.user + guild.synced = True def on_guild_role_create(self, event): if event.guild_id not in self.guilds: diff --git a/disco/types/base.py b/disco/types/base.py index 855ddfe..b271ddb 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -35,14 +35,14 @@ class FieldType(object): class Field(FieldType): - def __init__(self, typ, alias=None): + def __init__(self, typ, alias=None, default=None): super(Field, self).__init__(typ) # Set names self.src_name = alias self.dst_name = None - self.default = None + self.default = default if isinstance(self.typ, FieldType): self.default = self.typ.default @@ -110,6 +110,21 @@ def dictof(*args, **kwargs): return _Dict(*args, **kwargs) +def lazy_datetime(data): + if not data: + return property(lambda: None) + + def get(): + for fmt in DATETIME_FORMATS: + try: + return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) + except (ValueError, TypeError): + continue + raise ValueError('Failed to conver `{}` to datetime'.format(data)) + + return property(get) + + def datetime(data): if not data: return None @@ -155,17 +170,33 @@ def with_hash(field): return T +# Resolution hacks :( +Model = None +SlottedModel = None + + class ModelMeta(type): def __new__(cls, name, parents, dct): fields = {} + for parent in parents: + if issubclass(parent, Model) and parent != Model: + fields.update(parent._fields) + for k, v in six.iteritems(dct): if not isinstance(v, Field): continue v.set_name(k) fields[k] = v - dct[k] = None + + dct = {k: v for k, v in six.iteritems(dct) if k not in fields} + + if SlottedModel in parents and '__slots__' not in dct: + dct['__slots__'] = tuple(fields.keys()) + elif '__slots__' in dct and Model in parents and SlottedModel: + dct['__slots__'] = tuple(dct['__slots__']) + parents = tuple([SlottedModel] + [i for i in parents if i != Model]) dct['_fields'] = fields return super(ModelMeta, cls).__new__(cls, name, parents, dct) @@ -182,17 +213,20 @@ class Model(six.with_metaclass(ModelMeta)): else: obj = kwargs - for name, field in six.iteritems(self._fields): - if field.src_name not in obj or not obj[field.src_name]: + for name, field in six.iteritems(self.__class__._fields): + if field.src_name not in obj or obj[field.src_name] is None: if field.has_default(): - setattr(self, field.dst_name, field.default()) + default = field.default() if callable(field.default) else field.default + else: + default = None + setattr(self, field.dst_name, default) continue value = field.try_convert(obj[field.src_name], self.client) setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self._fields): + for name in six.iterkeys(self.__class__._fields): value = getattr(other, name) if value: setattr(self, name, value) @@ -206,7 +240,7 @@ class Model(six.with_metaclass(ModelMeta)): pass def to_dict(self): - return {k: getattr(self, k) for k in six.iterkeys(self._fields)} + return {k: getattr(self, k) for k in six.iterkeys(self.__class__._fields)} @classmethod def create(cls, client, data, **kwargs): @@ -227,3 +261,7 @@ class Model(six.with_metaclass(ModelMeta)): except: # TODO: wtf pass + + +class SlottedModel(Model): + __slots__ = ['client'] diff --git a/disco/types/channel.py b/disco/types/channel.py index 726af4a..7d5cffe 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -39,6 +39,7 @@ class PermissionOverwrite(Model): denied : :class:`PermissionValue` All denied permissions """ + __slots__ = ['id', 'type', 'allow', 'deny', 'channel', 'channel_id'] id = Field(snowflake) type = Field(enum(PermissionOverwriteType)) diff --git a/disco/types/guild.py b/disco/types/guild.py index 6948e0d..1f3d745 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -2,10 +2,11 @@ import six from holster.enum import Enum +from disco.gateway.packets import OPCode from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property -from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum +from disco.types.base import Model, Field, snowflake, listof, dictof, lazy_datetime, text, binary, enum from disco.types.user import User from disco.types.voice import VoiceState from disco.types.channel import Channel @@ -38,6 +39,8 @@ class Emoji(Model): roles : list(snowflake) Roles this emoji is attached to. """ + __slots__ = ['id', 'name', 'require_colons', 'managed', 'roles', 'guild', 'guild_id'] + id = Field(snowflake) name = Field(text) require_colons = Field(bool) @@ -66,6 +69,11 @@ class Role(Model): position : int The position of this role in the hierarchy. """ + __slots__ = [ + 'id', 'name', 'hoist', 'managed', 'color', 'permissions', 'position', 'mentionable', + 'guild', 'guild_id' + ] + id = Field(snowflake) name = Field(text) hoist = Field(bool) @@ -108,12 +116,16 @@ class GuildMember(Model): roles : list(snowflake) Roles this member is part of. """ + __slots__ = [ + 'user', 'guild_id', 'nick', 'mute', 'deaf', 'joined_at', 'roles', 'guild' + ] + user = Field(User) guild_id = Field(snowflake) nick = Field(text) mute = Field(bool) deaf = Field(bool) - joined_at = Field(datetime) + joined_at = Field(lazy_datetime) roles = Field(listof(snowflake)) def get_voice_state(self): @@ -242,6 +254,8 @@ class Guild(Model, Permissible): emojis = Field(dictof(Emoji, key='id')) voice_states = Field(dictof(VoiceState, key='session_id')) + synced = Field(bool, default=False) + def __init__(self, *args, **kwargs): super(Guild, self).__init__(*args, **kwargs) @@ -326,3 +340,13 @@ class Guild(Model, Permissible): 'hoist': role.hoist, 'mentionable': role.mentionable, }) + + def sync(self): + if self.synced: + return + + self.client.gw.send(OPCode.REQUEST_GUILD_MEMBERS, { + 'guild_id': self.id, + 'query': '', + 'limit': 0, + }) diff --git a/disco/types/invite.py b/disco/types/invite.py index 2bdb514..709f668 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import Model, Field, datetime +from disco.types.base import Model, Field, lazy_datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -37,4 +37,4 @@ class Invite(Model): max_uses = Field(int) uses = Field(int) temporary = Field(bool) - created_at = Field(datetime) + created_at = Field(lazy_datetime) diff --git a/disco/types/message.py b/disco/types/message.py index 52e2bb6..8548867 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,7 @@ import re from holster.enum import Enum -from disco.types.base import Model, Field, snowflake, text, datetime, dictof, listof, enum +from disco.types.base import Model, Field, snowflake, text, lazy_datetime, dictof, listof, enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -113,8 +113,8 @@ class Message(Model): author = Field(User) content = Field(text) nonce = Field(snowflake) - timestamp = Field(datetime) - edited_timestamp = Field(datetime) + timestamp = Field(lazy_datetime) + edited_timestamp = Field(lazy_datetime) tts = Field(bool) mention_everyone = Field(bool) pinned = Field(bool) diff --git a/disco/types/user.py b/disco/types/user.py index 0a1d006..8bb4a0c 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -4,6 +4,8 @@ from disco.types.base import Model, Field, snowflake, text, binary, with_equalit class User(Model, with_equality('id'), with_hash('id')): + __slots__ = ['id', 'username', 'discriminator', 'avatar', 'verified', 'email', 'presence'] + id = Field(snowflake) username = Field(text) discriminator = Field(str) @@ -42,12 +44,16 @@ Status = Enum( class Game(Model): + __slots__ = ['type', 'name', 'url'] + type = Field(GameType) name = Field(text) url = Field(text) class Presence(Model): + __slots__ = ['user', 'game', 'status'] + user = Field(User) game = Field(Game) status = Field(Status) diff --git a/disco/types/voice.py b/disco/types/voice.py index aadfa58..84cb092 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -12,6 +12,11 @@ class VoiceState(Model): self_mute = Field(bool) suppress = Field(bool) + __slots__ = [ + 'session_id', 'guild_id', 'channel_id', 'user_id', 'deaf', 'mute', 'self_deaf', + 'self_mute', 'suppress' + ] + @property def guild(self): return self.client.state.guilds.get(self.guild_id) diff --git a/disco/util/config.py b/disco/util/config.py index de98947..29147c2 100644 --- a/disco/util/config.py +++ b/disco/util/config.py @@ -13,6 +13,9 @@ class Config(object): if obj: self.__dict__.update(obj) + def get(self, key, default=None): + return self.__dict__.get(key, default) + @classmethod def from_file(cls, path): inst = cls() diff --git a/disco/util/functional.py b/disco/util/functional.py index 8641457..af72be3 100644 --- a/disco/util/functional.py +++ b/disco/util/functional.py @@ -1,5 +1,3 @@ -from gevent.lock import RLock - from six.moves import range NO_MORE_SENTINEL = object() @@ -50,46 +48,22 @@ def one_or_many(f): return _f -def cached_property(f): - """ - Creates a cached class property out of ``f``. When the property is resolved - for the first time, the function will be called and its result will be cached. - Subsequent calls will return the cached value. If this property is set, the - cached value will be replaced (or set initially) with the value provided. If - this property is deleted, the cache will be cleared and the next call will - refill it with a new value. +class CachedSlotProperty(object): + __slots__ = ['name', 'function', '__doc__'] - Notes - ----- - This function is greenlet safe. + def __init__(self, name, function): + self.name = name + self.function = function + self.__doc__ = getattr(function, '__doc__') - Args - ---- - f : function - The function to wrap. + def __get__(self, instance, owner): + if instance is None: + return self - Returns - ------- - property - The cached property created. - """ - lock = RLock() - value_name = '_' + f.__name__ + value = self.function(instance) + setattr(instance, self.name, value) + return value - def getf(self, *args, **kwargs): - if not hasattr(self, value_name): - with lock: - if hasattr(self, value_name): - return getattr(self, value_name) - setattr(self, value_name, f(self, *args, **kwargs)) - - return getattr(self, value_name) - - def setf(self, value): - setattr(self, value_name, value) - - def delf(self): - delattr(self, value_name) - - return property(getf, setf, delf) +def cached_property(f): + return CachedSlotProperty(f.__name__, f) diff --git a/disco/util/limiter.py b/disco/util/limiter.py new file mode 100644 index 0000000..6992832 --- /dev/null +++ b/disco/util/limiter.py @@ -0,0 +1,37 @@ +import time +import gevent + + +class SimpleLimiter(object): + def __init__(self, total, per): + self.total = total + self.per = per + + self.count = 0 + self.reset_at = 0 + + self.event = None + + def backoff(self): + self.event = gevent.event.Event() + gevent.sleep(self.reset_at - time.time()) + self.count = 0 + self.reset_at = 0 + self.event.set() + self.event = None + + def check(self): + if self.event: + self.event.wait() + + self.count += 1 + + if not self.reset_at: + self.reset_at = time.time() + self.per + return + elif self.reset_at < time.time(): + self.count = 1 + self.reset_at = time.time() + + if self.count > self.total and self.reset_at > time.time(): + self.backoff() diff --git a/disco/util/snowflake.py b/disco/util/snowflake.py index e642aa6..241e2de 100644 --- a/disco/util/snowflake.py +++ b/disco/util/snowflake.py @@ -29,3 +29,7 @@ def to_snowflake(i): return i.id raise Exception('{} ({}) is not convertable to a snowflake'.format(type(i), i)) + + +def calculate_shard(shard_count, guild_id): + return (guild_id >> 22) % shard_count