Skip to content

Commit

Permalink
Support and_try_compute_if_nobody_else
Browse files Browse the repository at this point in the history
It will only evaluate one closure for a certain
entry, and other closures will be canceled, returning
`CompResult::Unchanged`.

Add `ValueInitializer::post_init_for_try_compute_with_if_nobody_else`.
Add `Cache::try_compute_if_nobody_else_with_hash_and_fun`.
Add `RefKeyEntrySelector::and_try_compute_if_nobody_else`.
  • Loading branch information
xuehaonan27 committed Oct 9, 2024
1 parent 84acc73 commit 29b00c2
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/future/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,23 @@ where
.await
}

pub(crate) async fn try_compute_if_nobody_else_with_hash_and_fun<F, Fut, E>(
&self,
key: Arc<K>,
hash: u64,
f: F,
) -> Result<compute::CompResult<K, V>, E>
where
F: FnOnce(Option<Entry<K, V>>) -> Fut,
Fut: Future<Output = Result<compute::Op<V>, E>>,
E: Send + Sync + 'static,
{
let post_init = ValueInitializer::<K, V, S>::post_init_for_try_compute_with_if_nobody_else;
self.value_initializer
.try_compute_if_nobody_else(key, hash, self, f, post_init, true)
.await
}

pub(crate) async fn upsert_with_hash_and_fun<F, Fut>(
&self,
key: Arc<K>,
Expand Down
15 changes: 15 additions & 0 deletions src/future/entry_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,21 @@ where
.await
}

pub async fn and_try_compute_if_nobody_else<F, Fut, E>(
self,
f: F,
) -> Result<compute::CompResult<K, V>, E>
where
F: FnOnce(Option<Entry<K, V>>) -> Fut,
Fut: Future<Output = Result<compute::Op<V>, E>>,
E: Send + Sync + 'static,
{
let key = Arc::new(self.ref_key.to_owned());
self.cache
.try_compute_if_nobody_else_with_hash_and_fun(key, self.hash, f)
.await
}

/// Performs an upsert of an [`Entry`] by using the given closure `f`. The word
/// "upsert" here means "update" or "insert".
///
Expand Down
148 changes: 148 additions & 0 deletions src/future/value_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,144 @@ where
// The lock will be unlocked here.
}

pub(crate) async fn try_compute_if_nobody_else<'a, F, Fut, O, E>(
&'a self,
c_key: Arc<K>,
c_hash: u64,
cache: &Cache<K, V, S>,
f: F,
post_init: fn(O) -> Result<Op<V>, E>,
allow_nop: bool,
) -> Result<CompResult<K, V>, E>
where
F: FnOnce(Option<Entry<K, V>>) -> Fut,
Fut: Future<Output = O> + 'a,
E: Send + Sync + 'static,
{
use std::panic::{resume_unwind, AssertUnwindSafe};

let type_id = TypeId::of::<ComputeNone>();
let (w_key, w_hash) = waiter_key_hash(&self.waiters, &c_key, type_id);
let waiter = TrioArc::new(RwLock::new(WaiterValue::Computing));
// NOTE: We have to acquire a write lock before `try_insert_waiter`,
// so that any concurrent attempt will get our lock and wait on it.
let lock = waiter.write().await;

if let Some(_existing_waiter) =
try_insert_waiter(&self.waiters, w_key.clone(), w_hash, &waiter)
{
// There's already a waiter computing for this entry, cancel this computation.

// Get the current value.
let ignore_if = None as Option<&mut fn(&V) -> bool>;
let maybe_entry = cache
.base
.get_with_hash(&c_key, c_hash, ignore_if, true, true)
.await;
let maybe_value = maybe_entry.as_ref().map(|ent| ent.value().clone());

return if let Some(value) = maybe_value {
Ok(CompResult::Unchanged(Entry::new(
Some(c_key),
value,
false,
false,
)))
} else {
Ok(CompResult::StillNone(c_key))
};
// The lock will be unlocked here.
} else {
// Inserted.
}

// Our waiter was inserted.

// Create a guard. This will ensure to remove our waiter when the
// enclosing future has been aborted:
// https://github.com/moka-rs/moka/issues/59
let waiter_guard = WaiterGuard::new(w_key, w_hash, &self.waiters, lock);

// Get the current value.
let ignore_if = None as Option<&mut fn(&V) -> bool>;
let maybe_entry = cache
.base
.get_with_hash(&c_key, c_hash, ignore_if, true, true)
.await;
let maybe_value = if allow_nop {
maybe_entry.as_ref().map(|ent| ent.value().clone())
} else {
None
};
let entry_existed = maybe_entry.is_some();

// Evaluate the `f` closure and get a future. Catching panic is safe here as
// we will not evaluate the closure again.
let fut = match std::panic::catch_unwind(AssertUnwindSafe(|| f(maybe_entry))) {
// Evaluated.
Ok(fut) => fut,
Err(payload) => {
waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
resume_unwind(payload);
}
};

// Resolve the `fut` future. Catching panic is safe here as we will not
// resolve the future again.
let output = match AssertUnwindSafe(fut).catch_unwind().await {
// Resolved.
Ok(output) => {
waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
output
}
// Panicked.
Err(payload) => {
waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
resume_unwind(payload);
}
};

match post_init(output)? {
Op::Nop => {
if let Some(value) = maybe_value {
Ok(CompResult::Unchanged(Entry::new(
Some(c_key),
value,
false,
false,
)))
} else {
Ok(CompResult::StillNone(c_key))
}
}
Op::Put(value) => {
cache
.insert_with_hash(Arc::clone(&c_key), c_hash, value.clone())
.await;
if entry_existed {
crossbeam_epoch::pin().flush();
let entry = Entry::new(Some(c_key), value, true, true);
Ok(CompResult::ReplacedWith(entry))
} else {
let entry = Entry::new(Some(c_key), value, true, false);
Ok(CompResult::Inserted(entry))
}
}
Op::Remove => {
let maybe_prev_v = cache.invalidate_with_hash(&c_key, c_hash, true).await;
if let Some(prev_v) = maybe_prev_v {
crossbeam_epoch::pin().flush();
let entry = Entry::new(Some(c_key), prev_v, false, false);
Ok(CompResult::Removed(entry))
} else {
Ok(CompResult::StillNone(c_key))
}
}
}

// The lock will be unlocked here.
}

/// The `post_init` function for the `get_with` method of cache.
pub(crate) fn post_init_for_get_with(value: V) -> Result<V, ()> {
Ok(value)
Expand Down Expand Up @@ -437,6 +575,16 @@ where
op
}

/// The `post_init` function for the `and_try_compute_if_nobody_else` method of cache.
pub(crate) fn post_init_for_try_compute_with_if_nobody_else<E>(
op: Result<Op<V>, E>,
) -> Result<Op<V>, E>
where
E: Send + Sync + 'static,
{
op
}

/// Returns the `type_id` for `get_with` method of cache.
pub(crate) fn type_id_for_get_with() -> TypeId {
// NOTE: We use a regular function here instead of a const fn because TypeId
Expand Down

0 comments on commit 29b00c2

Please sign in to comment.