From 72335fce06ba33e86373ff3812e1ab939c55bcd7 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Mon, 1 Apr 2019 18:43:52 -0700 Subject: [PATCH] Tensor shape overflow checking in Blas Engine Fixes https://github.com/apache/incubator-mxnet/issues/14522 --- mshadow/base.h | 12 +++ mshadow/dot_engine-inl.h | 225 +++++++++++++++++++++------------------ 2 files changed, 135 insertions(+), 102 deletions(-) diff --git a/mshadow/base.h b/mshadow/base.h index 4cdab74d..53f36e06 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -1102,5 +1102,17 @@ inline size_t mshadow_sizeof(int type) { return size; } +template +inline bool mult_not_overflow(T a, T b, T *result = nullptr) { + static_assert(std::numeric_limits::is_integer, "mult_not_overflow is only supported for integer types"); + T res = {}; + res = a * b; + if (a != 0 && (res / a) != b) + return false; + if (result) + *result = res; + return true; +} + } // namespace mshadow #endif // MSHADOW_BASE_H_ diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index ed943b4e..5a02e83a 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -65,49 +65,49 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc) { + index_t m, index_t n, index_t k, DType alpha, + const DType *A, index_t lda, const DType *B, index_t ldb, + DType beta, DType *C, index_t ldc) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, DType alpha, + const DType *A, index_t lda, const DType *B, index_t ldb, + DType beta, DType *C, index_t ldc, index_t batch_count, DType **workspace) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY) { + bool trans, index_t m, index_t n, + DType alpha, const DType *A, index_t lda, + const DType *X, index_t incX, + DType beta, DType *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + DType alpha, const DType *A, index_t lda, + const DType *X, index_t incX, + DType beta, DType *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda) { + index_t m, index_t n, DType alpha, + const DType *X, index_t incX, + const DType *Y, index_t incY, DType *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda, int batch_count) { + index_t m, index_t n, DType alpha, + const DType *X, index_t incX, + const DType *Y, index_t incY, DType *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const DType* X, int incX, - const DType* Y, int incY, + index_t n, + const DType* X, index_t incX, + const DType* Y, index_t incY, DType* ret) { LOG(FATAL) << "Not implmented!"; } @@ -123,9 +123,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -147,46 +147,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc, index_t batch_count, float **workspace) { - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, + index_t n, + const float* X, index_t incX, + const float* Y, index_t incY, float* ret) { LOG(FATAL) << "Not implmented!"; } @@ -201,9 +201,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -225,46 +225,47 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc, index_t batch_count, double **workspace) { - for (int i = 0; i < batch_count; ++i) { + CHECK(batch_count >= 0LL); + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, + index_t n, + const double* X, index_t incX, + const double* Y, index_t incY, double* ret) { LOG(FATAL) << "Not implmented!"; } @@ -280,47 +281,54 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc) { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc, index_t batch_count, float **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - std::vector p_m(batch_count, m); - std::vector p_n(batch_count, n); - std::vector p_k(batch_count, k); - std::vector p_lda(batch_count, lda); - std::vector p_ldb(batch_count, ldb); - std::vector p_ldc(batch_count, ldc); - std::vector p_alpha(batch_count, alpha); - std::vector p_beta(batch_count, beta); - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; - CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); - CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); + CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); + CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - std::vector p_group_sizeb(batch_count, batch_count); - std::vector p_transa(batch_count, cblas_a_trans); - std::vector p_transb(batch_count, cblas_b_trans); + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); - auto m_k = m * k; - auto k_n = k * n; - auto m_n = m * n; - for (int i = 0; i < batch_count; i++) { - pp_A.push_back(A + i * m_k); - pp_B.push_back(B + i * k_n); - pp_C.push_back(C + i * m_n); - } + int m_k = 0; + CHECK(mult_not_overflow(m,k, &m_k)); + int k_n = 0; + CHECK(mult_not_overflow(k,n, &k_n)); + int m_n = 0; + CHECK(mult_not_overflow(m,n, &m_n)); + + CHECK(mult_not_overflow(batch_count, m_k)); + CHECK(mult_not_overflow(batch_count, k_n)); + CHECK(mult_not_overflow(batch_count, m_n)); + for (index_t i = 0; i < batch_count; ++i) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); + } cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), p_m.data(), p_n.data(), p_k.data(), @@ -328,7 +336,20 @@ struct BLASEngine { p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), 1, p_group_sizeb.data()); #else - for (int i = 0; i < batch_count; ++i) { + index_t m_k = 0; + CHECK(mult_not_overflow(m, k, &m_k)); + index_t b_m_k = 0; + CHECK(mult_not_overflow(batch_count, m_k, &b_m_k)); + index_t k_n = 0; + CHECK(mult_not_overflow(k, n, &k_n)); + index_t b_k_n = 0; + CHECK(mult_not_overflow(batch_count, k_n, &b_k_n)); + index_t m_n = 0; + CHECK(mult_not_overflow(m, n, &m_n)); + index_t b_m_n = 0; + CHECK(mult_not_overflow(batch_count, m_n, &b_m_n)); + + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc);