diff --git a/src/pty.rs b/src/pty.rs index cee5cd0..e516527 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -290,4 +290,9 @@ impl PTY { pub fn get_fd(&self) -> isize { self.pty.get_fd() } + + /// Wait for the process to exit/finish. + pub fn wait_for_exit(&self) -> Result { + self.pty.wait_for_exit() + } } diff --git a/src/pty/base.rs b/src/pty/base.rs index 176be45..30ba02a 100644 --- a/src/pty/base.rs +++ b/src/pty/base.rs @@ -1,12 +1,13 @@ /// Base struct used to generalize some of the PTY I/O operations. -use windows::Win32::Foundation::{HANDLE, S_OK, STATUS_PENDING, CloseHandle, WAIT_FAILED, WAIT_TIMEOUT}; +use windows::Win32::Foundation::{CloseHandle, HANDLE, STATUS_PENDING, S_OK, WAIT_FAILED, WAIT_OBJECT_0, WAIT_TIMEOUT}; use windows::Win32::Storage::FileSystem::{GetFileSizeEx, ReadFile, WriteFile}; use windows::Win32::System::Pipes::PeekNamedPipe; use windows::Win32::System::IO::CancelIoEx; use windows::Win32::System::Threading::{GetExitCodeProcess, GetProcessId, WaitForSingleObject}; use windows::Win32::Globalization::{MultiByteToWideChar, WideCharToMultiByte, CP_UTF8, MULTI_BYTE_TO_WIDE_CHAR_FLAGS}; use windows::core::{HRESULT, Error, PCSTR}; +use windows::Win32::System::Threading::INFINITE; use std::ptr; use std::sync::mpsc; @@ -150,6 +151,9 @@ pub trait PTYImpl: Sync + Send { /// Retrieve the process handle ID of the spawned program. fn get_fd(&self) -> isize; + + /// Wait for the process to exit/finish. + fn wait_for_exit(&self) -> Result; } @@ -284,6 +288,23 @@ fn is_alive(process: HANDLE) -> Result { } } +fn wait_for_exit(process: HANDLE) -> Result { + unsafe { + let wait_status = WaitForSingleObject(process, INFINITE); + let succ = wait_status != WAIT_FAILED; + if succ { + let dead = wait_status == WAIT_OBJECT_0; + Ok(dead) + } else { + let err: HRESULT = Error::from_win32().into(); + let result_msg = err.message(); + let string = OsString::from(result_msg); + Err(string) + } + } +} + + fn get_exitstatus(process: HANDLE) -> Result, OsString> { let mut exit = MaybeUninit::::uninit(); unsafe { @@ -388,7 +409,7 @@ impl PTYProcess { // let mut alive = reader_alive_rx.recv_timeout(Duration::from_millis(300)).unwrap_or(true); // alive = alive && !is_eof(process, conout).unwrap(); - while reader_alive_rx.try_recv().unwrap_or(true) { + while reader_alive_rx.recv_timeout(Duration::from_millis(100)).unwrap_or(true) { if !is_eof(process.into(), conout.into()).unwrap() { let result = read(4096, true, conout.into(), using_pipes); reader_out_tx.send(Some(result)).unwrap(); @@ -642,6 +663,11 @@ impl PTYProcess { self.process.0 as isize } + /// Wait for the process to exit + pub fn wait_for_exit(&self) -> Result { + wait_for_exit(self.process.into()) + } + } impl Drop for PTYProcess { diff --git a/src/pty/conpty/default_impl.rs b/src/pty/conpty/default_impl.rs index 80fae0c..5704f7b 100644 --- a/src/pty/conpty/default_impl.rs +++ b/src/pty/conpty/default_impl.rs @@ -46,4 +46,8 @@ impl PTYImpl for ConPTY { fn get_fd(&self) -> isize { -1 } + + fn wait_for_exit(&self) -> Result { + Err(OsString::from("pty_rs was compiled without ConPTY enabled")) + } } diff --git a/src/pty/conpty/pty_impl.rs b/src/pty/conpty/pty_impl.rs index e55f2bc..8a43374 100644 --- a/src/pty/conpty/pty_impl.rs +++ b/src/pty/conpty/pty_impl.rs @@ -358,6 +358,10 @@ impl PTYImpl for ConPTY { fn get_fd(&self) -> isize { self.process.get_fd() } + + fn wait_for_exit(&self) -> Result { + self.process.wait_for_exit() + } } impl Drop for ConPTY { diff --git a/src/pty/winpty/default_impl.rs b/src/pty/winpty/default_impl.rs index c9cb49f..31dfee5 100644 --- a/src/pty/winpty/default_impl.rs +++ b/src/pty/winpty/default_impl.rs @@ -44,4 +44,8 @@ impl PTYImpl for WinPTY { fn get_fd(&self) -> isize { -1 } + + fn wait_for_exit(&self) -> Result { + Err(OsString::from("winpty_rs was compiled without WinPTY enabled")) + } } diff --git a/src/pty/winpty/pty_impl.rs b/src/pty/winpty/pty_impl.rs index f02ac4c..ab3bf46 100644 --- a/src/pty/winpty/pty_impl.rs +++ b/src/pty/winpty/pty_impl.rs @@ -306,6 +306,10 @@ impl PTYImpl for WinPTY { fn get_fd(&self) -> isize { self.process.get_fd() } + + fn wait_for_exit(&self) -> Result { + self.process.wait_for_exit() + } } unsafe impl Send for WinPTY {} diff --git a/tests/conpty.rs b/tests/conpty.rs index 7d3b23a..08dcca1 100644 --- a/tests/conpty.rs +++ b/tests/conpty.rs @@ -189,3 +189,28 @@ fn is_alive_exitstatus_conpty() { assert!(!pty.is_alive().unwrap()); assert_eq!(pty.get_exitstatus().unwrap(), Some(0)) } + +#[test] +fn wait_for_exit() { + let pty_args = PTYArgs { + cols: 80, + rows: 25, + mouse_mode: MouseMode::WINPTY_MOUSE_MODE_NONE, + timeout: 10000, + agent_config: AgentConfig::WINPTY_FLAG_COLOR_ESCAPES + }; + + let appname = OsString::from("C:\\Windows\\System32\\cmd.exe"); + let mut pty = PTY::new_with_backend(&pty_args, PTYBackend::ConPTY).unwrap(); + pty.spawn(appname, None, None, None).unwrap(); + + pty.write("echo wait\r\n".into()).unwrap(); + assert!(pty.is_alive().unwrap()); + assert_eq!(pty.get_exitstatus().unwrap(), None); + + pty.write("exit\r\n".into()).unwrap(); + pty.wait_for_exit(); + + assert!(!pty.is_alive().unwrap()); + assert_eq!(pty.get_exitstatus().unwrap(), Some(0)) +} diff --git a/tests/winpty.rs b/tests/winpty.rs index 8c402ec..9fd7573 100644 --- a/tests/winpty.rs +++ b/tests/winpty.rs @@ -171,3 +171,29 @@ fn is_alive_exitstatus_winpty() { assert!(!pty.is_alive().unwrap()); assert_eq!(pty.get_exitstatus().unwrap(), Some(0)) } + + +#[test] +fn wait_for_exit() { + let pty_args = PTYArgs { + cols: 80, + rows: 25, + mouse_mode: MouseMode::WINPTY_MOUSE_MODE_NONE, + timeout: 10000, + agent_config: AgentConfig::WINPTY_FLAG_COLOR_ESCAPES + }; + + let appname = OsString::from("C:\\Windows\\System32\\cmd.exe"); + let mut pty = PTY::new_with_backend(&pty_args, PTYBackend::WinPTY).unwrap(); + pty.spawn(appname, None, None, None).unwrap(); + + pty.write("echo wait\r\n".into()).unwrap(); + assert!(pty.is_alive().unwrap()); + assert_eq!(pty.get_exitstatus().unwrap(), None); + + pty.write("exit\r\n".into()).unwrap(); + pty.wait_for_exit(); + + assert!(!pty.is_alive().unwrap()); + assert_eq!(pty.get_exitstatus().unwrap(), Some(0)) +}