Skip to content

Commit

Permalink
Add the i16 dtype (2) (#26)
Browse files Browse the repository at this point in the history
* Add the i16 dtype

* Added I16 and I32 to fix the missing arms issue (candle-onnx/eval)

* Update rust-ci.yml

* Update ci_cuda.yaml

* fmt adjustment

* Revert "Update rust-ci.yml"

This reverts commit f659d36.

* Revert "Update ci_cuda.yaml"

This reverts commit 62a4b39.
  • Loading branch information
ro99 authored Sep 15, 2024
1 parent 8a99f7c commit 9e31a19
Show file tree
Hide file tree
Showing 35 changed files with 586 additions and 27 deletions.
5 changes: 5 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I16 => {
for v in vs.to_vec1::<i16>()? {
f.write_i16::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand Down
124 changes: 121 additions & 3 deletions candle-core/src/cpu_backend/mod.rs

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub trait Map1 {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)),
C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
Expand All @@ -27,6 +28,7 @@ pub trait Map1Any {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I16(vs) => Ok(self.f(vs, layout, C::I16)?),
C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
Expand Down
64 changes: 50 additions & 14 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i16", kernels::FILL)?;
let params = (&data, v as i16, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i32>(elem_count) }.w()?;
Expand Down Expand Up @@ -207,6 +215,10 @@ impl BackendDevice for CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc_zeros::<i16>(elem_count).w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc_zeros::<i32>(elem_count).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -244,13 +256,17 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::U8
| DType::U32
| DType::I64
| DType::I32
| DType::I16
| DType::F16
| DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
Expand Down Expand Up @@ -288,13 +304,17 @@ impl BackendDevice for CudaDevice {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::U8
| DType::U32
| DType::I16
| DType::I32
| DType::I64
| DType::F16
| DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
Expand Down Expand Up @@ -330,6 +350,10 @@ impl BackendDevice for CudaDevice {
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc::<i16>(elem_count).w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc::<i32>(elem_count).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -371,6 +395,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorageRef::I32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -412,6 +440,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -453,6 +485,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down
55 changes: 50 additions & 5 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl SlicePtrOrNull<usize> {
pub enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I16(CudaSlice<i16>),
I32(CudaSlice<i32>),
I64(CudaSlice<i64>),
BF16(CudaSlice<bf16>),
Expand Down Expand Up @@ -364,14 +365,17 @@ impl<'a> Map1 for IndexSelect<'a> {
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I16(slice) => {
("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I32(slice) => {
("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8/u32/i32/i64",
msg: "index_select ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: self.0.dtype(),
})
Expand Down Expand Up @@ -431,14 +435,17 @@ impl<'a> Map1 for Gather<'a> {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => {
("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::I32(slice) => {
("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8/u32/i32/i64",
msg: "gather ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -484,11 +491,12 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i32/i64",
msg: "index-add ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -533,11 +541,12 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i32/i64",
msg: "scatter-add ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -876,6 +885,10 @@ impl<'a> Map2 for WhereCond<'a> {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
CudaStorageSlice::I16(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i16")
}
CudaStorageSlice::I32(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i32")
Expand All @@ -885,7 +898,7 @@ impl<'a> Map2 for WhereCond<'a> {
(ptr, "where_i64")
}
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u8/u32/i64",
msg: "where conditions should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: self.0.dtype(),
})
Expand Down Expand Up @@ -1039,6 +1052,7 @@ macro_rules! cuda_dtype {
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(i16, I16);
cuda_dtype!(i32, I32);
cuda_dtype!(i64, I64);
cuda_dtype!(f16, F16);
Expand Down Expand Up @@ -1162,6 +1176,7 @@ impl BackendStorage for CudaStorage {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::I16(_) => DType::I16,
CudaStorageSlice::I32(_) => DType::I32,
CudaStorageSlice::I64(_) => DType::I64,
CudaStorageSlice::BF16(_) => DType::BF16,
Expand Down Expand Up @@ -1189,6 +1204,7 @@ impl BackendStorage for CudaStorage {
let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
Expand All @@ -1213,6 +1229,12 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
DType::I16 => {
let out = unsafe { dev.alloc::<i16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I16(out)
}
DType::I32 => {
let out = unsafe { dev.alloc::<i32>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
Expand Down Expand Up @@ -1315,6 +1337,11 @@ impl BackendStorage for CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::I16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::I16(cpu_storage))
}
CudaStorageSlice::I32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Expand Down Expand Up @@ -1587,6 +1614,7 @@ impl BackendStorage for CudaStorage {
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
(S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?,
(S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
Expand Down Expand Up @@ -1854,6 +1882,11 @@ impl BackendStorage for CudaStorage {
*d.slice(dst_o..).device_ptr(),
"copy2d_u32",
),
(S::I16(s), S::I16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_i16",
),
(S::I32(s), S::I32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
Expand Down Expand Up @@ -1965,6 +1998,18 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cuda_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub trait Map1 {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I16(s) => S::I16(self.f(s, d, l)?),
S::I32(s) => S::I32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
Expand Down Expand Up @@ -137,6 +138,7 @@ pub trait Map1Any {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I16(s) => self.f(s, d, l, S::I16)?,
S::I32(s) => self.f(s, d, l, S::I32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I16 => self.fmt_dt::<i16>(f),
DType::I32 => self.fmt_dt::<i32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
Expand Down Expand Up @@ -464,6 +465,12 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I16 => {
let tf: IntFormatter<i16> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I32 => {
let tf: IntFormatter<i32> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
Expand Down
Loading

0 comments on commit 9e31a19

Please sign in to comment.