Skip to content

Commit

Permalink
Try again
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 4460683 commit fb83bf6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub use candle::cuda_backend::kernels;
use candle::{
backend::BackendStorage,

Check warning on line 35 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused import: `backend::BackendStorage`

Check warning on line 35 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused import: `backend::BackendStorage`

Check failure on line 35 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused import: `backend::BackendStorage`
cuda_backend::{

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unresolved import `candle::cuda_backend`

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unresolved import `candle::cuda_backend`

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 36 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unresolved import `candle::cuda_backend`
cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig},
cudarc::driver::{DevicePtr, DeviceRepr, LaunchAsync, LaunchConfig},
kernel_name, CudaDType, WrapErr,
},
CudaDevice, CudaStorage, DType, Device, Result, Storage, Tensor, WithDType, D,

Check warning on line 40 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused imports: `CudaDevice`, `CudaStorage`, `Device`, `Storage`, `WithDType`

Check warning on line 40 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused imports: `CudaDevice`, `CudaStorage`, `Device`, `Storage`, `WithDType`

Check failure on line 40 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused imports: `CudaDevice`, `CudaStorage`, `Device`, `Storage`, `WithDType`
Expand Down Expand Up @@ -159,30 +159,30 @@ impl crate::Module for LayerNorm {
let (m,n) = x.dims2()?;

let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s,
Storage::Cuda(s) => s.as_cuda_slice()?.device_ptr(),
_ => unreachable!(),
}.slice;
};

let mean = x.zeros_like()?;
let rstd = x.zeros_like()?;

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s,
Storage::Cuda(s) => s.as_cuda_slice()?.device_ptr(),
_ => unreachable!(),
}.slice;
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s,
Storage::Cuda(s) => s.as_cuda_slice()?.device_ptr(),
_ => unreachable!(),
}.slice;
};

let cfg_1 = LaunchConfig {
grid_dim: (m as u32,1,1),
block_dim: (K_CUDABLOCK_REDUCE_NUM_THREADS,1,1),
shared_mem_bytes: 0,
};
let rowwisemoments = cuda_dev.get_or_load_func(&format!("rowwisemoments_{}", x.dtype().as_str()), kernels::LAYERNORM)?;
let params = (n, self.eps, x_storage, mean_storage, rstd_storage);
let params = (n, self.eps, *x_storage, *mean_storage, &rstd_storage);
unsafe { rowwisemoments.launch(cfg_1, params) };

panic!("Done!");
Expand Down

0 comments on commit fb83bf6

Please sign in to comment.