Browse Source

a few asyncio related fixes

pull/31/merge
Miguel Grinberg 8 years ago
parent
commit
763583226a
  1. 4
      examples/aiohttp/app.py
  2. 43
      socketio/asyncio_server.py
  3. 6
      socketio/server.py
  4. 2
      tests/test_server.py

4
examples/aiohttp/app.py

@ -13,7 +13,7 @@ async def background_task():
"""Example of how to send server generated events to clients.""" """Example of how to send server generated events to clients."""
count = 0 count = 0
while True: while True:
await asyncio.sleep(10) await sio.sleep(10)
count += 1 count += 1
await sio.emit('my response', {'data': 'Server generated event'}, await sio.emit('my response', {'data': 'Server generated event'},
namespace='/test') namespace='/test')
@ -84,5 +84,5 @@ app.router.add_get('/', index)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.ensure_future(background_task()) sio.start_background_task(background_task)
web.run_app(app) web.run_app(app)

43
socketio/asyncio_server.py

@ -106,7 +106,7 @@ class AsyncServer(server.Server):
to always leave this parameter with its default to always leave this parameter with its default
value of ``False``. value of ``False``.
Note: this method is asynchronous. Note: this method is a coroutine.
""" """
namespace = namespace or '/' namespace = namespace or '/'
self.logger.info('emitting event "%s" to %s [%s]', event, self.logger.info('emitting event "%s" to %s [%s]', event,
@ -148,7 +148,7 @@ class AsyncServer(server.Server):
to always leave this parameter with its default to always leave this parameter with its default
value of ``False``. value of ``False``.
Note: this method is asynchronous. Note: this method is a coroutine.
""" """
await self.emit('message', data, room, skip_sid, namespace, callback, await self.emit('message', data, room, skip_sid, namespace, callback,
**kwargs) **kwargs)
@ -160,7 +160,7 @@ class AsyncServer(server.Server):
:param namespace: The Socket.IO namespace to disconnect. If this :param namespace: The Socket.IO namespace to disconnect. If this
argument is omitted the default namespace is used. argument is omitted the default namespace is used.
Note: this method is asynchronous. Note: this method is a coroutine.
""" """
namespace = namespace or '/' namespace = namespace or '/'
if self.manager.is_connected(sid, namespace=namespace): if self.manager.is_connected(sid, namespace=namespace):
@ -185,18 +185,38 @@ class AsyncServer(server.Server):
This function returns the HTTP response body to deliver to the client This function returns the HTTP response body to deliver to the client
as a byte sequence. as a byte sequence.
Note: this method is asynchronous. Note: this method is a coroutine.
""" """
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
return await self.eio.handle_request(environ) return await self.eio.handle_request(environ)
def start_background_task(self, target, *args, **kwargs): def start_background_task(self, target, *args, **kwargs):
raise RuntimeError('Not implemented, use asyncio.') """Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute. Must be a coroutine.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
Note: this method is a coroutine.
"""
return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
def sleep(self, seconds=0): This is a utility function that applications can use to put a task to
raise RuntimeError('Not implemented, use asyncio.') sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await self.eio.sleep(seconds)
async def _emit_internal(self, sid, event, data, namespace=None, id=None): async def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client.""" """Send a message to a client."""
@ -301,6 +321,9 @@ class AsyncServer(server.Server):
async def _handle_eio_connect(self, sid, environ): async def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ self.environ[sid] = environ
return await self._handle_connect(sid, '/') return await self._handle_connect(sid, '/')

6
socketio/server.py

@ -345,9 +345,6 @@ class Server(object):
This function returns the HTTP response body to deliver to the client This function returns the HTTP response body to deliver to the client
as a byte sequence. as a byte sequence.
""" """
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
return self.eio.handle_request(environ, start_response) return self.eio.handle_request(environ, start_response)
def start_background_task(self, target, *args, **kwargs): def start_background_task(self, target, *args, **kwargs):
@ -485,6 +482,9 @@ class Server(object):
def _handle_eio_connect(self, sid, environ): def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event.""" """Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ self.environ[sid] = environ
return self._handle_connect(sid, '/') return self._handle_connect(sid, '/')

2
tests/test_server.py

@ -30,7 +30,6 @@ class TestServer(unittest.TestCase):
self.assertEqual(s.eio.on.call_count, 3) self.assertEqual(s.eio.on.call_count, 3)
self.assertEqual(s.binary, True) self.assertEqual(s.binary, True)
self.assertEqual(s.async_handlers, True) self.assertEqual(s.async_handlers, True)
self.assertEqual(mgr.initialize.call_count, 1)
def test_on_event(self, eio): def test_on_event(self, eio):
s = server.Server() s = server.Server()
@ -180,6 +179,7 @@ class TestServer(unittest.TestCase):
handler.assert_called_once_with('123', 'environ') handler.assert_called_once_with('123', 'environ')
s.manager.connect.assert_called_once_with('123', '/') s.manager.connect.assert_called_once_with('123', '/')
s.eio.send.assert_called_once_with('123', '0', binary=False) s.eio.send.assert_called_once_with('123', '0', binary=False)
self.assertEqual(mgr.initialize.call_count, 1)
def test_handle_connect_namespace(self, eio): def test_handle_connect_namespace(self, eio):
mgr = mock.MagicMock() mgr = mock.MagicMock()

Loading…
Cancel
Save