Skip to content

Commit

Permalink
Set default for the safe_interval option of verdi computer configure
Browse files Browse the repository at this point in the history
The `safe_interval` is a common option for all `Transport` types but its
default value is class dependent. Since it is a common option it was
defined in the `Transport` base class, but since it is defined as a
class attribute, the default cannot be set yet. Keeping this design, the
only option to insert the class specific default is in `auth_options`
the class property that returns the options for the CLI. By adding the
default here, `verdi computer configure local localhost -n` will now
work without prompting for its only option `safe_interval` as it will
simply use the default.
  • Loading branch information
sphuber committed Nov 30, 2019
1 parent 76d680e commit 34bc1b9
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 32 deletions.
6 changes: 5 additions & 1 deletion aiida/transports/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,14 @@ def create_option(name, spec):
if spec.pop('switch', False):
option_name = '--{name}/--no-{name}'.format(name=name_dashed)
kwargs = {}
if 'default' not in spec:

if 'default' in spec:
kwargs['show_default'] = True
else:
kwargs['contextual_default'] = interactive_default(
'ssh', name, also_noninteractive=spec.pop('non_interactive_default', False)
)

kwargs['cls'] = InteractiveOption
kwargs.update(spec)
if existing_option:
Expand Down
9 changes: 2 additions & 7 deletions aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,16 @@ class LocalTransport(Transport):
# where the remote computer will rate limit the number of connections.
_DEFAULT_SAFE_OPEN_INTERVAL = 0.0

def __init__(self, **kwargs):
super(LocalTransport, self).__init__()
def __init__(self, *args, **kwargs):
super(LocalTransport, self).__init__(*args, **kwargs)
# The `_internal_dir` will emulate the concept of working directory, as the real current working directory is
# not to be changed to prevent bug-prone situations
self._is_open = False
self._internal_dir = None

# Just to avoid errors
self._machine = kwargs.pop('machine', None)
if self._machine and self._machine != 'localhost':
self.logger.debug('machine was passed, but it is not localhost')
self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL)

if kwargs:
raise ValueError('the following keywords are not recognized: {}'.format(kwargs))

def open(self):
"""
Expand Down
17 changes: 3 additions & 14 deletions aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ class SshTransport(Transport):
('key_policy', {'type': click.Choice(['RejectPolicy', 'WarningPolicy', 'AutoAddPolicy']), 'prompt': 'Key policy', 'help': 'SSH key policy', 'non_interactive_default': True})
]

# I set the (default) value here to 5 secs between consecutive SSH checks.
# This should be incremented to 30, probably.
_DEFAULT_SAFE_OPEN_INTERVAL = 5

@classmethod
def _get_username_suggestion_string(cls, computer):
"""
Expand Down Expand Up @@ -252,7 +248,7 @@ def _get_gss_host_suggestion_string(cls, computer):
def _get_safe_interval_suggestion_string(cls, computer):
return cls._DEFAULT_SAFE_OPEN_INTERVAL

def __init__(self, machine, **kwargs):
def __init__(self, *args, **kwargs):
"""
Initialize the SshTransport class.
Expand All @@ -268,21 +264,18 @@ def __init__(self, machine, **kwargs):
accepted paramiko.SSHClient.connect() params.
"""
import paramiko
super(SshTransport, self).__init__()
super(SshTransport, self).__init__(*args, **kwargs)

self._is_open = False
self._sftp = None
self._proxy = None

self._machine = machine
self._machine = kwargs.pop('machine')

self._client = paramiko.SSHClient()
self._load_system_host_keys = kwargs.pop('load_system_host_keys', False)
if self._load_system_host_keys:
self._client.load_system_host_keys()

self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL)

self._missing_key_policy = kwargs.pop('key_policy', 'RejectPolicy') # This is paramiko default
if self._missing_key_policy == 'RejectPolicy':
self._client.set_missing_host_key_policy(paramiko.RejectPolicy())
Expand All @@ -301,10 +294,6 @@ def __init__(self, machine, **kwargs):
except KeyError:
pass

if kwargs:
raise ValueError('The following parameters were not accepted by '
'the transport: {}'.format(','.join(str(k) for k in kwargs)))

def open(self):
"""
Open a SSHClient to the machine possibly using the parameters given
Expand Down
4 changes: 0 additions & 4 deletions aiida/transports/plugins/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def test_closed_connection(self):
t = LocalTransport()
t.listdir()

def test_invalid_param(self):
with self.assertRaises(ValueError):
LocalTransport(unrequired_var='something')

def test_basic(self):
with LocalTransport():
pass
Expand Down
4 changes: 0 additions & 4 deletions aiida/transports/plugins/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ def test_closed_connection_sftp(self):
t = SshTransport(machine='localhost')
t.listdir()

def test_invalid_param(self):
with self.assertRaises(ValueError):
SshTransport(machine='localhost', invalid_param=True)

def test_auto_add_policy(self):
with SshTransport(machine='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy'):
pass
Expand Down
15 changes: 13 additions & 2 deletions aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Transport(object):
"""
# pylint: disable=too-many-public-methods

_DEFAULT_SAFE_OPEN_INTERVAL = DEFAULT_TRANSPORT_INTERVAL

# To be defined in the subclass
# See the ssh or local plugin to see the format
_valid_auth_params = None
Expand All @@ -68,12 +70,11 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
__init__ method of the Transport base class.
"""
from aiida.common import AIIDA_LOGGER

self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL)
self._logger = AIIDA_LOGGER.getChild('transport').getChild(self.__class__.__name__)
self._logger_extra = None
self._is_open = False
self._enters = 0
self._safe_open_interval = DEFAULT_TRANSPORT_INTERVAL

def __enter__(self):
"""
Expand Down Expand Up @@ -186,6 +187,16 @@ def get_valid_auth_params(cls):

@classproperty
def auth_options(cls): # pylint: disable=no-self-argument
"""Return the authentication options to be used for building the CLI.
:return: `OrderedDict` of tuples, with first element option name and second dictionary of kwargs
"""
# The common auth options are currently defined as class members, but the default for `safe_interval` is sub
# class specific. With the current design the default cannot already be specified directly but has to be added
# manually here.
for option in cls._common_auth_options:
if option[0] == 'safe_interval':
option[1]['default'] = cls._DEFAULT_SAFE_OPEN_INTERVAL
return OrderedDict(cls._valid_auth_options + cls._common_auth_options)

@classmethod
Expand Down

0 comments on commit 34bc1b9

Please sign in to comment.