diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c77a93a..ee64334 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -43,11 +43,16 @@ jobs: uses: actions-rs/cargo@v1.0.3 with: command: build - args: --all + + - name: Tests + uses: actions-rs/cargo@v1.0.3 + with: + command: test + - name: Examples uses: actions-rs/cargo@v1.0.3 env: RUST_LOG: "trace" with: command: run - args: --all-features --example mutex -- 2 + args: --all-features --example mutex -- 15 diff --git a/Cargo.toml b/Cargo.toml index ef0efe1..0e01a3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "shared_memory" description = "A user friendly crate that allows you to share memory between processes" -version = "0.12.3" +version = "0.12.4" authors = ["ElasT0ny "] license = "MIT OR Apache-2.0" edition = "2018" @@ -35,7 +35,7 @@ nix = "0.23" libc = "0.2" [target.'cfg(windows)'.dependencies] -win-sys = "0.2" +win-sys = "0.3" [dev-dependencies] raw_sync = "0.1" diff --git a/examples/mutex.rs b/examples/mutex.rs index d46fdc5..47554bb 100644 --- a/examples/mutex.rs +++ b/examples/mutex.rs @@ -58,7 +58,6 @@ fn increment_value(shmem_flink: &str, thread_num: usize) { let mut raw_ptr = shmem.as_ptr(); let is_init: &mut AtomicU8; - let mutex: Box; unsafe { is_init = &mut *(raw_ptr as *mut u8 as *mut AtomicU8); @@ -66,7 +65,7 @@ fn increment_value(shmem_flink: &str, thread_num: usize) { }; // Initialize or wait for initialized mutex - if shmem.is_owner() { + let mutex = if shmem.is_owner() { is_init.store(0, Ordering::Relaxed); // Initialize the mutex let (lock, _bytes_used) = unsafe { @@ -77,7 +76,7 @@ fn increment_value(shmem_flink: &str, thread_num: usize) { .unwrap() }; is_init.store(1, Ordering::Relaxed); - mutex = lock; + lock } else { // wait until mutex is initialized while is_init.load(Ordering::Relaxed) != 1 {} @@ -89,8 +88,8 @@ fn increment_value(shmem_flink: &str, thread_num: usize) { ) .unwrap() }; - mutex = lock; - } + lock + }; // Loop until mutex data reaches 10 loop { diff --git a/src/lib.rs b/src/lib.rs index 33e1ee9..3f05937 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ pub struct ShmemConf { overwrite_flink: bool, flink_path: Option, size: usize, + ext: os_impl::ShmemConfExt, } impl Drop for ShmemConf { fn drop(&mut self) { @@ -184,6 +185,7 @@ impl ShmemConf { let mut retry = 0; loop { let unique_id = if let Some(ref unique_id) = self.os_id { + retry = 5; unique_id.as_str() } else { let flink_path = self.flink_path.as_ref().unwrap(); @@ -202,7 +204,7 @@ impl ShmemConf { flink_uid.as_str() }; - match os_impl::open_mapping(unique_id, self.size) { + match os_impl::open_mapping(unique_id, self.size, &self.ext) { Ok(m) => { self.size = m.map_size; self.owner = false; diff --git a/src/unix.rs b/src/unix.rs index 6d301ee..695b5f5 100644 --- a/src/unix.rs +++ b/src/unix.rs @@ -9,6 +9,9 @@ use nix::unistd::{close, ftruncate}; use crate::ShmemError; +#[derive(Clone, Default)] +pub struct ShmemConfExt; + pub struct MapData { //On linux, you must shm_unlink() the object created for the mapping. It wont disappear automatically. owner: bool, @@ -145,7 +148,11 @@ pub fn create_mapping(unique_id: &str, map_size: usize) -> Result Result { +pub fn open_mapping( + unique_id: &str, + _map_size: usize, + _ext: &ShmemConfExt, +) -> Result { //Open shared memory debug!("Openning persistent mapping at {}", unique_id); let shmem_fd = match shm_open( diff --git a/src/windows.rs b/src/windows.rs index f5dd9ca..96f4a59 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -3,11 +3,24 @@ use std::io::ErrorKind; use std::os::windows::{fs::OpenOptionsExt, io::AsRawHandle}; use std::path::PathBuf; -use crate::log::*; +use crate::{log::*, ShmemConf}; use win_sys::*; use crate::ShmemError; +#[derive(Clone, Default)] +pub struct ShmemConfExt { + allow_raw: bool, +} + +impl ShmemConf { + /// If set to true, enables openning raw shared memory that is not managed by this crate + pub fn allow_raw(mut self, allow: bool) -> Self { + self.ext.allow_raw = allow; + self + } +} + pub struct MapData { owner: bool, @@ -118,7 +131,12 @@ fn get_tmp_dir() -> Result { } } -fn new_map(unique_id: &str, map_size: usize, create: bool) -> Result { +fn new_map( + unique_id: &str, + mut map_size: usize, + create: bool, + allow_raw: bool, +) -> Result { // Create file to back the shared memory let mut file_path = get_tmp_dir()?; file_path.push(unique_id.trim_start_matches('/')); @@ -188,6 +206,8 @@ fn new_map(unique_id: &str, map_size: usize, create: bool) -> Result { if create { return Err(ShmemError::MapCreateFailed(e.raw_os_error().unwrap() as _)); + } else if !allow_raw { + return Err(ShmemError::MapOpenFailed(ERROR_FILE_NOT_FOUND.0)); } // This may be a mapping that isnt managed by this crate @@ -216,7 +236,7 @@ fn new_map(unique_id: &str, map_size: usize, create: bool) -> Result v as _, + Ok(v) => v, Err(e) => { return Err(if create { ShmemError::MapCreateFailed(e.win32_error().unwrap().0) @@ -227,33 +247,35 @@ fn new_map(unique_id: &str, map_size: usize, create: bool) -> Result Result { - new_map(unique_id, map_size, true) + new_map(unique_id, map_size, true, false) } //Opens an existing mapping specified by its uid -pub fn open_mapping(unique_id: &str, map_size: usize) -> Result { - new_map(unique_id, map_size, false) +pub fn open_mapping( + unique_id: &str, + map_size: usize, + ext: &ShmemConfExt, +) -> Result { + new_map(unique_id, map_size, false, ext.allow_raw) } diff --git a/tests/general.rs b/tests/general.rs new file mode 100644 index 0000000..2812711 --- /dev/null +++ b/tests/general.rs @@ -0,0 +1,117 @@ +use std::path::Path; + +use shared_memory::ShmemConf; + +#[test] +fn create_new() { + let mut s = ShmemConf::new().size(4090).create().unwrap(); + + assert!(s.is_owner()); + assert!(!s.get_os_id().is_empty()); + assert!(s.len() >= 4090); + assert!(!s.as_ptr().is_null()); + unsafe { + assert_eq!(s.as_slice().len(), s.len()); + assert_eq!(s.as_slice_mut().len(), s.len()); + } +} + +#[test] +fn create_with_flink() { + let flink = Path::new("create_new1"); + + let mut s = ShmemConf::new().flink(flink).size(4090).create().unwrap(); + + assert!(s.is_owner()); + assert!(!s.get_os_id().is_empty()); + assert!(flink.is_file()); + assert!(s.len() >= 4090); + assert!(!s.as_ptr().is_null()); + unsafe { + assert_eq!(s.as_slice().len(), s.len()); + assert_eq!(s.as_slice_mut().len(), s.len()); + } + + drop(s); + + assert!(!flink.is_file()); +} + +#[test] +fn open_os_id() { + let s1 = ShmemConf::new().size(4090).create().unwrap(); + + // Open with the unique os id + let os_id = s1.get_os_id().to_string(); + let mut s2 = ShmemConf::new().os_id(&os_id).open().unwrap(); + + assert!(!s2.is_owner()); + assert!(!s2.get_os_id().is_empty()); + assert!(s2.len() >= 4090); + assert!(!s2.as_ptr().is_null()); + unsafe { + assert_eq!(s2.as_slice().len(), s2.len()); + assert_eq!(s2.as_slice_mut().len(), s2.len()); + } + + // Drop the owner of the mapping + drop(s1); + + // Make sure it can be openned again + assert!(ShmemConf::new().os_id(&os_id).open().is_err()); + + drop(s2); +} + +#[test] +fn open_flink() { + let flink = Path::new("create_new2"); + let s1 = ShmemConf::new().flink(flink).size(4090).create().unwrap(); + + // Open with file base link + let mut s2 = ShmemConf::new().flink(&flink).open().unwrap(); + + assert!(!s2.is_owner()); + assert!(!s2.get_os_id().is_empty()); + assert!(flink.is_file()); + assert!(s2.len() >= 4090); + assert!(!s2.as_ptr().is_null()); + unsafe { + assert_eq!(s2.as_slice().len(), s2.len()); + assert_eq!(s2.as_slice_mut().len(), s2.len()); + } + + // Drop the owner of the mapping + drop(s1); + + // Make sure it can be openned again + assert!(ShmemConf::new().flink(&flink).open().is_err()); + + drop(s2); +} + +#[test] +fn share_data() { + let s1 = ShmemConf::new() + .size(core::mem::size_of::()) + .create() + .unwrap(); + + // Open with the unique os id + let os_id = s1.get_os_id().to_string(); + let s2 = ShmemConf::new().os_id(&os_id).open().unwrap(); + + let ptr1 = s1.as_ptr() as *mut u32; + let ptr2 = s2.as_ptr() as *mut u32; + + // Confirm that the two pointers are different + assert_ne!(ptr1, ptr2); + + // Write a value from s1 and read it from s2 + unsafe { + let shared_val = 0xBADC0FEE; + ptr1.write_volatile(shared_val); + let read_val = ptr2.read_volatile(); + assert_eq!(read_val, shared_val); + } +}