from functools import wraps
import threading
import time
from unittest import mock
import pytest
try:
    from engineio.async_socket import AsyncSocket as EngineIOSocket
except ImportError:
    from engineio.asyncio_socket import AsyncSocket as EngineIOSocket
import socketio
from socketio.exceptions import ConnectionError
from tests.asyncio_web_server import SocketIOWebServer


def with_instrumented_server(auth=False, **ikwargs):
    """This decorator can be applied to test functions or methods so that they
    run with a Socket.IO server that has been instrumented for the official
    Admin UI project. The arguments passed to the decorator are passed directly
    to the ``instrument()`` method of the server.
    """
    def decorator(f):
        @wraps(f)
        def wrapped(self, *args, **kwargs):
            sio = socketio.AsyncServer(async_mode='asgi')

            @sio.event
            async def enter_room(sid, data):
                await sio.enter_room(sid, data)

            @sio.event
            async def emit(sid, event):
                await sio.emit(event, skip_sid=sid)

            @sio.event(namespace='/foo')
            def connect(sid, environ, auth):
                pass

            async def shutdown():
                await self.isvr.shutdown()
                await sio.shutdown()

            if 'server_stats_interval' not in ikwargs:
                ikwargs['server_stats_interval'] = 0.25

            self.isvr = sio.instrument(auth=auth, **ikwargs)
            server = SocketIOWebServer(sio, on_shutdown=shutdown)
            server.start()

            # import logging
            # logging.getLogger('engineio.client').setLevel(logging.DEBUG)
            # logging.getLogger('socketio.client').setLevel(logging.DEBUG)

            original_schedule_ping = EngineIOSocket.schedule_ping
            EngineIOSocket.schedule_ping = mock.MagicMock()

            try:
                ret = f(self, *args, **kwargs)
            finally:
                server.stop()
                self.isvr.uninstrument()
                self.isvr = None

            EngineIOSocket.schedule_ping = original_schedule_ping

            # import logging
            # logging.getLogger('engineio.client').setLevel(logging.NOTSET)
            # logging.getLogger('socketio.client').setLevel(logging.NOTSET)

            return ret
        return wrapped
    return decorator


def _custom_auth(auth):
    return auth == {'foo': 'bar'}


async def _async_custom_auth(auth):
    return auth == {'foo': 'bar'}


class TestAsyncAdmin:
    def setup_method(self):
        print('threads at start:', threading.enumerate())
        self.thread_count = threading.active_count()

    def teardown_method(self):
        print('threads at end:', threading.enumerate())
        assert self.thread_count == threading.active_count()

    def _expect(self, expected, admin_client):
        events = {}
        while expected:
            data = admin_client.receive(timeout=5)
            if data[0] in expected:
                if expected[data[0]] == 1:
                    events[data[0]] = data[1]
                    del expected[data[0]]
                else:
                    expected[data[0]] -= 1
        return events

    def test_missing_auth(self):
        sio = socketio.AsyncServer(async_mode='asgi')
        with pytest.raises(ValueError):
            sio.instrument()

    @with_instrumented_server(auth=False)
    def test_admin_connect_with_no_auth(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin')
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'foo': 'bar'})

    @with_instrumented_server(auth={'foo': 'bar'})
    def test_admin_connect_with_dict_auth(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'foo': 'bar'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect(
                    'http://localhost:8900', namespace='/admin',
                    auth={'foo': 'baz'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect(
                    'http://localhost:8900', namespace='/admin')

    @with_instrumented_server(auth=[{'foo': 'bar'},
                                    {'u': 'admin', 'p': 'secret'}])
    def test_admin_connect_with_list_auth(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'foo': 'bar'})
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'u': 'admin', 'p': 'secret'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin', auth={'foo': 'baz'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin')

    @with_instrumented_server(auth=_custom_auth)
    def test_admin_connect_with_function_auth(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'foo': 'bar'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin', auth={'foo': 'baz'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin')

    @with_instrumented_server(auth=_async_custom_auth)
    def test_admin_connect_with_async_function_auth(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin',
                                 auth={'foo': 'bar'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin', auth={'foo': 'baz'})
        with socketio.SimpleClient() as admin_client:
            with pytest.raises(ConnectionError):
                admin_client.connect('http://localhost:8900',
                                     namespace='/admin')

    @with_instrumented_server()
    def test_admin_connect_only_admin(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin')
            sid = admin_client.sid
            events = self._expect({'config': 1, 'all_sockets': 1,
                                  'server_stats': 2}, admin_client)

        assert 'supportedFeatures' in events['config']
        assert 'ALL_EVENTS' in events['config']['supportedFeatures']
        assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
        assert 'EMIT' in events['config']['supportedFeatures']
        assert len(events['all_sockets']) == 1
        assert events['all_sockets'][0]['id'] == sid
        assert events['all_sockets'][0]['rooms'] == [sid]
        assert events['server_stats']['clientsCount'] == 1
        assert events['server_stats']['pollingClientsCount'] == 0
        assert len(events['server_stats']['namespaces']) == 3
        assert {'name': '/', 'socketsCount': 0} in \
            events['server_stats']['namespaces']
        assert {'name': '/foo', 'socketsCount': 0} in \
            events['server_stats']['namespaces']
        assert {'name': '/admin', 'socketsCount': 1} in \
            events['server_stats']['namespaces']

    @with_instrumented_server()
    def test_admin_connect_with_others(self):
        with socketio.SimpleClient() as client1, \
                socketio.SimpleClient() as client2, \
                socketio.SimpleClient() as client3, \
                socketio.SimpleClient() as admin_client:
            client1.connect('http://localhost:8900')
            client1.emit('enter_room', 'room')
            sid1 = client1.sid

            saved_check_for_upgrade = self.isvr._check_for_upgrade
            self.isvr._check_for_upgrade = mock.AsyncMock()
            client2.connect('http://localhost:8900', namespace='/foo',
                            transports=['polling'])
            sid2 = client2.sid
            self.isvr._check_for_upgrade = saved_check_for_upgrade

            client3.connect('http://localhost:8900', namespace='/admin')
            sid3 = client3.sid

            admin_client.connect('http://localhost:8900', namespace='/admin')
            sid = admin_client.sid
            events = self._expect({'config': 1, 'all_sockets': 1,
                                  'server_stats': 2}, admin_client)

        assert 'supportedFeatures' in events['config']
        assert 'ALL_EVENTS' in events['config']['supportedFeatures']
        assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
        assert 'EMIT' in events['config']['supportedFeatures']
        assert len(events['all_sockets']) == 4
        assert events['server_stats']['clientsCount'] == 4
        assert events['server_stats']['pollingClientsCount'] == 1
        assert len(events['server_stats']['namespaces']) == 3
        assert {'name': '/', 'socketsCount': 1} in \
            events['server_stats']['namespaces']
        assert {'name': '/foo', 'socketsCount': 1} in \
            events['server_stats']['namespaces']
        assert {'name': '/admin', 'socketsCount': 2} in \
            events['server_stats']['namespaces']

        for socket in events['all_sockets']:
            if socket['id'] == sid:
                assert socket['rooms'] == [sid]
            elif socket['id'] == sid1:
                assert socket['rooms'] == [sid1, 'room']
            elif socket['id'] == sid2:
                assert socket['rooms'] == [sid2]
            elif socket['id'] == sid3:
                assert socket['rooms'] == [sid3]

    @with_instrumented_server(mode='production', read_only=True)
    def test_admin_connect_production(self):
        with socketio.SimpleClient() as admin_client:
            admin_client.connect('http://localhost:8900', namespace='/admin')
            events = self._expect({'config': 1, 'server_stats': 2},
                                  admin_client)

        assert 'supportedFeatures' in events['config']
        assert 'ALL_EVENTS' not in events['config']['supportedFeatures']
        assert 'AGGREGATED_EVENTS' in events['config']['supportedFeatures']
        assert 'EMIT' not in events['config']['supportedFeatures']
        assert events['server_stats']['clientsCount'] == 1
        assert events['server_stats']['pollingClientsCount'] == 0
        assert len(events['server_stats']['namespaces']) == 3
        assert {'name': '/', 'socketsCount': 0} in \
            events['server_stats']['namespaces']
        assert {'name': '/foo', 'socketsCount': 0} in \
            events['server_stats']['namespaces']
        assert {'name': '/admin', 'socketsCount': 1} in \
            events['server_stats']['namespaces']

    @with_instrumented_server()
    def test_admin_features(self):
        with socketio.SimpleClient() as client1, \
                socketio.SimpleClient() as client2, \
                socketio.SimpleClient() as admin_client:
            client1.connect('http://localhost:8900')
            client2.connect('http://localhost:8900')
            admin_client.connect('http://localhost:8900', namespace='/admin')

            # emit from admin
            admin_client.emit(
                'emit', ('/', client1.sid, 'foo', {'bar': 'baz'}, 'extra'))
            data = client1.receive(timeout=5)
            assert data == ['foo', {'bar': 'baz'}, 'extra']

            # emit from regular client
            client1.emit('emit', 'foo')
            data = client2.receive(timeout=5)
            assert data == ['foo']

            # join and leave
            admin_client.emit('join', ('/', 'room', client1.sid))
            time.sleep(0.2)
            admin_client.emit(
                'emit', ('/', 'room', 'foo', {'bar': 'baz'}))
            data = client1.receive(timeout=5)
            assert data == ['foo', {'bar': 'baz'}]
            admin_client.emit('leave', ('/', 'room'))

            # disconnect
            admin_client.emit('_disconnect', ('/', False, client1.sid))
            for _ in range(10):
                if not client1.connected:
                    break
                time.sleep(0.2)
            assert not client1.connected