Skip to content

Commit

Permalink
Remove unused SSHRemote (#6596)
Browse files Browse the repository at this point in the history
  • Loading branch information
achamayou authored Oct 29, 2024
1 parent 6e08bca commit 8ed3237
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 295 deletions.
3 changes: 0 additions & 3 deletions tests/infra/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import time

import logging
from contextlib import contextmanager
from enum import Enum, IntEnum, auto
from infra.clients import flush_info
Expand Down Expand Up @@ -34,8 +33,6 @@
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.backends import default_backend

logging.getLogger("paramiko").setLevel(logging.WARNING)

# JOIN_TIMEOUT should be greater than the worst case quote verification time (~ 25 secs)
JOIN_TIMEOUT = 40

Expand Down
290 changes: 0 additions & 290 deletions tests/infra/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
import os
import time
from enum import Enum, auto
import paramiko
import subprocess
from contextlib import contextmanager
import infra.path
import ctypes
import signal
import re
import stat
import shutil
from jinja2 import Environment, FileSystemLoader, select_autoescape
import json
Expand Down Expand Up @@ -43,21 +40,6 @@ def popen(*args, **kwargs):
return subprocess.Popen(*args, **kwargs)


@contextmanager
def sftp_session(hostname):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname)
try:
session = client.open_sftp()
try:
yield session
finally:
session.close()
finally:
client.close()


class CmdMixin(object):
def set_perf(self):
self.cmd = [
Expand All @@ -78,278 +60,6 @@ def _get_perf(self, lines):
raise ValueError(f"No performance result found (pattern is {pattern})")


class SSHRemote(CmdMixin):
def __init__(
self,
name,
hostname,
exe_files,
data_files,
cmd,
workspace,
common_dir,
env=None,
pid_file="pid.file",
**kwargs,
):
"""
Runs a command on a remote host, through an SSH connection. A temporary
directory is created, and some files can be shipped over. The command is
run out of that directory.
Note that the name matters, since the temporary directory that will be first
deleted, then created and populated is workspace/name. There is deliberately no
cleanup on shutdown, to make debugging/inspection possible.
setup() connects, creates the directory and ships over the files
start() runs the specified command
stop() disconnects, which shuts down the command via SIGHUP
"""
self.hostname = hostname
self.exe_files = exe_files
self.data_files = data_files
self.cmd = cmd
self.client = paramiko.SSHClient()
# this client (proc_client) is used to execute commands on the remote host since the main client uses pty
self.proc_client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.proc_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.common_dir = common_dir
self.root = os.path.join(workspace, name)
self.name = name
self.env = env or {}
self.out = os.path.join(self.root, "out")
self.err = os.path.join(self.root, "err")
self.suspension_proc = None
self.pid_file = pid_file
self._pid = None

@staticmethod
def make_host(host):
return host

@staticmethod
def get_node_address(addr):
return addr

def _rc(self, cmd):
LOG.info("[{}] {}".format(self.hostname, cmd))
_, stdout, _ = self.client.exec_command(cmd)
return stdout.channel.recv_exit_status()

def _connect(self):
LOG.debug("[{}] connect".format(self.hostname))
self.client.connect(self.hostname)
self.proc_client.connect(self.hostname)

def _setup_files(self):
assert self._rc("rm -rf {}".format(self.root)) == 0
assert self._rc("mkdir -p {}".format(self.root)) == 0
# For SSHRemote, both executable files (host and enclave) and data
# files (ledger, secrets) are copied to the remote
session = self.client.open_sftp()
for path in self.exe_files:
tgt_path = os.path.join(self.root, os.path.basename(path))
LOG.info("[{}] copy {} from {}".format(self.hostname, tgt_path, path))
session.put(path, tgt_path)
stat = os.stat(path)
session.chmod(tgt_path, stat.st_mode)
for path in self.data_files:
tgt_path = os.path.join(self.root, os.path.basename(path))
if os.path.isdir(path):
session.mkdir(tgt_path)
for f in os.listdir(path):
session.put(os.path.join(path, f), os.path.join(tgt_path, f))
else:
session.put(path, tgt_path)
LOG.info("[{}] copy {} from {}".format(self.hostname, tgt_path, path))
session.close()

def get(
self,
file_name,
dst_path,
timeout=FILE_TIMEOUT_S,
target_name=None,
pre_condition_func=lambda src_dir, _: True,
):
"""
Get file called `file_name` under the root of the remote. If the
file is missing, wait for timeout, and raise an exception.
If the file is present, it is copied to the CWD on the caller's
host, as `target_name` if it is set.
This call spins up a separate client because we don't want to interrupt
the main cmd that may be running.
"""
with sftp_session(self.hostname) as session:
end_time = time.time() + timeout
start_time = time.time()
while time.time() < end_time:
try:
target_name = target_name or file_name
fileattr = session.lstat(os.path.join(self.root, file_name))
if stat.S_ISDIR(fileattr.st_mode):
src_dir = os.path.join(self.root, file_name)
dst_dir = os.path.join(dst_path, file_name)
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
os.makedirs(dst_dir)
if not pre_condition_func(src_dir, session.listdir):
raise RuntimeError(
"Pre-condition for getting remote files failed"
)
for f in session.listdir(src_dir):
session.get(
os.path.join(src_dir, f), os.path.join(dst_dir, f)
)
else:
session.get(
os.path.join(self.root, file_name),
os.path.join(dst_path, target_name),
)
LOG.debug(
"[{}] found {} after {}s".format(
self.hostname, file_name, int(time.time() - start_time)
)
)
break
except FileNotFoundError:
time.sleep(0.1)
else:
raise ValueError(file_name)

def list_files(self, timeout=FILE_TIMEOUT_S):
files = []
with sftp_session(self.hostname) as session:
end_time = time.time() + timeout
while time.time() < end_time:
try:
files = session.listdir(self.root)

break
except Exception:
time.sleep(0.1)

else:
raise ValueError(self.root)
return files

def get_logs(self):
with sftp_session(self.hostname) as session:
for filepath in (self.err, self.out):
try:
local_file_name = "{}_{}_{}".format(
self.hostname, self.name, os.path.basename(filepath)
)
dst_path = os.path.join(self.common_dir, local_file_name)
session.get(filepath, dst_path)
LOG.info("Downloaded {}".format(dst_path))
except FileNotFoundError:
LOG.warning(
"Failed to download {} to {} (host: {})".format(
filepath, dst_path, self.hostname
)
)
return os.path.join(
self.common_dir, "{}_{}_out".format(self.hostname, self.name)
), os.path.join(self.common_dir, "{}_{}_err".format(self.hostname, self.name))

def start(self):
"""
Start cmd on the remote host. stdout and err are captured to file locally.
We create a pty on the remote host under which to run the command, so as to
get a SIGHUP on disconnection.
"""
cmd = self.get_cmd()
LOG.info("[{}] {}".format(self.hostname, cmd))
self.client.exec_command(cmd, get_pty=True)
self.pid()

def pid(self):
if self._pid is None:
pid_path = os.path.join(self.root, self.pid_file)
time_left = 3
while time_left > 0:
_, stdout, _ = self.proc_client.exec_command(f'cat "{pid_path}"')
res = stdout.read().strip()
if res:
self._pid = int(res)
break
time_left = max(time_left - 0.1, 0)
if not time_left:
raise TimeoutError("Failed to read PID from file")
time.sleep(0.1)
return self._pid

def suspend(self):
_, stdout, _ = self.proc_client.exec_command(f"kill -STOP {self.pid()}")
if stdout.channel.recv_exit_status() != 0:
raise RuntimeError(f"Remote {self.name} could not be suspended")

def resume(self):
_, stdout, _ = self.proc_client.exec_command(f"kill -CONT {self.pid()}")
if stdout.channel.recv_exit_status() != 0:
raise RuntimeError(f"Could not resume remote {self.name} from suspension!")

def sigterm(self):
_, stdout, _ = self.proc_client.exec_command(f"kill {self.pid()}")
if stdout.channel.recv_exit_status() != 0:
raise RuntimeError(f"Remote {self.name} could not deliver SIGTERM")

def stop(self):
"""
Disconnect the client, and therefore shut down the command as well.
"""
LOG.info("[{}] closing".format(self.hostname))
self.client.close()
self.proc_client.close()

def setup(self, **kwargs):
"""
Connect to the remote host, empty the temporary directory if it exsits,
and populate it with the initial set of files.
"""
self._connect()
self._setup_files()

def get_cmd(self):
env = " ".join(f"{key}={value}" for key, value in self.env.items())
cmd = " ".join(self.cmd)
return f"cd {self.root} && {env} {cmd} 1> {self.out} 2> {self.err} 0< /dev/null"

def debug_node_cmd(self):
cmd = " ".join(self.cmd)
return f"cd {self.root} && {DBG} --args {cmd}"

def _connect_new(self):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(self.hostname)
return client

def check_done(self):
client = self._connect_new()
try:
_, stdout, _ = client.exec_command(f"ps -p {self.pid()}")
return stdout.channel.recv_exit_status() == 1
finally:
client.close()

def get_result(self, line_count):
client = self._connect_new()
try:
_, stdout, _ = client.exec_command(f"tail -{line_count} {self.out}")
if stdout.channel.recv_exit_status() == 0:
lines = stdout.read().splitlines()
result = lines[-line_count:]
return self._get_perf(result)
finally:
client.close()


class LocalRemote(CmdMixin):
def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion tests/infra/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from loguru import logger as LOG

logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("paramiko").setLevel(logging.WARNING)


def minimum_number_of_local_nodes(args):
Expand Down
1 change: 0 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
wheel
paramiko
loguru
psutil
openapi-spec-validator
Expand Down

0 comments on commit 8ed3237

Please sign in to comment.