Skip to content

Commit

Permalink
Add api to get current seed
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 11, 2024
1 parent ad84486 commit ac84f01
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": [
"cuda", "flash-attn",
"cuda",
],
}
1 change: 1 addition & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;

fn set_seed(&self, _: u64) -> Result<()>;
fn get_current_seed(&self) -> Result<u64>;

/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3313,6 +3313,10 @@ impl BackendDevice for CpuDevice {
crate::bail!("cannot seed the CPU rng with set_seed")
}

fn get_current_seed(&self) -> Result<u64> {
crate::bail!("cannot get the CPU rng seed with get_current_seed")
}

fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*;

Expand Down
8 changes: 8 additions & 0 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::cell::Cell;
use std::sync::{Arc, Mutex};

use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
Expand All @@ -30,6 +31,7 @@ pub struct CudaDevice {
device: Arc<cudarc::driver::CudaDevice>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
seed_value: Cell<u64>,
}

impl std::fmt::Debug for CudaDevice {
Expand Down Expand Up @@ -168,6 +170,7 @@ impl BackendDevice for CudaDevice {
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
seed_value: Cell::new(299792458),
})
}

Expand All @@ -176,9 +179,14 @@ impl BackendDevice for CudaDevice {
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
self.seed_value.set(seed);
Ok(())
}

fn get_current_seed(&self) -> Result<u64> {
Ok(self.seed_value.get())
}

fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
Expand Down
9 changes: 9 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ impl Device {
}
}

/// Get the current seed for the device RNG.
pub fn get_current_seed(&self) -> Result<u64> {
match self {
Self::Cpu => CpuDevice.get_current_seed(),
Self::Cuda(c) => c.get_current_seed(),
Self::Metal(m) => m.get_current_seed(),
}
}

pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}

fn get_current_seed(&self) -> Result<u64> {
Err(Error::NotCompiledWithCudaSupport)
}

fn location(&self) -> crate::DeviceLocation {
fail!()
}
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}

fn get_current_seed(&self) -> Result<u64> {
Err(Error::NotCompiledWithMetalSupport)
}

fn location(&self) -> crate::DeviceLocation {
fail!()
}
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub struct MetalDevice {
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Value of the current seed
pub(crate) seed_value: Cell<u64>,
}

impl std::fmt::Debug for MetalDevice {
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,7 @@ impl BackendDevice for MetalDevice {
buffers,
kernels,
seed,
seed_value: Cell::new(299792458),
})
}

Expand Down Expand Up @@ -2105,9 +2106,15 @@ impl BackendDevice for MetalDevice {
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));

self.seed_value.set(seed);

Ok(())
}

fn get_current_seed(&self) -> Result<u64> {
Ok(self.seed_value.get())
}

fn synchronize(&self) -> Result<()> {
self.wait_until_completed()
}
Expand Down

0 comments on commit ac84f01

Please sign in to comment.