Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 6, 2024
1 parent 121bdfd commit 960e0fd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 14 deletions.
6 changes: 5 additions & 1 deletion candle-core/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,11 @@ impl std::fmt::Display for Tensor {
}
}
DType::F8E4M3 => {
return write!(f, "F8E4M3 does not support display.");
if let Ok(tf) = FloatFormatter::<F8E4M3>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
};

Expand Down
98 changes: 85 additions & 13 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,53 @@ __device__ void cast_(
}
}

#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))

template <typename T>
__device__ void cast_fp8_(
const size_t numel,
const size_t num_dims,
const size_t *info,
const __nv_fp8_e4m3 *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = F8E4M3_TO_FLOAT(inp[i]);
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = F8E4M3_TO_FLOAT(inp[strided_i]);
}
}
}
template <typename S>
__device__ void cast_fp8_into_(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
__nv_fp8_e4m3 *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = __nv_fp8_e4m3((float)inp[i]);
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = __nv_fp8_e4m3((float)inp[strided_i]);
}
}
}

template <typename S, typename T, typename I>
__device__ void cast_through(
const size_t numel,
Expand Down Expand Up @@ -59,6 +106,30 @@ extern "C" __global__ void FN_NAME( \
cast_<SRC_TYPENAME, DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \


#define CAST_OP_FP8(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
cast_fp8_<DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \


#define CAST_OP_FP8_INTO(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
cast_fp8_into_<SRC_TYPENAME>(numel, num_dims, info, inp, out); \
} \

#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
Expand Down Expand Up @@ -99,22 +170,23 @@ CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16)
CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16)
CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32)
CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3)

CAST_OP(__nv_fp8_e4m3, float, cast_f8_e4m3_f32)
CAST_OP(float, __nv_fp8_e4m3, cast_f32_f8_e4m3)
CAST_THROUGH_OP(__nv_fp8_e4m3, uint8_t, float, cast_f8_e4m3_u8)
CAST_THROUGH_OP(__nv_fp8_e4m3, __half, float, cast_f8_e4m3_f16)
CAST_THROUGH_OP(__nv_fp8_e4m3, double, float, cast_f8_e4m3_f64)
CAST_THROUGH_OP(__half, __nv_fp8_e4m3, float, cast_f16_f8_e4m3)
CAST_THROUGH_OP(double, __nv_fp8_e4m3, float, cast_f64_f8_e4m3)
CAST_THROUGH_OP(uint8_t, __nv_fp8_e4m3, float, cast_u8_f8_e4m3)
CAST_THROUGH_OP(int32_t, __nv_fp8_e4m3, float, cast_i32_f8_e4m3)
CAST_THROUGH_OP(__nv_fp8_e4m3, int32_t, float, cast_f8_e4m3_i32)
CAST_THROUGH_OP(__nv_fp8_e4m3, __nv_bfloat16, float, cast_f8_e4m3_bf16)
CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3)
#endif
#endif


CAST_OP_FP8(__nv_fp8_e4m3, float, cast_f8_e4m3_f32)
CAST_OP_FP8_INTO(float, __nv_fp8_e4m3, cast_f32_f8_e4m3)
// CAST_THROUGH_OP(__nv_fp8_e4m3, uint8_t, float, cast_f8_e4m3_u8)
// CAST_THROUGH_OP(__nv_fp8_e4m3, __half, float, cast_f8_e4m3_f16)
// CAST_THROUGH_OP(__nv_fp8_e4m3, double, float, cast_f8_e4m3_f64)
// CAST_THROUGH_OP(__half, __nv_fp8_e4m3, float, cast_f16_f8_e4m3)
// CAST_THROUGH_OP(double, __nv_fp8_e4m3, float, cast_f64_f8_e4m3)
// CAST_THROUGH_OP(uint8_t, __nv_fp8_e4m3, float, cast_u8_f8_e4m3)
// CAST_THROUGH_OP(int32_t, __nv_fp8_e4m3, float, cast_i32_f8_e4m3)
// CAST_THROUGH_OP(__nv_fp8_e4m3, int32_t, float, cast_f8_e4m3_i32)
// CAST_THROUGH_OP(__nv_fp8_e4m3, __nv_bfloat16, float, cast_f8_e4m3_bf16)
// CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3)

#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)

Expand Down

0 comments on commit 960e0fd

Please sign in to comment.