diff --git a/docs/index.rst b/docs/index.rst index 835b2b0..4aa2fde 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -226,6 +226,10 @@ subset of them: * ``after_event(*args)`` is called after the event handler with the event name, the namespace and the list of values the event handler returned. It may alter that values eventually. +* ``ignore_event`` is called before the middleware is applied to an + event handler with the event name and namespace as arguments. If its + return value resolves to ``True`` the middleware is not applied to that + particular event handler. If one of these methods returns something else than ``None``, execution is stopped at that point and the returned value is treated as if it was diff --git a/socketio/__init__.py b/socketio/__init__.py index ff68424..514583f 100644 --- a/socketio/__init__.py +++ b/socketio/__init__.py @@ -5,6 +5,7 @@ from .pubsub_manager import PubSubManager from .kombu_manager import KombuManager from .redis_manager import RedisManager from .server import Server +from .util import apply_middleware __all__ = [Middleware, Namespace, Server, BaseManager, PubSubManager, - KombuManager, RedisManager] + KombuManager, RedisManager, apply_middleware] diff --git a/socketio/namespace.py b/socketio/namespace.py index c778d1e..8fe4a20 100644 --- a/socketio/namespace.py +++ b/socketio/namespace.py @@ -36,7 +36,8 @@ class Namespace(object): # ... sio = socketio.Server() - sio.register_namespace("/chat", ChatNamespace) + ns = sio.register_namespace("/chat", ChatNamespace) + # ns now holds the instantiated ChatNamespace object """ def __init__(self, name, server): @@ -73,10 +74,7 @@ class Namespace(object): else: continue if _event_name == event_name: - extra_middlewares = getattr(attr, '_sio_middlewares', []) - return util._apply_middlewares( - self.middlewares + extra_middlewares, event_name, - self.name, attr) + return attr @staticmethod def event_name(name): @@ -86,14 +84,12 @@ class Namespace(object): def foo(self, sid, data): return "received: %s" % data - Note that you must not add third-party decorators after the ones - provided by this library because you'll otherwise loose metadata - that this decorators create. You can add them before instead. + Ensure that you only add well-behaving decorators after this one + (meaning such that preserve attributes) because you'll loose them + otherwise. """ + @util._simple_decorator def wrapper(handler): - def wrapped_handler(*args, **kwargs): - return handler(*args, **kwargs) - util._copy_sio_properties(handler, wrapped_handler) - wrapped_handler._sio_event_name = name - return wrapped_handler + handler._sio_event_name = name + return handler return wrapper diff --git a/socketio/server.py b/socketio/server.py index f1a0c3f..2215199 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -6,7 +6,6 @@ import six from . import base_manager from . import namespace as sio_namespace from . import packet -from . import util class Server(object): @@ -173,25 +172,6 @@ class Server(object): self.handlers[name] = namespace return namespace - def _get_event_handler(self, event, namespace): - """Returns the event handler for given ``event`` and ``namespace`` or - ``None``, if none exists. - - :param event: The event name the handler is required for. - :param namespace: The Socket.IO namespace for the event. - """ - handler = None - ns = self.handlers.get(namespace) - if isinstance(ns, sio_namespace.Namespace): - handler = ns._get_event_handler(event) - elif isinstance(ns, dict): - handler = ns.get(event) - if handler is not None: - extra_middlewares = getattr(handler, '_sio_middlewares', []) - return util._apply_middlewares( - self.middlewares + extra_middlewares, event, namespace, - handler) - def emit(self, event, data=None, room=None, skip_sid=None, namespace=None, callback=None): """Emit a custom event to one or more connected clients. @@ -466,10 +446,62 @@ class Server(object): def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" - handler = self._get_event_handler(event, namespace) + handler = None + middlewares = list(self.middlewares) + ns = self.handlers.get(namespace) + if isinstance(ns, sio_namespace.Namespace): + middlewares.extend(ns.middlewares) + handler = ns._get_event_handler(event) + elif isinstance(ns, dict): + handler = ns.get(event) if handler is not None: + middlewares.extend(getattr(handler, '_sio_middlewares', [])) + handler = self._apply_middlewares(middlewares, event, namespace, + handler) return handler(*args) + @staticmethod + def _apply_middlewares(middlewares, event, namespace, handler): + """Wraps the given handler with a wrapper that executes middlewares + before and after the real event handler.""" + + _middlewares = [] + for middleware in middlewares: + if isinstance(middleware, type): + middleware = middleware() + if not hasattr(middleware, 'ignore_event') or \ + not middleware.ignore_event(event, namespace): + _middlewares.append(middleware) + if not _middlewares: + return handler + + def wrapped(*args): + args = list(args) + + for middleware in _middlewares: + if hasattr(middleware, 'before_event'): + result = middleware.before_event(event, namespace, args) + if result is not None: + return result + + result = handler(*args) + if result is None: + data = [] + elif isinstance(result, tuple): + data = list(result) + else: + data = [result] + + for middleware in reversed(_middlewares): + if hasattr(middleware, 'after_event'): + result = middleware.after_event(event, namespace, data) + if result is not None: + return result + + return tuple(data) + + return wrapped + def _handle_eio_connect(self, sid, environ): """Handle the Engine.IO connection event.""" self.environ[sid] = environ diff --git a/socketio/util.py b/socketio/util.py index fd89ac7..5fd885c 100644 --- a/socketio/util.py +++ b/socketio/util.py @@ -1,48 +1,33 @@ -def _copy_sio_properties(from_func, to_func): - """Copies all properties starting with ``'_sio'`` from one function to - another.""" - for key in dir(from_func): - if key.startswith('_sio'): - setattr(to_func, key, getattr(from_func, key)) - - -def _apply_middlewares(middlewares, event, namespace, handler): - """Wraps the given handler with a wrapper that executes middlewares - before and after the real event handler.""" - if not middlewares: - return handler - - def wrapped(*args): - _middlewares = [] - for middleware in middlewares: - if isinstance(middleware, type): - _middlewares.append(middleware()) - else: - _middlewares.append(middleware) - - for middleware in _middlewares: - if hasattr(middleware, 'before_event'): - result = middleware.before_event(event, namespace, args) - if result is not None: - return result - - result = handler(*args) - if result is None: - data = [] - elif isinstance(result, tuple): - data = list(result) - else: - data = [result] - - for middleware in reversed(_middlewares): - if hasattr(middleware, 'after_event'): - result = middleware.after_event(event, namespace, data) - if result is not None: - return result +def _simple_decorator(decorator): + """This decorator can be used to turn simple functions + into well-behaved decorators, so long as the decorators + are fairly simple. If a decorator expects a function and + returns a function (no descriptors), and if it doesn't + modify function attributes or docstring, then it is + eligible to use this. Simply apply @_simple_decorator to + your decorator and it will automatically preserve the + docstring and function attributes of functions to which + it is applied. + + Also preserves all properties starting with ``'_sio'``. + """ + def copy_attrs(a, b): + """Copies attributes from a to b.""" + for attr_name in ('__name__', '__doc__'): + if hasattr(a, attr_name): + setattr(b, attr_name, getattr(a, attr_name)) + if hasattr(a, '__dict__') and hasattr(b, '__dict__'): + b.__dict__.update(a.__dict__) - return tuple(data) + def new_decorator(f): + g = decorator(f) + copy_attrs(f, g) + return g - return wrapped + # Now a few lines needed to make _simple_decorator itself + # be a well-behaved decorator. + copy_attrs(decorator, new_decorator) + return new_decorator def apply_middleware(middleware): @@ -51,10 +36,11 @@ def apply_middleware(middleware): :param middleware: The middleware to add - Note that you must not add third-party decorators after the ones - provided by this library because you'll otherwise loose metadata - that this decorators create. You can add them before instead. + Ensure that you only add well-behaving decorators after this one + (meaning such that preserve attributes) because you'll loose them + otherwise. """ + @_simple_decorator def wrapper(handler): if not hasattr(handler, '_sio_middlewares'): handler._sio_middlewares = [] diff --git a/tests/test_server.py b/tests/test_server.py index b11b530..1a137d8 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -11,6 +11,7 @@ else: from socketio import namespace from socketio import packet from socketio import server +from socketio import util @mock.patch('engineio.Server') @@ -62,6 +63,97 @@ class TestServer(unittest.TestCase): self.assertIsNotNone(s.handlers['/ns']._get_event_handler('foo bar')) self.assertIsNone(s.handlers['/ns']._get_event_handler('abc')) + def test_middleware(self, eio): + class MW: + def __init__(self): + self.ignore_event = mock.MagicMock(side_effect=100*[False]) + self.before_event = mock.MagicMock(side_effect=100*[None]) + self.after_event = mock.MagicMock(side_effect=100*[None]) + + mw1 = MW() + mw2 = MW() + mw3 = MW() + mw4 = MW() + mw4.ignore_event = mock.MagicMock(side_effect=[True]+100*[False]) + mw4.before_event = mock.MagicMock(side_effect=['x']+100*[None]) + mw4.after_event = mock.MagicMock(side_effect=['x']+100*[None]) + + class NS(namespace.Namespace): + def on_foo(self, sid): + pass + + @namespace.Namespace.event_name('foo bar') + @util.apply_middleware(mw4) + def some_name(self, sid): + pass + + s = server.Server() + s.middlewares.append(mw1) + + @s.on('abc') + @util.apply_middleware(mw2) + def abc(sid): + pass + + ns = s.register_namespace('/ns', NS) + ns.middlewares.append(mw3) + + # only mw1 and mw3 should run completely + s._trigger_event('foo', '/ns', '123') + self.assertEqual(mw1.before_event.call_count, 1) + self.assertEqual(mw1.after_event.call_count, 1) + self.assertEqual(mw2.before_event.call_count, 0) + self.assertEqual(mw2.after_event.call_count, 0) + self.assertEqual(mw3.before_event.call_count, 1) + self.assertEqual(mw3.after_event.call_count, 1) + self.assertEqual(mw4.before_event.call_count, 0) + self.assertEqual(mw4.after_event.call_count, 0) + + # only mw1 and mw3 should run completely, mw4 is enabled but ignored + s._trigger_event('foo bar', '/ns', '123') + self.assertEqual(mw1.before_event.call_count, 2) + self.assertEqual(mw1.after_event.call_count, 2) + self.assertEqual(mw2.before_event.call_count, 0) + self.assertEqual(mw2.after_event.call_count, 0) + self.assertEqual(mw3.before_event.call_count, 2) + self.assertEqual(mw3.after_event.call_count, 2) + self.assertEqual(mw4.before_event.call_count, 0) + self.assertEqual(mw4.after_event.call_count, 0) + + # again, this time mw4 before_event should be triggered + s._trigger_event('foo bar', '/ns', '123') + self.assertEqual(mw1.before_event.call_count, 3) + self.assertEqual(mw1.after_event.call_count, 2) + self.assertEqual(mw2.before_event.call_count, 0) + self.assertEqual(mw2.after_event.call_count, 0) + self.assertEqual(mw3.before_event.call_count, 3) + self.assertEqual(mw3.after_event.call_count, 2) + self.assertEqual(mw4.before_event.call_count, 1) + self.assertEqual(mw4.after_event.call_count, 0) + + # again, this time mw4 before + after_event should be triggered + # but after_event should abort execution + s._trigger_event('foo bar', '/ns', '123') + self.assertEqual(mw1.before_event.call_count, 4) + self.assertEqual(mw1.after_event.call_count, 2) + self.assertEqual(mw2.before_event.call_count, 0) + self.assertEqual(mw2.after_event.call_count, 0) + self.assertEqual(mw3.before_event.call_count, 4) + self.assertEqual(mw3.after_event.call_count, 2) + self.assertEqual(mw4.before_event.call_count, 2) + self.assertEqual(mw4.after_event.call_count, 1) + + # only mw1 and mw2 should run completely + s._trigger_event('abc', '/', '123') + self.assertEqual(mw1.before_event.call_count, 5) + self.assertEqual(mw1.after_event.call_count, 3) + self.assertEqual(mw2.before_event.call_count, 1) + self.assertEqual(mw2.after_event.call_count, 1) + self.assertEqual(mw3.before_event.call_count, 4) + self.assertEqual(mw3.after_event.call_count, 2) + self.assertEqual(mw4.before_event.call_count, 2) + self.assertEqual(mw4.after_event.call_count, 1) + def test_on_bad_event_name(self, eio): s = server.Server() self.assertRaises(ValueError, s.on, 'two-words')