diff --git a/openwisp_controller/connection/connectors/ssh.py b/openwisp_controller/connection/connectors/ssh.py index 02a7fa797..2b16b4173 100644 --- a/openwisp_controller/connection/connectors/ssh.py +++ b/openwisp_controller/connection/connectors/ssh.py @@ -106,21 +106,47 @@ def connect(self): if not addresses: raise ValueError('No valid IP addresses to initiate connections found') for address in addresses: + try: + self._connect(address) + except Exception as e: + exception = e + else: + success = True + break + if not success: + self.disconnect() + raise exception + + def _connect(self, address): + """ + Tries to instantiate the SSH connection, + if the connection fails, it tries again + by disabling the new deafult HostKeyAlgorithms + used by newer versions of Paramiko + """ + params = self.params + for attempt in [1, 2]: try: self.shell.connect( address, auth_timeout=app_settings.SSH_AUTH_TIMEOUT, banner_timeout=app_settings.SSH_BANNER_TIMEOUT, timeout=app_settings.SSH_CONNECTION_TIMEOUT, - **self.params + **params ) - except Exception as e: - exception = e + except paramiko.ssh_exception.AuthenticationException as e: + # the authentication failure may be caused by the issue + # described at https://github.com/paramiko/paramiko/issues/1961 + # let's retry by disabling the new default HostKeyAlgorithms, + # which can work on older systems. + if e.args == ('Authentication failed.',) and attempt == 1: + params['disabled_algorithms'] = { + 'pubkeys': ['rsa-sha2-512', 'rsa-sha2-256'] + } + continue + raise e else: - success = True break - if not success: - raise exception def disconnect(self): self.shell.close() diff --git a/openwisp_controller/connection/tests/test_ssh.py b/openwisp_controller/connection/tests/test_ssh.py index fc28117d2..ef774a35d 100644 --- a/openwisp_controller/connection/tests/test_ssh.py +++ b/openwisp_controller/connection/tests/test_ssh.py @@ -1,8 +1,10 @@ import os +import sys from unittest import mock from django.conf import settings from django.test import TestCase +from paramiko.ssh_exception import AuthenticationException from swapper import load_model from ..connectors.ssh import logger as ssh_logger @@ -40,6 +42,22 @@ def test_connection_connect(self, mocked_debug): [mock.call('Executing command: echo test'), mock.call('test\n')] ) + @mock.patch('paramiko.SSHClient.close') + def test_connection_connect_auth_failure(self, mocked_ssh_close): + ckey = self._create_credentials_with_key(port=self.ssh_server.port) + dc = self._create_device_connection(credentials=ckey) + auth_failed = AuthenticationException('Authentication failed.') + with mock.patch( + 'paramiko.SSHClient.connect', side_effect=auth_failed + ) as mocked_connect: + dc.connect() + self.assertEqual(mocked_connect.call_count, 2) + self.assertFalse(dc.is_working) + mocked_ssh_close.assert_called_once() + if sys.version_info[0:2] > (3, 7): + self.assertNotIn('disabled_algorithms', mocked_connect.mock_calls[0].kwargs) + self.assertIn('disabled_algorithms', mocked_connect.mock_calls[1].kwargs) + @mock.patch.object(ssh_logger, 'info') @mock.patch.object(ssh_logger, 'debug') def test_connection_failed_command(self, mocked_debug, mocked_info):