Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new async fn Sftp::from_session_with_check_connection #117

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 62 additions & 59 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ on:
branches:
- main
paths-ignore:
- 'README.md'
- 'LICENSE'
- '.gitignore'
- "README.md"
- "LICENSE"
- ".gitignore"
pull_request:
paths-ignore:
- 'README.md'
- 'LICENSE'
- '.gitignore'
- "README.md"
- "LICENSE"
- ".gitignore"

jobs:
os-check:
Expand All @@ -34,7 +34,10 @@ jobs:
fail-fast: false
matrix:
include:
- { target: x86_64-pc-windows-msvc, args: "--exclude-features openssh" }
- {
target: x86_64-pc-windows-msvc,
args: "--exclude-features openssh",
}
- { target: x86_64-apple-darwin }
- { target: x86_64-unknown-linux-gnu }
steps:
Expand All @@ -57,65 +60,65 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'
- name: Install toolchain
run: |
rustup toolchain install stable --component rustfmt,clippy --no-self-update --profile minimal
rustup toolchain install nightly --no-self-update --profile minimal
- uses: actions/checkout@v4
with:
submodules: "recursive"
- name: Install toolchain
run: |
rustup toolchain install stable --component rustfmt,clippy --no-self-update --profile minimal
rustup toolchain install nightly --no-self-update --profile minimal

- name: Create Cargo.lock for caching
run: cargo update
- uses: Swatinem/rust-cache@v2
- name: Create Cargo.lock for caching
run: cargo update
- uses: Swatinem/rust-cache@v2

- run: ./check.sh
- run: ./check.sh

build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Install toolchain
run: rustup toolchain install stable --no-self-update --profile minimal

- name: Create Cargo.lock for caching
run: cargo update
- uses: Swatinem/rust-cache@v2

- name: Compile tests
run: cargo test --all-features --workspace --no-run

- name: Test ssh connectivity
run: |
# Wait for startup of openssh-server
timeout 15 ./wait_for_sshd_start_up.sh
chmod 600 .test-key
mkdir /tmp/openssh-rs
ssh -i .test-key -v -p 2222 -l test-user 127.0.0.1 -o StrictHostKeyChecking=accept-new -o UserKnownHostsFile=/tmp/openssh-rs/known_hosts whoami
- name: Set up ssh-agent
run: |
eval $(ssh-agent)
echo "SSH_AUTH_SOCK=$SSH_AUTH_SOCK" >> $GITHUB_ENV
echo "SSH_AGENT_PID=$SSH_AGENT_PID" >> $GITHUB_ENV
cat .test-key | ssh-add -

- name: Run tests
run: ./run_tests.sh
env:
XDG_RUNTIME_DIR: /tmp

- name: ssh container log
run: docker logs $(docker ps | grep openssh-server | awk '{print $1}')
if: ${{ failure() }}
- run: docker exec $(docker ps | grep openssh-server | awk '{print $1}') ls -R /config/logs/
if: ${{ failure() }}
- run: docker exec $(docker ps | grep openssh-server | awk '{print $1}') cat /config/logs/openssh/current
name: ssh server log
if: ${{ failure() }}
- uses: actions/checkout@v4
with:
submodules: "recursive"

- name: Install toolchain
run: rustup toolchain install stable --no-self-update --profile minimal

- name: Create Cargo.lock for caching
run: cargo update
- uses: Swatinem/rust-cache@v2

- name: Compile tests
run: cargo test --all-features --workspace --no-run

- name: Test ssh connectivity
run: |
# Wait for startup of openssh-server
timeout 15 ./wait_for_sshd_start_up.sh
chmod 600 .test-key
mkdir /tmp/openssh-rs
ssh -i .test-key -v -p 2222 -l test-user 127.0.0.1 -o StrictHostKeyChecking=accept-new -o UserKnownHostsFile=/tmp/openssh-rs/known_hosts whoami
- name: Set up ssh-agent
run: |
eval $(ssh-agent)
echo "SSH_AUTH_SOCK=$SSH_AUTH_SOCK" >> $GITHUB_ENV
echo "SSH_AGENT_PID=$SSH_AGENT_PID" >> $GITHUB_ENV
cat .test-key | ssh-add -

- name: Run tests
run: ./run_tests.sh
env:
XDG_RUNTIME_DIR: /tmp

- name: ssh container log
run: docker logs $(docker ps | grep openssh-server | awk '{print $1}')
if: ${{ failure() }}
- run: docker exec $(docker ps | grep openssh-server | awk '{print $1}') ls -R /config/logs/
if: ${{ failure() }}
- run: docker exec $(docker ps | grep openssh-server | awk '{print $1}') cat /config/logs/openssh/current
name: ssh server log
if: ${{ failure() }}
services:
openssh:
image: linuxserver/openssh-server:amd64-latest
Expand Down
1 change: 1 addition & 0 deletions check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cd "$(dirname "$(realpath "$0")")"

cargo fmt --all -- --check
cargo clippy --all-features --all --no-deps
cargo test --doc --all-features

export RUSTDOCFLAGS="--cfg docsrs"
exec cargo +nightly doc \
Expand Down
2 changes: 2 additions & 0 deletions src/changelog.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[allow(unused_imports)]
use crate::*;

/// # Added
/// - [`Sftp::from_session_with_check_connection`] for checking connection
#[doc(hidden)]
pub mod unreleased {}

Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ mod unix_timestamp;
pub use unix_timestamp::UnixTimeStamp;

mod sftp;
#[cfg(feature = "openssh")]
pub use sftp::OpensshSession;
use sftp::SftpHandle;
#[cfg(feature = "openssh")]
pub use sftp::{CheckOpensshConnection, OpensshSession};
pub use sftp::{Sftp, SftpAuxiliaryData};

#[cfg(feature = "openssh")]
pub use openssh;

mod options;
pub use options::SftpOptions;

Expand Down
2 changes: 1 addition & 1 deletion src/sftp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tokio_io_utility::assert_send;
mod openssh_session;

#[cfg(feature = "openssh")]
pub use openssh_session::OpensshSession;
pub use openssh_session::{CheckOpensshConnection, OpensshSession};

#[derive(Debug, destructure)]
pub(super) struct SftpHandle(SharedData);
Expand Down
127 changes: 118 additions & 9 deletions src/sftp/openssh_session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{future::Future, pin::Pin, sync::Arc};

use openssh::{ChildStdin, ChildStdout, Error as OpensshError, Session, Stdio};
use tokio::{sync::oneshot, task::JoinHandle};
Expand All @@ -9,6 +9,32 @@ use crate::{utils::ErrorExt, Error, Sftp, SftpAuxiliaryData, SftpOptions};
#[derive(Debug)]
pub struct OpensshSession(JoinHandle<Option<Error>>);

/// Check for openssh connection to be alive
pub trait CheckOpensshConnection {
/// This function should only return on `Err()`.
/// Once the sftp session is closed, the future will be cancelled (dropped).
fn check_connection<'session>(
self: Box<Self>,
session: &'session Session,
) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>>;
}

impl<F> CheckOpensshConnection for F
where
F: for<'session> FnOnce(
&'session Session,
) -> Pin<
Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>,
>,
{
fn check_connection<'session>(
self: Box<Self>,
session: &'session Session,
) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>> {
(self)(session)
}
}

impl Drop for OpensshSession {
fn drop(&mut self) {
self.0.abort();
Expand All @@ -17,11 +43,12 @@ impl Drop for OpensshSession {

#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "session_task", skip(tx))
tracing::instrument(name = "session_task", skip(tx, check_openssh_connection))
)]
async fn create_session_task(
session: Session,
tx: oneshot::Sender<Result<(ChildStdin, ChildStdout), OpensshError>>,
check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
) -> Option<Error> {
#[cfg(feature = "tracing")]
tracing::info!("Connecting to sftp subsystem, session = {session:?}");
Expand Down Expand Up @@ -54,15 +81,48 @@ async fn create_session_task(
let stdout = child.stdout().take().unwrap();
tx.send(Ok((stdin, stdout))).unwrap(); // Ok

let original_error = match child.wait().await {
Ok(exit_status) => {
if !exit_status.success() {
Some(Error::SftpServerFailure(exit_status))
let original_error = {
let check_conn_future = async {
if let Some(checker) = check_openssh_connection {
checker
.check_connection(&session)
.await
.err()
.map(Error::from)
} else {
None
}
};

let wait_on_child_future = async {
match child.wait().await {
Ok(exit_status) => {
if !exit_status.success() {
Some(Error::SftpServerFailure(exit_status))
} else {
None
}
}
Err(err) => Some(err.into()),
}
};
tokio::pin!(wait_on_child_future);

tokio::select! {
biased;

original_error = check_conn_future => {
let occuring_error = wait_on_child_future.await;
match (original_error, occuring_error) {
(Some(original_error), Some(occuring_error)) => {
Some(original_error.error_on_cleanup(occuring_error))
}
(Some(err), None) | (None, Some(err)) => Some(err),
(None, None) => None,
}
}
original_error = &mut wait_on_child_future => original_error,
}
Err(err) => Some(err.into()),
};

#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -99,10 +159,59 @@ impl Sftp {
pub async fn from_session(
session: openssh::Session,
options: SftpOptions,
) -> Result<Self, Error> {
Self::from_session_with_check_connection_inner(session, options, None).await
}

/// Similar to [`Sftp::from_session`], but takes an additional parameter
/// for checking if the connection is still alive.
///
/// # Example
///
/// ```rust,no_run
///
/// fn check_connection<'session>(
/// session: &'session openssh::Session,
/// ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), openssh::Error>> + Send + Sync + 'session>> {
/// Box::pin(async move {
/// loop {
/// tokio::time::sleep(std::time::Duration::from_secs(10)).await;
/// session.check().await?;
/// }
/// Ok(())
/// })
/// }
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), openssh_sftp_client::Error> {
/// openssh_sftp_client::Sftp::from_session_with_check_connection(
/// openssh::Session::connect_mux("[email protected]", openssh::KnownHosts::Strict).await?,
/// openssh_sftp_client::SftpOptions::default(),
/// check_connection,
/// ).await?;
/// # Ok(())
/// # }
/// ```
pub async fn from_session_with_check_connection(
session: openssh::Session,
options: SftpOptions,
check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
) -> Result<Self, Error> {
Self::from_session_with_check_connection_inner(
session,
options,
Some(Box::new(check_openssh_connection)),
)
.await
}

async fn from_session_with_check_connection_inner(
session: openssh::Session,
options: SftpOptions,
check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
) -> Result<Self, Error> {
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(create_session_task(session, tx));
let handle = tokio::spawn(create_session_task(session, tx, check_openssh_connection));

let msg = "Task failed without sending anything, so it must have panicked";

Expand All @@ -111,7 +220,7 @@ impl Sftp {
Err(_) => return Err(handle.await.expect_err(msg).into()),
};

Sftp::new_with_auxiliary(
Self::new_with_auxiliary(
stdin,
stdout,
options,
Expand Down