Skip to content

Commit

Permalink
Add lovelace i8 kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Jul 29, 2024
1 parent 766435e commit da7ba9d
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 43 deletions.
27 changes: 10 additions & 17 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"

/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
Expand Down Expand Up @@ -98,39 +99,31 @@ template <template <typename, typename> typename Epilogue,
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::half_t, Epilogue, TileShape,
WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include "cutlass/float8.h"

/**
* This file defines Gemm kernel configurations for SM89 based on the Gemm
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/

namespace vllm {

template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fallback_gemm {
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
Expand All @@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
FP8MathOperator>;
};

struct sm89_config_default {
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -40,7 +40,8 @@ struct sm89_config_default {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -74,7 +75,7 @@ struct sm89_config_default {
}
};

struct sm89_config_M256 {
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -89,7 +90,8 @@ struct sm89_config_M256 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand All @@ -114,7 +116,7 @@ struct sm89_config_M256 {
}
};

struct sm89_config_M128 {
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -129,7 +131,8 @@ struct sm89_config_M128 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -163,7 +166,7 @@ struct sm89_config_M128 {
}
};

struct sm89_config_M64 {
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

Expand All @@ -176,7 +179,8 @@ struct sm89_config_M64 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -215,7 +219,7 @@ struct sm89_config_M64 {
}
};

struct sm89_config_M32 {
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
Expand All @@ -229,7 +233,8 @@ struct sm89_config_M32 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -265,7 +270,7 @@ struct sm89_config_M32 {
}
};

struct sm89_config_M16 {
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -281,7 +286,8 @@ struct sm89_config_M16 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -320,10 +326,10 @@ struct sm89_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
Expand All @@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,

if (mp2 <= 16) {
// M in [1, 16]
return sm89_config_M16::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_config_M32::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_config_M64::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_config_M128::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_config_M256::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_config_default::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
Expand Down
Loading

0 comments on commit da7ba9d

Please sign in to comment.