diff --git a/talpid-core/src/split_tunnel/windows/event.rs b/talpid-core/src/split_tunnel/windows/event.rs new file mode 100644 index 000000000000..7350a1bb7772 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/event.rs @@ -0,0 +1,213 @@ +//! Handle events dispatched from win-split-tunnel. +//! +//! It follows a typical inverted-call model, in which we request the next event, and block until +//! such an event has been received, or until a quit event is signaled. + +use super::driver; +use std::{ + collections::HashMap, + io, + path::Path, + sync::{Arc, RwLock}, + time::Duration, +}; +use talpid_types::{split_tunnel::ExcludedProcess, ErrorExt}; +use talpid_windows::{io::Overlapped, sync::Event}; +use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED; + +enum EventResult { + /// Result containing the next event. + Event(driver::EventId, driver::EventBody), + /// Quit event was signaled. + Quit, +} + +const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; + +/// Spawns an event loop thread that processes events from the driver service. +pub fn spawn_listener( + handle: Arc, + excluded_processes: Arc>>, +) -> io::Result<(std::thread::JoinHandle<()>, Arc)> { + let mut event_overlapped = Overlapped::new(Some(Event::new(true, false)?))?; + + let quit_event = Arc::new(Event::new(true, false)?); + let quit_event_copy = quit_event.clone(); + + let event_thread = std::thread::spawn(move || { + log::debug!("Starting split tunnel event thread"); + let mut data_buffer = vec![]; + + loop { + // Wait until either the next event is received or the quit event is signaled. + let (event_id, event_body) = match fetch_next_event( + &handle, + &quit_event, + &mut event_overlapped, + &mut data_buffer, + ) { + Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body), + Ok(EventResult::Quit) => break, + Err(error) => { + if error.raw_os_error() == Some(ERROR_OPERATION_ABORTED as i32) { + // The driver will normally abort the request if the driver state + // is reset. Give the driver service some time to recover before + // retrying. + std::thread::sleep(Duration::from_millis(500)); + } + continue; + } + }; + + handle_event(event_id, event_body, &excluded_processes); + } + + log::debug!("Stopping split tunnel event thread"); + }); + + Ok((event_thread, quit_event_copy)) +} + +fn fetch_next_event( + device: &Arc, + quit_event: &Event, + overlapped: &mut Overlapped, + data_buffer: &mut Vec, +) -> io::Result { + if unsafe { driver::wait_for_single_object(quit_event.as_raw(), Some(Duration::ZERO)) }.is_ok() + { + return Ok(EventResult::Quit); + } + + data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8); + + unsafe { + driver::device_io_control_buffer_async( + device, + driver::DriverIoctlCode::DequeEvent as u32, + None, + data_buffer.as_mut_ptr(), + u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"), + overlapped.as_mut_ptr(), + ) + } + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("DeviceIoControl failed to deque event") + ); + error + })?; + + let event_objects = [ + overlapped.get_event().unwrap().as_raw(), + quit_event.as_raw(), + ]; + + let signaled_object = unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) } + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("wait_for_multiple_objects failed") + ); + error + })?; + + if signaled_object == quit_event.as_raw() { + // Quit event was signaled + return Ok(EventResult::Quit); + } + + let returned_bytes = driver::get_overlapped_result(device, overlapped).map_err(|error| { + if error.raw_os_error() != Some(ERROR_OPERATION_ABORTED as i32) { + log::error!( + "{}", + error.display_chain_with_msg("get_overlapped_result failed for dequeued event"), + ); + } + error + })?; + + data_buffer + .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32")); + + driver::parse_event_buffer(data_buffer) + .map(|(id, body)| EventResult::Event(id, body)) + .map_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to parse ST event buffer") + ); + io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer") + }) +} + +fn handle_event( + event_id: driver::EventId, + event_body: driver::EventBody, + excluded_processes: &Arc>>, +) { + use driver::{EventBody, EventId}; + + let event_str = match &event_id { + EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { + "Start splitting process" + } + EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { + "Stop splitting process" + } + EventId::ErrorMessage => "ErrorMessage", + }; + + match event_body { + EventBody::SplittingEvent { + process_id, + reason, + image, + } => { + let mut pids = excluded_processes.write().unwrap(); + match event_id { + EventId::StartSplittingProcess => { + if let Some(prev_entry) = pids.get(&process_id) { + log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry); + } + pids.insert( + process_id, + ExcludedProcess { + pid: u32::try_from(process_id) + .expect("PID should be containable in a DWORD"), + image: Path::new(&image).to_path_buf(), + inherited: reason + .contains(driver::SplittingChangeReason::BY_INHERITANCE), + }, + ); + } + EventId::StopSplittingProcess => { + if pids.remove(&process_id).is_none() { + log::error!("Inconsistent process tree: {process_id} was not found"); + } + } + _ => (), + } + + log::trace!( + "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", + event_str, + process_id, + reason, + image, + ); + } + EventBody::SplittingError { process_id, image } => { + log::error!( + "FAILED: {}:\n\tpid: {}\n\timage: {:?}", + event_str, + process_id, + image, + ); + } + EventBody::ErrorMessage { status, message } => { + log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) + } + } +} diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 68030691a99b..4194edb4b694 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -1,17 +1,22 @@ mod driver; +mod event; mod path_monitor; +mod request; mod service; mod volume_monitor; mod windows; use crate::{tunnel::TunnelMetadata, tunnel_state_machine::TunnelCommand}; +use driver::DeviceHandle; use futures::channel::{mpsc, oneshot}; +use path_monitor::PathMonitor; +use request::{Request, RequestDetails}; use std::{ collections::HashMap, ffi::{OsStr, OsString}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr}, - path::{Path, PathBuf}, + path::PathBuf, sync::{ atomic::{AtomicBool, Ordering}, mpsc as sync_mpsc, Arc, Mutex, MutexGuard, RwLock, Weak, @@ -19,15 +24,13 @@ use std::{ time::Duration, }; use talpid_routing::{get_best_default_route, CallbackHandle, EventType, RouteManagerHandle}; -use talpid_types::{split_tunnel::ExcludedProcess, tunnel::ErrorStateCause, ErrorExt}; +use talpid_types::{split_tunnel::ExcludedProcess, ErrorExt}; use talpid_windows::{ - io::Overlapped, net::{get_ip_address_for_interface, AddressFamily}, sync::Event, }; -use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED; +use volume_monitor::VolumeMonitor; -const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; const RESERVED_IP_V4: Ipv4Addr = Ipv4Addr::new(192, 0, 2, 123); /// Errors that may occur in [`SplitTunnel`]. @@ -49,18 +52,10 @@ pub enum Error { #[error("Failed to set list of excluded applications")] SetConfiguration(#[source] io::Error), - /// Failed to obtain the current driver state - #[error("Failed to obtain the driver state")] - GetState(#[source] io::Error), - /// Failed to register interface IP addresses #[error("Failed to register IP addresses for exclusions")] RegisterIps(#[source] io::Error), - /// Failed to clear interface IP addresses - #[error("Failed to clear registered IP addresses")] - ClearIps(#[source] io::Error), - /// Failed to set up the driver event loop #[error("Failed to set up the driver event loop")] EventThreadError(#[source] io::Error), @@ -77,10 +72,6 @@ pub enum Error { #[error("Failed to register default route change callback")] RegisterRouteChangeCallback, - /// Unexpected IP parsing error - #[error("Failed to parse IP address")] - IpParseError, - /// The request handling thread is stuck #[error("The ST request thread is stuck")] RequestThreadStuck, @@ -96,33 +87,20 @@ pub enum Error { /// A previous path update has not yet completed #[error("A previous update is not yet complete")] AlreadySettingPaths, - - /// Resetting in the engaged state risks leaking into the tunnel - #[error("Failed to reset driver because it is engaged")] - CannotResetEngaged, } /// Manages applications whose traffic to exclude from the tunnel. pub struct SplitTunnel { runtime: tokio::runtime::Handle, - request_tx: RequestTx, + request_tx: sync_mpsc::Sender, event_thread: Option>, quit_event: Arc, excluded_processes: Arc>>, _route_change_callback: Option, - daemon_tx: Weak>, async_path_update_in_progress: Arc, route_manager: RouteManagerHandle, } -enum Request { - SetPaths(Vec), - RegisterIps(InterfaceAddresses), - Stop, -} -type RequestResponseTx = sync_mpsc::Sender>; -type RequestTx = sync_mpsc::Sender<(Request, RequestResponseTx)>; - const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); #[derive(Default, PartialEq, Clone)] @@ -152,13 +130,6 @@ impl SplitTunnelHandle { } } -enum EventResult { - /// Result containing the next event. - Event(driver::EventId, driver::EventBody), - /// Quit event was signaled. - Quit, -} - impl SplitTunnel { /// Initialize the split tunnel device. pub fn new( @@ -170,11 +141,40 @@ impl SplitTunnel { ) -> Result { let excluded_processes = Arc::new(RwLock::new(HashMap::new())); - let (request_tx, handle) = - Self::spawn_request_thread(resource_dir, volume_update_rx, excluded_processes.clone())?; + let (refresh_paths_tx, refresh_paths_rx) = sync_mpsc::channel(); - let (event_thread, quit_event) = - Self::spawn_event_listener(handle, excluded_processes.clone())?; + let path_monitor = + PathMonitor::spawn(refresh_paths_tx.clone()).map_err(Error::StartPathMonitor)?; + + let monitored_paths = Arc::new(Mutex::new(vec![])); + let volume_monitor = VolumeMonitor::spawn( + path_monitor.clone(), + refresh_paths_tx, + monitored_paths.clone(), + volume_update_rx, + ); + + let (request_tx, handle) = request::spawn_request_thread( + resource_dir, + daemon_tx, + path_monitor, + volume_monitor, + monitored_paths.clone(), + excluded_processes.clone(), + )?; + + let handle_copy = Arc::downgrade(&handle); + std::thread::spawn(move || { + while let Ok(()) = refresh_paths_rx.recv() { + let Some(handle) = handle_copy.upgrade() else { + return; + }; + Self::handle_volume_monitor_update(&handle, &monitored_paths); + } + }); + + let (event_thread, quit_event) = event::spawn_listener(handle, excluded_processes.clone()) + .map_err(Error::EventThreadError)?; Ok(SplitTunnel { runtime, @@ -182,373 +182,42 @@ impl SplitTunnel { event_thread: Some(event_thread), quit_event, _route_change_callback: None, - daemon_tx, async_path_update_in_progress: Arc::new(AtomicBool::new(false)), excluded_processes, route_manager, }) } - /// Spawns an event loop thread that processes events from the driver service. - fn spawn_event_listener( - handle: Arc, - excluded_processes: Arc>>, - ) -> Result<(std::thread::JoinHandle<()>, Arc), Error> { - let mut event_overlapped = Overlapped::new(Some( - Event::new(true, false).map_err(Error::EventThreadError)?, - )) - .map_err(Error::EventThreadError)?; - - let quit_event = Arc::new(Event::new(true, false).map_err(Error::EventThreadError)?); - let quit_event_copy = quit_event.clone(); - - let event_thread = std::thread::spawn(move || { - log::debug!("Starting split tunnel event thread"); - let mut data_buffer = vec![]; - - loop { - // Wait until either the next event is received or the quit event is signaled. - let (event_id, event_body) = match Self::fetch_next_event( - &handle, - &quit_event, - &mut event_overlapped, - &mut data_buffer, - ) { - Ok(EventResult::Event(event_id, event_body)) => (event_id, event_body), - Ok(EventResult::Quit) => break, - Err(error) => { - if error.raw_os_error() == Some(ERROR_OPERATION_ABORTED as i32) { - // The driver will normally abort the request if the driver state - // is reset. Give the driver service some time to recover before - // retrying. - std::thread::sleep(Duration::from_millis(500)); - } - continue; - } - }; - - Self::handle_event(event_id, event_body, &excluded_processes); - } - - log::debug!("Stopping split tunnel event thread"); - }); - - Ok((event_thread, quit_event_copy)) - } - - fn fetch_next_event( - device: &Arc, - quit_event: &Event, - overlapped: &mut Overlapped, - data_buffer: &mut Vec, - ) -> io::Result { - if unsafe { driver::wait_for_single_object(quit_event.as_raw(), Some(Duration::ZERO)) } - .is_ok() - { - return Ok(EventResult::Quit); + fn handle_volume_monitor_update( + handle: &DeviceHandle, + monitored_paths: &Arc>>, + ) { + let paths = monitored_paths.lock().unwrap(); + if paths.len() == 0 { + return; } - data_buffer.resize(DRIVER_EVENT_BUFFER_SIZE, 0u8); - - unsafe { - driver::device_io_control_buffer_async( - device, - driver::DriverIoctlCode::DequeEvent as u32, - None, - data_buffer.as_mut_ptr(), - u32::try_from(data_buffer.len()).expect("buffer must be smaller than u32"), - overlapped.as_mut_ptr(), - ) - } - .map_err(|error| { + log::debug!("Re-resolving excluded paths"); + if let Err(error) = handle.set_config(&paths) { log::error!( "{}", - error.display_chain_with_msg("DeviceIoControl failed to deque event") + error.display_chain_with_msg("Failed to update excluded paths") ); - error - })?; - - let event_objects = [ - overlapped.get_event().unwrap().as_raw(), - quit_event.as_raw(), - ]; - - let signaled_object = - unsafe { driver::wait_for_multiple_objects(&event_objects[..], false) }.map_err( - |error| { - log::error!( - "{}", - error.display_chain_with_msg("wait_for_multiple_objects failed") - ); - error - }, - )?; - - if signaled_object == quit_event.as_raw() { - // Quit event was signaled - return Ok(EventResult::Quit); } - - let returned_bytes = - driver::get_overlapped_result(device, overlapped).map_err(|error| { - if error.raw_os_error() != Some(ERROR_OPERATION_ABORTED as i32) { - log::error!( - "{}", - error.display_chain_with_msg( - "get_overlapped_result failed for dequeued event" - ), - ); - } - error - })?; - - data_buffer - .truncate(usize::try_from(returned_bytes).expect("usize must be no smaller than u32")); - - driver::parse_event_buffer(data_buffer) - .map(|(id, body)| EventResult::Event(id, body)) - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to parse ST event buffer") - ); - io::Error::new(io::ErrorKind::Other, "Failed to parse ST event buffer") - }) } - fn handle_event( - event_id: driver::EventId, - event_body: driver::EventBody, - excluded_processes: &Arc>>, - ) { - use driver::{EventBody, EventId}; - - let event_str = match &event_id { - EventId::StartSplittingProcess | EventId::ErrorStartSplittingProcess => { - "Start splitting process" - } - EventId::StopSplittingProcess | EventId::ErrorStopSplittingProcess => { - "Stop splitting process" - } - EventId::ErrorMessage => "ErrorMessage", - }; - - match event_body { - EventBody::SplittingEvent { - process_id, - reason, - image, - } => { - let mut pids = excluded_processes.write().unwrap(); - match event_id { - EventId::StartSplittingProcess => { - if let Some(prev_entry) = pids.get(&process_id) { - log::error!("PID collision: {process_id} is already in the list of excluded processes. New image: {:?}. Current image: {:?}", image, prev_entry); - } - pids.insert( - process_id, - ExcludedProcess { - pid: u32::try_from(process_id) - .expect("PID should be containable in a DWORD"), - image: Path::new(&image).to_path_buf(), - inherited: reason - .contains(driver::SplittingChangeReason::BY_INHERITANCE), - }, - ); - } - EventId::StopSplittingProcess => { - if pids.remove(&process_id).is_none() { - log::error!("Inconsistent process tree: {process_id} was not found"); - } - } - _ => (), - } - - log::trace!( - "{}:\n\tpid: {}\n\treason: {:?}\n\timage: {:?}", - event_str, - process_id, - reason, - image, - ); - } - EventBody::SplittingError { process_id, image } => { - log::error!( - "FAILED: {}:\n\tpid: {}\n\timage: {:?}", - event_str, - process_id, - image, - ); - } - EventBody::ErrorMessage { status, message } => { - log::error!("NTSTATUS {:#x}: {}", status, message.to_string_lossy()) - } - } - } - - fn spawn_request_thread( - resource_dir: PathBuf, - volume_update_rx: mpsc::UnboundedReceiver<()>, - excluded_processes: Arc>>, - ) -> Result<(RequestTx, Arc), Error> { - let (tx, rx): (RequestTx, _) = sync_mpsc::channel(); - let (init_tx, init_rx) = sync_mpsc::channel(); - - let monitored_paths = Arc::new(Mutex::new(vec![])); - let monitored_paths_copy = monitored_paths.clone(); - - let (monitor_tx, monitor_rx) = sync_mpsc::channel(); - - let path_monitor = path_monitor::PathMonitor::spawn(monitor_tx.clone()) - .map_err(Error::StartPathMonitor)?; - let volume_monitor = volume_monitor::VolumeMonitor::spawn( - path_monitor.clone(), - monitor_tx, - monitored_paths.clone(), - volume_update_rx, - ); - - std::thread::spawn(move || { - let init_fn = || { - service::install_driver_if_required(&resource_dir).map_err(Error::ServiceError)?; - driver::DeviceHandle::new() - .map(Arc::new) - .map_err(Error::InitializationError) - }; - - let handle = match init_fn() { - Ok(handle) => { - let _ = init_tx.send(Ok(handle.clone())); - handle - } - Err(error) => { - let _ = init_tx.send(Err(error)); - return; - } - }; - - let mut previous_addresses = InterfaceAddresses::default(); - - while let Ok((request, response_tx)) = rx.recv() { - let response = match request { - Request::SetPaths(paths) => { - let mut monitored_paths_guard = monitored_paths.lock().unwrap(); - - let result = if !paths.is_empty() { - handle.set_config(&paths).map_err(Error::SetConfiguration) - } else { - handle.clear_config().map_err(Error::SetConfiguration) - }; - - if result.is_ok() { - if let Err(error) = path_monitor.set_paths(&paths) { - log::error!( - "{}", - error.display_chain_with_msg("Failed to update path monitor") - ); - } - *monitored_paths_guard = paths.to_vec(); - } - - result - } - Request::RegisterIps(mut ips) => { - if ips.internet_ipv4.is_none() && ips.internet_ipv6.is_none() { - ips.tunnel_ipv4 = None; - ips.tunnel_ipv6 = None; - } - if previous_addresses == ips { - Ok(()) - } else { - let result = handle - .register_ips( - ips.tunnel_ipv4, - ips.tunnel_ipv6, - ips.internet_ipv4, - ips.internet_ipv6, - ) - .map_err(Error::RegisterIps); - if result.is_ok() { - previous_addresses = ips; - } - result - } - } - Request::Stop => { - if let Err(error) = handle.reset().map_err(Error::ResetError) { - let _ = response_tx.send(Err(error)); - continue; - } - - monitored_paths.lock().unwrap().clear(); - excluded_processes.write().unwrap().clear(); - - let _ = response_tx.send(Ok(())); - - // Stop listening to commands - break; - } - }; - if response_tx.send(response).is_err() { - log::error!("A response could not be sent for a completed request"); - } - } - - drop(volume_monitor); - if let Err(error) = path_monitor.shutdown() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to shut down path monitor") - ); - } - - drop(handle); - - log::debug!("Stopping ST service"); - if let Err(error) = service::stop_driver_service() { - log::error!( - "{}", - error.display_chain_with_msg("Failed to stop ST service") - ); - } - }); - - let handle = init_rx - .recv_timeout(REQUEST_TIMEOUT) - .map_err(|_| Error::RequestThreadStuck)??; - - let handle_copy = handle.clone(); - - std::thread::spawn(move || { - while let Ok(()) = monitor_rx.recv() { - let paths = monitored_paths_copy.lock().unwrap(); - let result = if paths.len() > 0 { - log::debug!("Re-resolving excluded paths"); - handle_copy.set_config(&paths) - } else { - continue; - }; - if let Err(error) = result { - log::error!( - "{}", - error.display_chain_with_msg("Failed to update excluded paths") - ); - } - } - }); - - Ok((tx, handle)) - } - - fn send_request(&self, request: Request) -> Result<(), Error> { + fn send_request(&self, request: RequestDetails) -> Result<(), Error> { Self::send_request_inner(&self.request_tx, request) } - fn send_request_inner(request_tx: &RequestTx, request: Request) -> Result<(), Error> { + fn send_request_inner( + request_tx: &sync_mpsc::Sender, + request: RequestDetails, + ) -> Result<(), Error> { let (response_tx, response_rx) = sync_mpsc::channel(); request_tx - .send((request, response_tx)) + .send(Request::new(request).response_tx(response_tx)) .map_err(|_| Error::SplitTunnelDown)?; response_rx @@ -558,7 +227,7 @@ impl SplitTunnel { /// Set a list of applications to exclude from the tunnel. pub fn set_paths_sync>(&self, paths: &[T]) -> Result<(), Error> { - self.send_request(Request::SetPaths( + self.send_request(RequestDetails::SetPaths( paths .iter() .map(|path| path.as_ref().to_os_string()) @@ -580,7 +249,7 @@ impl SplitTunnel { return; } let (response_tx, response_rx) = sync_mpsc::channel(); - let request = Request::SetPaths( + let request = RequestDetails::SetPaths( paths .iter() .map(|path| path.as_ref().to_os_string()) @@ -590,7 +259,7 @@ impl SplitTunnel { let wait_task = move || { request_tx - .send((request, response_tx)) + .send(Request::new(request).response_tx(response_tx)) .map_err(|_| Error::SplitTunnelDown)?; response_rx.recv().map_err(|_| Error::SplitTunnelDown)? }; @@ -601,7 +270,7 @@ impl SplitTunnel { }); } - /// Instructs the driver to redirect traffic from sockets bound to 0.0.0.0, ::, or the + /// Instructs the driver to redirect connections for sockets bound to 0.0.0.0, ::, or the /// tunnel addresses (if any) to the default route. pub fn set_tunnel_addresses(&mut self, metadata: Option<&TunnelMetadata>) -> Result<(), Error> { let mut tunnel_ipv4 = None; @@ -620,7 +289,6 @@ impl SplitTunnel { let context_mutex = Arc::new(Mutex::new( SplitTunnelDefaultRouteChangeHandlerContext::new( self.request_tx.clone(), - self.daemon_tx.clone(), tunnel_ipv4, tunnel_ipv6, ), @@ -646,7 +314,7 @@ impl SplitTunnel { // could deadlock if the dropped callback is invoked (see `init_context`). .map_err(|_| Error::RegisterRouteChangeCallback)?; - Self::init_context(context)?; + Self::init_context(context, &self.request_tx)?; self._route_change_callback = callback; Ok(()) @@ -654,6 +322,7 @@ impl SplitTunnel { fn init_context( mut context: MutexGuard<'_, SplitTunnelDefaultRouteChangeHandlerContext>, + request_tx: &sync_mpsc::Sender, ) -> Result<(), Error> { // NOTE: This should remain a separate function. Dropping the context after `callback` // causes a deadlock if `split_tunnel_default_route_change_handler` is called at the same @@ -662,13 +331,16 @@ impl SplitTunnel { // to complete. context.initialize_internet_addresses()?; - context.register_ips() + SplitTunnel::send_request_inner( + request_tx, + RequestDetails::RegisterIps(context.addresses.clone()), + ) } - /// Instructs the driver to stop redirecting tunnel traffic and INADDR_ANY. + /// Instructs the driver to stop redirecting connections. pub fn clear_tunnel_addresses(&mut self) -> Result<(), Error> { self._route_change_callback = None; - self.send_request(Request::RegisterIps(InterfaceAddresses::default())) + self.send_request(RequestDetails::RegisterIps(InterfaceAddresses::default())) } /// Returns a handle used for interacting with the split tunnel module. @@ -691,7 +363,7 @@ impl Drop for SplitTunnel { // Not joining `event_thread`: It may be unresponsive. } - if let Err(error) = self.send_request(Request::Stop) { + if let Err(error) = self.send_request(RequestDetails::Stop) { log::error!( "{}", error.display_chain_with_msg("Failed to stop ST driver service") @@ -701,21 +373,18 @@ impl Drop for SplitTunnel { } struct SplitTunnelDefaultRouteChangeHandlerContext { - request_tx: RequestTx, - pub daemon_tx: Weak>, + request_tx: sync_mpsc::Sender, pub addresses: InterfaceAddresses, } impl SplitTunnelDefaultRouteChangeHandlerContext { pub fn new( - request_tx: RequestTx, - daemon_tx: Weak>, + request_tx: sync_mpsc::Sender, tunnel_ipv4: Option, tunnel_ipv6: Option, ) -> Self { SplitTunnelDefaultRouteChangeHandlerContext { request_tx, - daemon_tx, addresses: InterfaceAddresses { tunnel_ipv4, tunnel_ipv6, @@ -725,13 +394,6 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { } } - pub fn register_ips(&self) -> Result<(), Error> { - SplitTunnel::send_request_inner( - &self.request_tx, - Request::RegisterIps(self.addresses.clone()), - ) - } - pub fn initialize_internet_addresses(&mut self) -> Result<(), Error> { // Identify IP address that gives us Internet access let internet_ipv4 = get_best_default_route(AddressFamily::Ipv4) @@ -782,14 +444,9 @@ fn split_tunnel_default_route_change_handler( // Update the "internet interface" IP when best default route changes let mut ctx = ctx_mutex.lock().expect("ST route handler mutex poisoned"); - let daemon_tx = ctx.daemon_tx.upgrade(); - let maybe_send = move |content| { - if let Some(tx) = daemon_tx { - let _ = tx.unbounded_send(content); - } - }; + let prev_addrs = ctx.addresses.clone(); - let result = match event_type { + match event_type { Updated(default_route) | UpdatedDetails(default_route) => { match get_ip_address_for_interface(address_family, default_route.iface) { Ok(Some(ip)) => match ip { @@ -814,32 +471,29 @@ fn split_tunnel_default_route_change_handler( "Failed to obtain default route interface address" ) ); - maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); - return; } }; - - ctx.register_ips() } // no default route - Removed => { - match address_family { - AddressFamily::Ipv4 => { - ctx.addresses.internet_ipv4 = None; - } - AddressFamily::Ipv6 => { - ctx.addresses.internet_ipv6 = None; - } + Removed => match address_family { + AddressFamily::Ipv4 => { + ctx.addresses.internet_ipv4 = None; } - ctx.register_ips() - } - }; + AddressFamily::Ipv6 => { + ctx.addresses.internet_ipv6 = None; + } + }, + } - if let Err(error) = result { - log::error!( - "{}", - error.display_chain_with_msg("Failed to register new addresses in split tunnel driver") - ); - maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); + if prev_addrs == ctx.addresses { + return; + } + + if ctx + .request_tx + .send(Request::new(RequestDetails::RegisterIps(ctx.addresses.clone())).must_succeed()) + .is_err() + { + log::error!("Split tunnel request thread is down"); } } diff --git a/talpid-core/src/split_tunnel/windows/request.rs b/talpid-core/src/split_tunnel/windows/request.rs new file mode 100644 index 000000000000..bd80d0c64736 --- /dev/null +++ b/talpid-core/src/split_tunnel/windows/request.rs @@ -0,0 +1,300 @@ +//! This module spawns a thread used to service request to the split tunnel device driver. +//! +//! We've chosen isolate all dealings with the device driver on a dedicated thread as we've +//! previously faced issues with other software fighting us over the global transaction lock in WFP +//! (Windows Filtering Platform). +//! +//! This design also makes the tunnel state machine relatively protected against driver failure. + +use crate::tunnel_state_machine::TunnelCommand; +use futures::channel::mpsc; +use std::{ + collections::HashMap, + ffi::OsString, + path::{Path, PathBuf}, + sync::{mpsc as sync_mpsc, Arc, Mutex, RwLock, Weak}, + time::Duration, +}; +use talpid_types::{split_tunnel::ExcludedProcess, tunnel::ErrorStateCause, ErrorExt}; + +use super::{ + driver::DeviceHandle, path_monitor::PathMonitorHandle, service, + volume_monitor::VolumeMonitorHandle, Error, InterfaceAddresses, +}; + +const INIT_TIMEOUT: Duration = Duration::from_secs(5); + +/// A request to the split tunnel monitor +pub struct Request { + /// Request details + details: RequestDetails, + /// Whether to block if the request fails + must_succeed: bool, + /// Response channel + response_tx: Option>>, +} + +/// The particular request to send +pub enum RequestDetails { + /// Update paths to exclude + SetPaths(Vec), + /// Update default and VPN tunnel addresses + RegisterIps(InterfaceAddresses), + /// Stop the split tunnel monitor + Stop, +} + +impl Request { + pub fn new(details: RequestDetails) -> Self { + Request { + details, + must_succeed: false, + response_tx: None, + } + } + + pub fn response_tx(mut self, response_tx: sync_mpsc::Sender>) -> Self { + self.response_tx = Some(response_tx); + self + } + + pub fn must_succeed(mut self) -> Self { + self.must_succeed = true; + self + } + + pub fn request_name(&self) -> &'static str { + match self.details { + RequestDetails::SetPaths(_) => "SetPaths", + RequestDetails::RegisterIps(_) => "RegisterIps", + RequestDetails::Stop => "Stop", + } + } +} + +/// Begin servicing requests sent on the returned channel +pub fn spawn_request_thread( + resource_dir: PathBuf, + daemon_tx: Weak>, + path_monitor: PathMonitorHandle, + volume_monitor: VolumeMonitorHandle, + monitored_paths: Arc>>, + excluded_processes: Arc>>, +) -> Result<(sync_mpsc::Sender, Arc), Error> { + let (tx, rx): (sync_mpsc::Sender, _) = sync_mpsc::channel(); + let (init_tx, init_rx) = sync_mpsc::channel(); + + std::thread::spawn(move || { + // Ensure that the device driver service is running and that we have a handle to it + let handle = match setup_and_create_device(&resource_dir) { + Ok(handle) => { + let _ = init_tx.send(Ok(handle.clone())); + handle + } + Err(error) => { + let _ = init_tx.send(Err(error)); + return; + } + }; + + request_loop( + handle.clone(), + rx, + daemon_tx, + monitored_paths, + path_monitor.clone(), + excluded_processes, + ); + + // Shut components down in a sane order + drop(volume_monitor); + if let Err(error) = path_monitor.shutdown() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to shut down path monitor") + ); + } + + // The device handle must be dropped before stopping the service + debug_assert_eq!(Arc::strong_count(&handle), 1); + drop(handle); + + log::debug!("Stopping ST service"); + if let Err(error) = service::stop_driver_service() { + log::error!( + "{}", + error.display_chain_with_msg("Failed to stop ST service") + ); + } + }); + + let handle = init_rx + .recv_timeout(INIT_TIMEOUT) + .map_err(|_| Error::RequestThreadStuck)??; + + Ok((tx, handle)) +} + +/// Install the driver and open a handle for it +fn setup_and_create_device(resource_dir: &Path) -> Result, Error> { + service::install_driver_if_required(resource_dir).map_err(Error::ServiceError)?; + DeviceHandle::new() + .map(Arc::new) + .map_err(Error::InitializationError) +} + +/// Service requests to the device driver +fn request_loop( + handle: Arc, + cmd_rx: sync_mpsc::Receiver, + daemon_tx: Weak>, + monitored_paths: Arc>>, + path_monitor: PathMonitorHandle, + excluded_processes: Arc>>, +) { + let mut previous_addresses = InterfaceAddresses::default(); + + while let Ok(request) = cmd_rx.recv() { + let request_name = request.request_name(); + + let (should_stop, response) = handle_request( + request.details, + &handle, + &path_monitor, + &monitored_paths, + &excluded_processes, + &mut previous_addresses, + ); + + handle_request_result( + &daemon_tx, + response, + request.must_succeed, + request_name, + request.response_tx, + ); + + // Stop request loop + if should_stop { + break; + } + } +} + +/// Handle a request to the split tunnel device +fn handle_request( + request: RequestDetails, + handle: &DeviceHandle, + path_monitor: &PathMonitorHandle, + monitored_paths: &Arc>>, + excluded_processes: &Arc>>, + previous_addresses: &mut InterfaceAddresses, +) -> (bool, Result<(), Error>) { + let (should_stop, result) = match request { + RequestDetails::SetPaths(paths) => { + let mut monitored_paths_guard = monitored_paths.lock().unwrap(); + + let result = if !paths.is_empty() { + handle.set_config(&paths).map_err(Error::SetConfiguration) + } else { + handle.clear_config().map_err(Error::SetConfiguration) + }; + + if result.is_ok() { + if let Err(error) = path_monitor.set_paths(&paths) { + log::error!( + "{}", + error.display_chain_with_msg("Failed to update path monitor") + ); + } + *monitored_paths_guard = paths.to_vec(); + } + + (false, result) + } + RequestDetails::RegisterIps(mut ips) => { + if ips.internet_ipv4.is_none() && ips.internet_ipv6.is_none() { + ips.tunnel_ipv4 = None; + ips.tunnel_ipv6 = None; + } + if previous_addresses == &ips { + (false, Ok(())) + } else { + let result = handle + .register_ips( + ips.tunnel_ipv4, + ips.tunnel_ipv6, + ips.internet_ipv4, + ips.internet_ipv6, + ) + .map_err(Error::RegisterIps); + if result.is_ok() { + *previous_addresses = ips; + } + (false, result) + } + } + RequestDetails::Stop => { + if let Err(error) = handle.reset().map_err(Error::ResetError) { + // Shut down failed, so continue to live + return (false, Err(error)); + } + + // Clean up + monitored_paths.lock().unwrap().clear(); + excluded_processes.write().unwrap().clear(); + + // Stop listening to commands + (true, Ok(())) + } + }; + + (should_stop, result) +} + +/// Handle the result of a request +fn handle_request_result( + daemon_tx: &Weak>, + result: Result<(), Error>, + must_succeed: bool, + request_name: &str, + response_tx: Option>>, +) { + let log_request_failure = |response: &Result<(), Error>| { + if let Err(error) = response { + log::error!( + "Request/ioctl failed: {}\n{}", + request_name, + error.display_chain() + ); + } + }; + + let request_failed = result.is_err(); + + if let Some(response_tx) = response_tx { + if let Err(error) = response_tx.send(result) { + log::error!( + "A response could not be sent for completed request/ioctl: {}", + request_name + ); + log_request_failure(&error.0); + } + } else { + log_request_failure(&result); + } + + // Move to error state if the request failed but must succeed + if request_failed && must_succeed { + if let Some(daemon_tx) = daemon_tx.upgrade() { + log::debug!( + "Entering error state due to failed request/ioctl: {}", + request_name + ); + let _ = + daemon_tx.unbounded_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); + } else { + log::error!("Cannot handle failed request since tunnel state machine is down"); + } + } +}