diff --git a/.vscode/settings.json b/.vscode/settings.json index f9b6ef02f..6abf0d3d6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,6 @@ "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "rust-analyzer.cargo.features": [ - "cuda", "flash-attn", + "cuda", ], } \ No newline at end of file diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 40fd63f1e..655c7894d 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; /// Synchronize should block until all the operations on the device are completed. fn synchronize(&self) -> Result<()>; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 75c6e7bd3..512f3c053 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -3313,6 +3313,10 @@ impl BackendDevice for CpuDevice { crate::bail!("cannot seed the CPU rng with set_seed") } + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 9e0b64067..1890496e8 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -4,6 +4,7 @@ pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; use half::{bf16, f16}; +use std::cell::Cell; use std::sync::{Arc, Mutex}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -30,6 +31,7 @@ pub struct CudaDevice { device: Arc, pub(crate) blas: Arc, curand: Arc>, + seed_value: Cell, } impl std::fmt::Debug for CudaDevice { @@ -168,6 +170,7 @@ impl BackendDevice for CudaDevice { device, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), + seed_value: Cell::new(299792458), }) } @@ -176,9 +179,14 @@ impl BackendDevice for CudaDevice { // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + self.seed_value.set(seed); Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(self.seed_value.get()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 52e3e2281..55384d0b7 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -142,6 +142,15 @@ impl Device { } } + /// Get the current seed for the device RNG. + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 26d9b2f62..ccd34756e 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -208,6 +208,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 2ec89f97a..364311a9b 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -220,6 +220,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 07210c68c..e22c7829a 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -70,6 +70,8 @@ pub struct MetalDevice { pub(crate) buffers: AllocatedBuffers, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Value of the current seed + pub(crate) seed_value: Cell, } impl std::fmt::Debug for MetalDevice { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 194c5a625..6c2f673cb 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1944,6 +1944,7 @@ impl BackendDevice for MetalDevice { buffers, kernels, seed, + seed_value: Cell::new(299792458), }) } @@ -2105,9 +2106,15 @@ impl BackendDevice for MetalDevice { } seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); + self.seed_value.set(seed); + Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(self.seed_value.get()) + } + fn synchronize(&self) -> Result<()> { self.wait_until_completed() }