Skip to content

Commit

Permalink
Support thread pools
Browse files Browse the repository at this point in the history
Create a safe `ThreadPool` type wrapping `ZSTD_threadPool`. Allow
binding a `&ThreadPool` to a `CCtx`, which requires that the lifetime of
the `&ThreadPool` outlive the `CCtx`.
  • Loading branch information
joshtriplett committed Jun 2, 2024
1 parent 1c49557 commit fbc09c1
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions zstd-safe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,38 @@ impl<'a> CCtx<'a> {
pub fn out_size() -> usize {
unsafe { zstd_sys::ZSTD_CStreamOutSize() }
}

/// Use a shared thread pool for this context.
///
/// Thread pool must outlive the context.
#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
pub fn ref_thread_pool<'b>(&mut self, pool: &'b ThreadPool) -> SafeResult
where
'b: 'a,
{
parse_code(unsafe {
zstd_sys::ZSTD_CCtx_refThreadPool(self.0.as_ptr(), pool.0.as_ptr())
})
}

/// Return to using a private thread pool for this context.
#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
pub fn disable_thread_pool(&mut self) -> SafeResult {
parse_code(unsafe {
zstd_sys::ZSTD_CCtx_refThreadPool(
self.0.as_ptr(),
core::ptr::null_mut(),
)
})
}
}

impl<'a> Drop for CCtx<'a> {
Expand Down Expand Up @@ -1355,6 +1387,64 @@ impl<'a> Drop for DDict<'a> {
unsafe impl<'a> Send for DDict<'a> {}
unsafe impl<'a> Sync for DDict<'a> {}

/// A shared thread pool for one or more compression contexts
#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
pub struct ThreadPool(NonNull<zstd_sys::ZSTD_threadPool>);

#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
impl ThreadPool {
/// Create a thread pool with the specified number of threads.
///
/// # Panics
///
/// If creating the thread pool failed.
pub fn new(num_threads: usize) -> Self {
Self::try_new(num_threads)
.expect("zstd returned null pointer when creating thread pool")
}

/// Create a thread pool with the specified number of threads.
pub fn try_new(num_threads: usize) -> Option<Self> {
Some(Self(NonNull::new(unsafe {
zstd_sys::ZSTD_createThreadPool(num_threads)
})?))
}
}

#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
impl Drop for ThreadPool {
fn drop(&mut self) {
unsafe {
zstd_sys::ZSTD_freeThreadPool(self.0.as_ptr());
}
}
}

#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
unsafe impl Send for ThreadPool {}
#[cfg(all(feature = "experimental", feature = "zstdmt"))]
#[cfg_attr(
feature = "doc-cfg",
doc(cfg(all(feature = "experimental", feature = "zstdmt")))
)]
unsafe impl Sync for ThreadPool {}

/// Wraps the `ZSTD_decompress_usingDDict()` function.
pub fn decompress_using_ddict(
dctx: &mut DCtx<'_>,
Expand Down

0 comments on commit fbc09c1

Please sign in to comment.