Skip to content

Commit

Permalink
Use different shfl fn
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 8045497 commit 6e43a72
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,6 @@ impl crate::Module for LayerNorm {
Device::Cuda(dev) => dev
};
let (m,n) = x.dims2()?;
let cfg_1 = LaunchConfig {
grid_dim: m,
block_dim: K_CUDABLOCK_REDUCE_NUM_THREADS,
shared_mem_bytes: 0,
};

let x_storage = match x.storage_and_layout().0 {
Storage::Cuda(s) => s,
Expand All @@ -175,6 +170,11 @@ impl crate::Module for LayerNorm {
_ => unreachable!(),
}.slice;

let cfg_1 = LaunchConfig {
grid_dim: (m,1,1),
block_dim: (K_CUDABLOCK_REDUCE_NUM_THREADS,1,1),
shared_mem_bytes: 0,
};
let rowwisemoments = cuda_dev.get_or_load_func(&kernel_name::<T>("rowwisemoments"), kernels::LAYERNORM)?;
let params = (n, self.eps, x_storage, mean_storage, rstd_storage);
unsafe { rowwisemoments.launch(cfg_1, params) };
Expand Down

0 comments on commit 6e43a72

Please sign in to comment.