diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 1e7c2262ab26af..d86455955bae60 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -47,6 +47,7 @@ pub mod file_operations; pub mod miscdev; pub mod pages; pub mod str; +pub mod traits; pub mod linked_list; mod raw_list; diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index fad94708fa6f2d..f0835fb19b2f7b 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -22,3 +22,5 @@ pub use super::{pr_alert, pr_cont, pr_crit, pr_emerg, pr_err, pr_info, pr_notice pub use super::static_assert; pub use super::{KernelModule, Result}; + +pub use crate::traits::TryPin; diff --git a/rust/kernel/traits.rs b/rust/kernel/traits.rs new file mode 100644 index 00000000000000..79e121027b3046 --- /dev/null +++ b/rust/kernel/traits.rs @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Traits useful to drivers, and their implementations for common types. + +use core::{ops::Deref, pin::Pin}; + +use alloc::{alloc::AllocError, sync::Arc}; + +/// Trait which provides a fallible version of `pin()` for pointer types. +/// +/// Common pointer types which implement a `pin()` method include [`Box`], [`Arc`] and [`Rc`]. +pub trait TryPin { + /// Constructs a new `Pin>`. If `T` does not implement [`Unpin`], then data + /// will be pinned in memory and unable to be moved. An error will be returned + /// if allocation fails. + fn try_pin(data: P::Target) -> core::result::Result, AllocError>; +} + +impl TryPin> for Arc { + fn try_pin(data: T) -> core::result::Result>, AllocError> { + // SAFETY: the data `T` is exposed only through a `Pin>`, which + // does not allow data to move out of the `Arc`. Therefore it can + // never be moved. + Ok(unsafe { Pin::new_unchecked(Arc::try_new(data)?) }) + } +} diff --git a/samples/rust/rust_miscdev.rs b/samples/rust/rust_miscdev.rs index f3293b4800904b..0f96e3bec23e5e 100644 --- a/samples/rust/rust_miscdev.rs +++ b/samples/rust/rust_miscdev.rs @@ -39,19 +39,16 @@ struct SharedState { impl SharedState { fn try_new() -> Result>> { - // SAFETY: `state` is pinning `Arc`, which implements `Unpin`. - let state = unsafe { - Pin::new_unchecked(Arc::try_new(Self { - // SAFETY: `condvar_init!` is called below. - state_changed: CondVar::new(), - // SAFETY: `mutex_init!` is called below. - inner: Mutex::new(SharedStateInner { token_count: 0 }), - })?) - }; - // SAFETY: `state_changed` is pinned behind `Arc`. + let state = Arc::try_pin(Self { + // SAFETY: `condvar_init!` is called below. + state_changed: unsafe { CondVar::new() }, + // SAFETY: `mutex_init!` is called below. + inner: unsafe { Mutex::new(SharedStateInner { token_count: 0 }) }, + })?; + // SAFETY: `state_changed` is pinned behind `Pin`. let state_changed = unsafe { Pin::new_unchecked(&state.state_changed) }; kernel::condvar_init!(state_changed, "SharedState::state_changed"); - // SAFETY: `inner` is pinned behind `Arc`. + // SAFETY: `inner` is pinned behind `Pin`. let inner = unsafe { Pin::new_unchecked(&state.inner) }; kernel::mutex_init!(inner, "SharedState::inner"); Ok(state)