From 55cdd15b9cc26000dfe08009b4755bc863189515 Mon Sep 17 00:00:00 2001 From: Rossen Georgiev Date: Sat, 18 May 2019 17:22:40 +0100 Subject: [PATCH] cmclient: rework retry logic + updated tests --- steam/core/cm.py | 28 +++++++++------------ tests/test_core_cm.py | 57 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/steam/core/cm.py b/steam/core/cm.py index 76681fe..02a7a2d 100644 --- a/steam/core/cm.py +++ b/steam/core/cm.py @@ -123,32 +123,28 @@ class CMClient(EventEmitter): self._LOG.debug("Connect initiated.") - i = count() + i = count(0) while len(self.cm_servers) == 0: - if self.auto_discovery: - if not self.cm_servers.bootstrap_from_webapi(): - self.cm_servers.bootstrap_from_dns() - else: - self._LOG.error("CM server list is empty. Auto discovery is off.") - - if not self.auto_discovery or (retry and next(i) > retry): + if not self.auto_discovery or (retry and next(i) >= retry): + if not self.auto_discovery: + self._LOG.error("CM server list is empty. Auto discovery is off.") self._connecting = False return False - for i, server_addr in enumerate(cycle(self.cm_servers)): - if retry and i > retry: + if not self.cm_servers.bootstrap_from_webapi(): + self.cm_servers.bootstrap_from_dns() + + for i, server_addr in enumerate(cycle(self.cm_servers), start=next(i)-1): + if retry and i >= retry: self._connecting = False return False start = time() - if server_addr: - if self.connection.connect(server_addr): - break - self._LOG.debug("Failed to connect. Retrying...") - else: - self._LOG.debug("No servers available. Retrying...") + if self.connection.connect(server_addr): + break + self._LOG.debug("Failed to connect. Retrying...") diff = time() - start diff --git a/tests/test_core_cm.py b/tests/test_core_cm.py index ad374a8..b2c8413 100644 --- a/tests/test_core_cm.py +++ b/tests/test_core_cm.py @@ -45,28 +45,77 @@ class CMClient_Scenarios(unittest.TestCase): patcher = patch('steam.core.cm.CMServerList', autospec=True) self.addCleanup(patcher.stop) self.server_list = patcher.start().return_value - - self.server_list.__iter__.return_value = [(127001, i+1) for i in range(10)] + self.server_list.__iter__.return_value = [(127001, 20000+i) for i in range(10)] + self.server_list.bootstrap_from_webapi.return_value = False + self.server_list.bootstrap_from_dns.return_value = False @patch.object(CMClient, 'emit') @patch.object(CMClient, '_recv_messages') def test_connect(self, mock_recv, mock_emit): # setup self.conn.connect.return_value = True + self.server_list.__len__.return_value = 10 # run cm = CMClient() with gevent.Timeout(2, False): - cm.connect() + cm.connect(retry=1) gevent.idle() # verify - self.conn.connect.assert_called_once_with((127001, 1)) + self.conn.connect.assert_called_once_with((127001, 20000)) mock_emit.assert_called_once_with('connected') mock_recv.assert_called_once_with() + @patch.object(CMClient, 'emit') + @patch.object(CMClient, '_recv_messages') + def test_connect_auto_discovery_failing(self, mock_recv, mock_emit): + # setup + self.conn.connect.return_value = True + self.server_list.__len__.return_value = 0 + + # run + cm = CMClient() + + with gevent.Timeout(3, False): + cm.connect(retry=1) + + gevent.idle() + + # verify + self.server_list.bootstrap_from_webapi.assert_called_once_with() + self.server_list.bootstrap_from_dns.assert_called_once_with() + self.conn.connect.assert_not_called() + + @patch.object(CMClient, 'emit') + @patch.object(CMClient, '_recv_messages') + def test_connect_auto_discovery_success(self, mock_recv, mock_emit): + # setup + self.conn.connect.return_value = True + self.server_list.__len__.return_value = 0 + + def fake_servers(*args, **kwargs): + self.server_list.__len__.return_value = 10 + return True + + self.server_list.bootstrap_from_webapi.side_effect = fake_servers + + # run + cm = CMClient() + + with gevent.Timeout(3, False): + cm.connect(retry=1) + + gevent.idle() + + # verify + self.server_list.bootstrap_from_webapi.assert_called_once_with() + self.server_list.bootstrap_from_dns.assert_not_called() + self.conn.connect.assert_called_once_with((127001, 20000)) + mock_emit.assert_called_once_with('connected') + mock_recv.assert_called_once_with() def test_channel_encrypt_sequence(self): # setup