Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Tensor shape overflow checking in Blas Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Apr 2, 2019
1 parent 95ebe0f commit 72335fc
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 102 deletions.
12 changes: 12 additions & 0 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -1102,5 +1102,17 @@ inline size_t mshadow_sizeof(int type) {
return size;
}

template<typename T>
inline bool mult_not_overflow(T a, T b, T *result = nullptr) {
static_assert(std::numeric_limits<T>::is_integer, "mult_not_overflow is only supported for integer types");
T res = {};
res = a * b;
if (a != 0 && (res / a) != b)
return false;
if (result)
*result = res;
return true;
}

} // namespace mshadow
#endif // MSHADOW_BASE_H_
225 changes: 123 additions & 102 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,49 +65,49 @@ struct BLASEngine {
}
inline static void gemm(Stream<Device> *stream,
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc) {
index_t m, index_t n, index_t k, DType alpha,
const DType *A, index_t lda, const DType *B, index_t ldb,
DType beta, DType *C, index_t ldc) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemm(Stream<Device> *stream,
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, DType alpha,
const DType *A, index_t lda, const DType *B, index_t ldb,
DType beta, DType *C, index_t ldc, index_t batch_count,
DType **workspace) {
LOG(FATAL) << "Not implmented!";
}
inline static void gemv(Stream<Device> *stream,
bool trans, int m, int n,
DType alpha, const DType *A, int lda,
const DType *X, int incX,
DType beta, DType *Y, int incY) {
bool trans, index_t m, index_t n,
DType alpha, const DType *A, index_t lda,
const DType *X, index_t incX,
DType beta, DType *Y, index_t incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<Device> *stream,
bool trans, int m, int n,
DType alpha, const DType *A, int lda,
const DType *X, int incX,
DType beta, DType *Y, int incY, int batch_count) {
bool trans, index_t m, index_t n,
DType alpha, const DType *A, index_t lda,
const DType *X, index_t incX,
DType beta, DType *Y, index_t incY, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<Device> *stream,
int m, int n, DType alpha,
const DType *X, int incX,
const DType *Y, int incY, DType *A, int lda) {
index_t m, index_t n, DType alpha,
const DType *X, index_t incX,
const DType *Y, index_t incY, DType *A, index_t lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<Device> *stream,
int m, int n, DType alpha,
const DType *X, int incX,
const DType *Y, int incY, DType *A, int lda, int batch_count) {
index_t m, index_t n, DType alpha,
const DType *X, index_t incX,
const DType *Y, index_t incY, DType *A, index_t lda, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<Device> *stream,
int n,
const DType* X, int incX,
const DType* Y, int incY,
index_t n,
const DType* X, index_t incX,
const DType* Y, index_t incY,
DType* ret) {
LOG(FATAL) << "Not implmented!";
}
Expand All @@ -123,9 +123,9 @@ struct BLASEngine<cpu, float> {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc) {
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc) {
if (alpha == 1.0f && beta == 0.0f) {
bool transpose_left = transb;
bool transpose_right = transa;
Expand All @@ -147,46 +147,46 @@ struct BLASEngine<cpu, float> {
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc, index_t batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
for (index_t i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY) {
bool trans, index_t m, index_t n,
float alpha, const float *A, index_t lda,
const float *X, index_t incX,
float beta, float *Y, index_t incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY, int batch_count) {
bool trans, index_t m, index_t n,
float alpha, const float *A, index_t lda,
const float *X, index_t incX,
float beta, float *Y, index_t incY, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda) {
index_t m, index_t n, float alpha,
const float *X, index_t incX,
const float *Y, index_t incY, float *A, index_t lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda, int batch_count) {
index_t m, index_t n, float alpha,
const float *X, index_t incX,
const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<cpu> *stream,
int n,
const float* X, int incX,
const float* Y, int incY,
index_t n,
const float* X, index_t incX,
const float* Y, index_t incY,
float* ret) {
LOG(FATAL) << "Not implmented!";
}
Expand All @@ -201,9 +201,9 @@ struct BLASEngine<cpu, double> {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc) {
index_t m, index_t n, index_t k, double alpha,
const double *A, index_t lda, const double *B, index_t ldb,
double beta, double *C, index_t ldc) {
if (alpha == 1.0f && beta == 0.0f) {
bool transpose_left = transb;
bool transpose_right = transa;
Expand All @@ -225,46 +225,47 @@ struct BLASEngine<cpu, double> {
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, double alpha,
const double *A, index_t lda, const double *B, index_t ldb,
double beta, double *C, index_t ldc, index_t batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
CHECK(batch_count >= 0LL);
for (index_t i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY) {
bool trans, index_t m, index_t n,
double alpha, const double *A, index_t lda,
const double *X, index_t incX,
double beta, double *Y, index_t incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY, int batch_count) {
bool trans, index_t m, index_t n,
double alpha, const double *A, index_t lda,
const double *X, index_t incX,
double beta, double *Y, index_t incY, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda) {
index_t m, index_t n, double alpha,
const double *X, index_t incX,
const double *Y, index_t incY, double *A, index_t lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda, int batch_count) {
index_t m, index_t n, double alpha,
const double *X, index_t incX,
const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<cpu> *stream,
int n,
const double* X, int incX,
const double* Y, int incY,
index_t n,
const double* X, index_t incX,
const double* Y, index_t incY,
double* ret) {
LOG(FATAL) << "Not implmented!";
}
Expand All @@ -280,55 +281,75 @@ struct BLASEngine<cpu, float> {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc) {
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc) {
cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc, index_t batch_count,
float **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<float> p_alpha(batch_count, alpha);
std::vector<float> p_beta(batch_count, beta);
std::vector<const float*> pp_A;
std::vector<const float*> pp_B;
std::vector<float*> pp_C;
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<float> p_alpha(batch_count, alpha);
std::vector<float> p_beta(batch_count, beta);
std::vector<const float*> pp_A;
std::vector<const float*> pp_B;
std::vector<float*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);
std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}
int m_k = 0;
CHECK(mult_not_overflow(m,k, &m_k));
int k_n = 0;
CHECK(mult_not_overflow(k,n, &k_n));
int m_n = 0;
CHECK(mult_not_overflow(m,n, &m_n));

CHECK(mult_not_overflow(batch_count, m_k));
CHECK(mult_not_overflow(batch_count, k_n));
CHECK(mult_not_overflow(batch_count, m_n));
for (index_t i = 0; i < batch_count; ++i) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
index_t m_k = 0;
CHECK(mult_not_overflow(m, k, &m_k));
index_t b_m_k = 0;
CHECK(mult_not_overflow(batch_count, m_k, &b_m_k));
index_t k_n = 0;
CHECK(mult_not_overflow(k, n, &k_n));
index_t b_k_n = 0;
CHECK(mult_not_overflow(batch_count, k_n, &b_k_n));
index_t m_n = 0;
CHECK(mult_not_overflow(m, n, &m_n));
index_t b_m_n = 0;
CHECK(mult_not_overflow(batch_count, m_n, &b_m_n));

for (index_t i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
Expand Down

0 comments on commit 72335fc

Please sign in to comment.