Browse Source

cmclient: rework retry logic + updated tests

pull/168/head
Rossen Georgiev 6 years ago
parent
commit
55cdd15b9c
  1. 28
      steam/core/cm.py
  2. 57
      tests/test_core_cm.py

28
steam/core/cm.py

@ -123,32 +123,28 @@ class CMClient(EventEmitter):
self._LOG.debug("Connect initiated.") self._LOG.debug("Connect initiated.")
i = count() i = count(0)
while len(self.cm_servers) == 0: while len(self.cm_servers) == 0:
if self.auto_discovery: if not self.auto_discovery or (retry and next(i) >= retry):
if not self.cm_servers.bootstrap_from_webapi(): if not self.auto_discovery:
self.cm_servers.bootstrap_from_dns() self._LOG.error("CM server list is empty. Auto discovery is off.")
else:
self._LOG.error("CM server list is empty. Auto discovery is off.")
if not self.auto_discovery or (retry and next(i) > retry):
self._connecting = False self._connecting = False
return False return False
for i, server_addr in enumerate(cycle(self.cm_servers)): if not self.cm_servers.bootstrap_from_webapi():
if retry and i > retry: self.cm_servers.bootstrap_from_dns()
for i, server_addr in enumerate(cycle(self.cm_servers), start=next(i)-1):
if retry and i >= retry:
self._connecting = False self._connecting = False
return False return False
start = time() start = time()
if server_addr: if self.connection.connect(server_addr):
if self.connection.connect(server_addr): break
break self._LOG.debug("Failed to connect. Retrying...")
self._LOG.debug("Failed to connect. Retrying...")
else:
self._LOG.debug("No servers available. Retrying...")
diff = time() - start diff = time() - start

57
tests/test_core_cm.py

@ -45,28 +45,77 @@ class CMClient_Scenarios(unittest.TestCase):
patcher = patch('steam.core.cm.CMServerList', autospec=True) patcher = patch('steam.core.cm.CMServerList', autospec=True)
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
self.server_list = patcher.start().return_value self.server_list = patcher.start().return_value
self.server_list.__iter__.return_value = [(127001, 20000+i) for i in range(10)]
self.server_list.__iter__.return_value = [(127001, i+1) for i in range(10)] self.server_list.bootstrap_from_webapi.return_value = False
self.server_list.bootstrap_from_dns.return_value = False
@patch.object(CMClient, 'emit') @patch.object(CMClient, 'emit')
@patch.object(CMClient, '_recv_messages') @patch.object(CMClient, '_recv_messages')
def test_connect(self, mock_recv, mock_emit): def test_connect(self, mock_recv, mock_emit):
# setup # setup
self.conn.connect.return_value = True self.conn.connect.return_value = True
self.server_list.__len__.return_value = 10
# run # run
cm = CMClient() cm = CMClient()
with gevent.Timeout(2, False): with gevent.Timeout(2, False):
cm.connect() cm.connect(retry=1)
gevent.idle() gevent.idle()
# verify # verify
self.conn.connect.assert_called_once_with((127001, 1)) self.conn.connect.assert_called_once_with((127001, 20000))
mock_emit.assert_called_once_with('connected') mock_emit.assert_called_once_with('connected')
mock_recv.assert_called_once_with() mock_recv.assert_called_once_with()
@patch.object(CMClient, 'emit')
@patch.object(CMClient, '_recv_messages')
def test_connect_auto_discovery_failing(self, mock_recv, mock_emit):
# setup
self.conn.connect.return_value = True
self.server_list.__len__.return_value = 0
# run
cm = CMClient()
with gevent.Timeout(3, False):
cm.connect(retry=1)
gevent.idle()
# verify
self.server_list.bootstrap_from_webapi.assert_called_once_with()
self.server_list.bootstrap_from_dns.assert_called_once_with()
self.conn.connect.assert_not_called()
@patch.object(CMClient, 'emit')
@patch.object(CMClient, '_recv_messages')
def test_connect_auto_discovery_success(self, mock_recv, mock_emit):
# setup
self.conn.connect.return_value = True
self.server_list.__len__.return_value = 0
def fake_servers(*args, **kwargs):
self.server_list.__len__.return_value = 10
return True
self.server_list.bootstrap_from_webapi.side_effect = fake_servers
# run
cm = CMClient()
with gevent.Timeout(3, False):
cm.connect(retry=1)
gevent.idle()
# verify
self.server_list.bootstrap_from_webapi.assert_called_once_with()
self.server_list.bootstrap_from_dns.assert_not_called()
self.conn.connect.assert_called_once_with((127001, 20000))
mock_emit.assert_called_once_with('connected')
mock_recv.assert_called_once_with()
def test_channel_encrypt_sequence(self): def test_channel_encrypt_sequence(self):
# setup # setup

Loading…
Cancel
Save