Skip to content

Commit

Permalink
Check matmul types and error at compile-time if the backend doesn't s…
Browse files Browse the repository at this point in the history
…upport them
  • Loading branch information
cliffburdick committed Dec 17, 2023
1 parent b3804a8 commit 0bd925c
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,45 @@ union MatMulScaleType_t {
double cf64[2];
};

template <typename OpA, typename OpB, typename OpC, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
constexpr bool CompatibleGemmTypes() {
if constexpr (!std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type> &&
!std::is_same_v<typename OpA::scalar_type, typename OpC::scalar_type>) {
return false;
}

if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
if constexpr (std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
// List of accepted types when A/B/C match
return std::is_same_v<typename OpA::scalar_type, matxFp16> ||
std::is_same_v<typename OpA::scalar_type, matxBf16> ||
std::is_same_v<typename OpA::scalar_type, float> ||
std::is_same_v<typename OpA::scalar_type, double> ||
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<float>> ||
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<double>> ||
std::is_same_v<typename OpA::scalar_type, int8_t> ||
std::is_same_v<typename OpA::scalar_type, matxFp16Complex> ||
std::is_same_v<typename OpA::scalar_type, matxBf16Complex>;

}
// Accumulator type different from A/B
else if constexpr ( std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
return (std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, int32_t>) ||
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, matxBf16> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, matxFp16> && std::is_same_v<typename OpC::scalar_type, float>) ||
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>);
}
}
else {
// For now return true for other providers until we support more
return true;
}
}

/**
* Parameters needed to execute a GEMM. For the most part, these are very
* similar to that of a standard GEMM call
Expand Down Expand Up @@ -834,7 +873,7 @@ class matxMatMulHandle_t {
static_cast<int>(
params_.ldc)}, // Tensor-ref for destination matrix D (may be
// different memory than source C matrix)
{alpha, beta}); // Scalars used in the Epilogue
{static_cast<T1>(alpha), static_cast<T1>(beta)}); // Scalars used in the Epilogue

CutlassGemm gemm_operator;
cutlass::Status status = gemm_operator(args, nullptr, stream);
Expand Down Expand Up @@ -895,7 +934,7 @@ class matxMatMulHandle_t {
params_.ldc)}, // Tensor-ref for destination matrix D (may
// be different memory than source C matrix)
c_adj.Stride(RANK - 3), // Batch Stride C
{alpha, beta},
{static_cast<T1>(alpha), static_cast<T1>(beta)},
params_.batch // Batch Dimension
); // Scalars used in the Epilogue

Expand Down Expand Up @@ -1118,6 +1157,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
auto A_ = as_type<typename TensorTypeC::scalar_type>(A);
auto B_ = as_type<typename TensorTypeC::scalar_type>(B);

static_assert(detail::CompatibleGemmTypes<decltype(A_), decltype(B_), TensorTypeC, PROV>(),
"Combination of A/B/C types are not supported");


// CublasLt does not support operators and certain transpose modes.
// Grab a suppported tensor here and copy in if necessary.
auto c = getCublasSupportedTensor(C, stream);
Expand Down

0 comments on commit 0bd925c

Please sign in to comment.