diff --git a/guide/README.md b/guide/README.md index 95f9b1da8f8e..f79687eba374 100644 --- a/guide/README.md +++ b/guide/README.md @@ -29,16 +29,17 @@ template struct Tensor { DType *dptr_; Shape shape_; + Stream stream_; index_t stride_; }; // this is how shape object declaration look like Shape<2> shape2; // this is how tensor object declaration look like -// you can +// you can Tensor ts2; Tensor ts3; ``` -``` Tensor``` means a two dimensional tensor in CPU, while ``` Tensor``` means three dimensional tensor in GPU. +``` Tensor``` means a two dimensional tensor in CPU, while ``` Tensor``` means three dimensional tensor in GPU. ```Shape``` gives the shape information of k-dimensional tensor. The declaration use template, and can be specialized into tensor of specific device and dimension. This is what two dimensional tensor will look like: ```c++ @@ -50,8 +51,8 @@ struct Tensor { Shape<2> shape_; index_t stride_; }; -``` -* ``` Tensor``` contains ```dptr_```, which points to the space that backup the tensor. +``` +* ``` Tensor``` contains ```dptr_```, which points to the space that backup the tensor. * ```Shape<2>``` is a structure that stores shape information, the convention is same as numpy * ```stride_``` gives the number of cell space allocated in the smallest dimension (if we use numpy convention, the dimension corresponds to shape_[-1]). This is introduced when we introduce some padding cells in lowest dimension to make sure memory is aligned. @@ -64,7 +65,7 @@ Tensor ts; ts.dptr_ = data; ts.shape_ = mshadow::Shape2(3, 2); ts.stride_ = 3; -// now: ts[0][0] == 0, ts[0][1] == 1 , ts[1][0] == 3, ts[1][1] == 4 +// now: ts[0][0] == 0, ts[0][1] == 1 , ts[1][0] == 3, ts[1][1] == 4 for (index_t i = 0; i < ts.size(0); ++i) { for (index_t j = 0; j < ts.size(1), ++j) { printf("ts[%u][%u]=%f\n", i, j, ts[i][j]); @@ -73,10 +74,12 @@ for (index_t i = 0; i < ts.size(0); ++i) { ``` The result ts should be a 3 * 2 matrix, where data[2], data[5], data[8] are padding cells that are ignored. If you want a continuous memory, set ```stride_=shape_[1]```. +NOTICE: We highly recommend use stream in ```gpu``` mode, there will be an error thrown out if no stream is set. Check [basic_stream.cu](basic_stream.cu) for more detail. + Memory Allocation ==== An important design choice about mshadow is that the data structure is a **whitebox**: -it works so long as we set the space pointer ```dptr_```, corresponding ```shape_``` and ```stride_```: +it works so long as we set the space pointer ```dptr_```, corresponding ```shape_``` and ```stride_```: * For ```Tensor```, the space can be created by ```new float[]```, or pointer to some existing space such as float array in last example. * For ```Tensor```, the space need to lie in GPU, created by ```cudaMallocPitch``` @@ -197,17 +200,17 @@ int main(void) { // Tensor object is only a handle, assignment means they have same data content // we can specify content type of a Tensor, if not specified, it is float bydefault Tensor mat2 = mat; - + // shaape of matrix, note size order is same as numpy printf("%u X %u matrix\n", mat.size(1), mat.size(1)); - + // initialize all element to zero mat = 0.0f; // assign some values mat[0][1] = 1.0f; mat[1][0] = 2.0f; // elementwise operations mat += (mat + 10.0f) / 10.0f + 2.0f; - + // print out matrix, note: mat2 and mat1 are handles(pointers) for (index_t i = 0; i < mat.size(0); ++i) { for (index_t j = 0; j < mat.size(1); ++j) { diff --git a/guide/neuralnet/config.mk b/guide/neuralnet/config.mk index 112396d5557b..6c10b79903bf 100644 --- a/guide/neuralnet/config.mk +++ b/guide/neuralnet/config.mk @@ -3,10 +3,10 @@ # # This is configuration script that you can use to compile mshadow # Usage: -# +# # include config.mk in your Makefile, or directly include the definition of variables # include mshadow.mk after the variables are set -# +# # Add MSHADOW_CFLAGS to the compile flags # Add MSHADOW_LDFLAGS to the linker flags # Add MSHADOW_NVCCFLAGS to the nvcc compile flags @@ -22,11 +22,11 @@ USE_CUDA_PATH = NONE # # choose the version of blas you want to use # can be: mkl, blas, atlas, openblas, apple -USE_BLAS = mkl +USE_BLAS = openblas # # add path to intel library, you may need it # for MKL, if you did not add the path to enviroment variable -# +# USE_INTEL_PATH = NONE # whether compile with parameter server diff --git a/guide/neuralnet/convnet.cu b/guide/neuralnet/convnet.cu index 97b6a03fc416..21983fb103a8 100644 --- a/guide/neuralnet/convnet.cu +++ b/guide/neuralnet/convnet.cu @@ -9,12 +9,12 @@ using namespace mshadow; // this namespace contains all operator overloads using namespace mshadow::expr; -// define operations +// define operations struct relu{ MSHADOW_XINLINE static real_t Map(real_t a) { using namespace std; return max(a, 0.0f); - } + } }; struct relu_grad { MSHADOW_XINLINE static real_t Map(real_t a) { @@ -26,12 +26,12 @@ struct relu_grad { class INNet{ public: virtual void Forward(const Tensor& inbatch, Tensor &oubatch) = 0; - virtual void Backprop(const Tensor& gradout) = 0; + virtual void Backprop(const Tensor& gradout) = 0; virtual void Update(void) = 0; virtual ~INNet() {} }; -/*! +/*! * \brief simple two layer conv-net conv-pool-flat-fullc * this implementation is device invariant */ @@ -41,9 +41,24 @@ class ConvNet : public INNet { // initialize the network ConvNet(int batch_size, int insize, int nchannel, int ksize, int kstride, int psize, int num_out) :rnd(0), ksize(ksize), kstride(kstride), psize(psize) { + // setup stream + Stream *stream = NewStream(); + ninput.set_stream(stream); + nhidden.set_stream(stream); + nhiddenbak.set_stream(stream); + npool.set_stream(stream); + npoolbak.set_stream(stream); + nflat.set_stream(stream); + nout.set_stream(stream); + hbias.set_stream(stream); g_hbias.set_stream(stream); + obias.set_stream(stream); g_obias.set_stream(stream); + Ki2h.set_stream(stream); g_Ki2h.set_stream(stream); + Wh2o.set_stream(stream); g_Wh2o.set_stream(stream); + tmp_col.set_stream(stream); + tmp_dst.set_stream(stream); // setup nodes ninput.Resize(Shape4(batch_size, 1, insize, insize)); - nhidden.Resize(Shape4(batch_size, nchannel, (insize - ksize)/kstride+1, (insize -ksize)/kstride+1)); + nhidden.Resize(Shape4(batch_size, nchannel, (insize - ksize)/kstride+1, (insize -ksize)/kstride+1)); nhiddenbak.Resize(nhidden.shape_); npool.Resize(Shape4(batch_size, nchannel, (nhidden.size(2)+1-psize)/psize, (nhidden.size(3)+1-psize)/psize)); npoolbak.Resize(npool.shape_); @@ -58,25 +73,25 @@ class ConvNet : public INNet { Wh2o.Resize(Shape2(nflat.size(1), num_out)); g_Wh2o.Resize(Wh2o.shape_); rnd.SampleGaussian(&Ki2h, 0, 0.01f); rnd.SampleGaussian(&Wh2o, 0, 0.01f); - + printf("conv=%d, pool=%d\n", nhidden.size(3), npool.size(3)); } virtual ~ConvNet() {} // forward propagation virtual void Forward(const Tensor& inbatch, Tensor &oubatch) { index_t batch_size = inbatch.size(0); - // copy data to input layer - Copy(ninput, inbatch); + // copy data to input layer + Copy(ninput, inbatch, ninput.stream_); // first layer, conv, use stride=2 ConvForward(ninput, Ki2h, nhidden, ksize, kstride, tmp_col, tmp_dst); // add bias nhidden += broadcast<1>(hbias, nhidden.shape_); - // activation, relu, backup activation in nhidden + // activation, relu, backup activation in nhidden nhidden = F(nhidden); - Copy(nhiddenbak, nhidden); - // max pooling + Copy(nhiddenbak, nhidden, nhiddenbak.stream_); + // max pooling npool = pool(nhiddenbak, npool[0][0].shape_, psize, psize, psize); - Copy(npoolbak, npool); + Copy(npoolbak, npool, npoolbak.stream_); // flat nflat = reshape(npool, nflat.shape_); // second layer fullc @@ -85,12 +100,12 @@ class ConvNet : public INNet { // softmax calculation Softmax(nout, nout); // copy result out - Copy(oubatch, nout); + Copy(oubatch, nout, nout.stream_); } // back propagation virtual void Backprop(const Tensor& gradout) { // copy gradient to output layer - Copy(nout, gradout); + Copy(nout, gradout, nout.stream_); // calc grad of final layer g_obias = sum_rows(nout); g_Wh2o = dot(nflat.T(), nout); @@ -98,7 +113,7 @@ class ConvNet : public INNet { nflat = dot(nout, Wh2o.T()); npool = reshape(nflat, npool.shape_); // backprop pooling layer - nhiddenbak = unpool(nhiddenbak, npoolbak, npool, psize, psize, psize); + nhiddenbak = unpool(nhiddenbak, npoolbak, npool, psize, psize, psize); // calculate gradient of relu layer nhidden = F(nhidden) * nhiddenbak; // calc grad of layer 1 @@ -118,10 +133,10 @@ class ConvNet : public INNet { obias-= eta * g_obias; } private: - // forward convolution, tmp_col and tmp_dst are helper structure + // forward convolution, tmp_col and tmp_dst are helper structure inline static void ConvForward(const Tensor &in, const Tensor &kernel, - Tensor &out, + Tensor &out, int ksize, int kstride, TensorContainer &tmp_col, TensorContainer &tmp_dst) { @@ -130,20 +145,20 @@ class ConvNet : public INNet { index_t nbatch = in.size(0); index_t nchannel = out.size(1); // we directly unpack all local patches and do a dot product - // this cost lots of memory, normally for large image, only unpack several image at a time + // this cost lots of memory, normally for large image, only unpack several image at a time tmp_col.Resize(Shape2(in.size(1)*ksize*ksize, nbatch*oheight*owidth)); tmp_dst.Resize(Shape2(nchannel, nbatch*oheight*owidth)); // unpack local patches , stride=1 tmp_col = unpack_patch2col(in, ksize, ksize, kstride); tmp_dst = dot(kernel, tmp_col); - // reshape, then swap axis, we chain equations together + // reshape, then swap axis, we chain equations together out = swapaxis<1,0>(reshape(tmp_dst, Shape4(nchannel, nbatch, oheight, owidth))); - } + } // backward convolution, calculate gradient of kernel, and backprop back to in inline static void ConvBackWard(const Tensor &out, - const Tensor &kernel, + const Tensor &kernel, Tensor &g_kernel, - Tensor &in, + Tensor &in, int ksize, int kstride, TensorContainer &tmp_col, TensorContainer &tmp_dst) { @@ -152,13 +167,13 @@ class ConvNet : public INNet { index_t nbatch = in.size(0); index_t nchannel = out.size(1); // we directly unpack all local patches and do a dot product - // this cost lots of memory, normally for large image, only unpack several image at a time + // this cost lots of memory, normally for large image, only unpack several image at a time tmp_col.Resize(Shape2(in.size(1) * ksize * ksize, nbatch * oheight * owidth)); tmp_dst.Resize(Shape2(nchannel, nbatch * oheight * owidth)); - // unpack local patches - tmp_col = unpack_patch2col(in, ksize, ksize, kstride); - tmp_dst = reshape(swapaxis<1,0>(out), tmp_dst.shape_); + // unpack local patches + tmp_col = unpack_patch2col(in, ksize, ksize, kstride); + tmp_dst = reshape(swapaxis<1,0>(out), tmp_dst.shape_); g_kernel = dot(tmp_dst, tmp_col.T()); // backpropgation: not necessary for first layer, but included anyway tmp_col = dot(kernel.T(), tmp_dst); @@ -193,7 +208,7 @@ int main(int argc, char *argv[]) { if(argc < 2) { printf("Usage: cpu or gpu\n"); return 0; } - srand(0); + srand(0); // settings int batch_size = 100; int insize = 28; @@ -202,7 +217,7 @@ int main(int argc, char *argv[]) { int kstride = 1; int psize = 2; int num_out = 10; - + // choose which version to use INNet *net; if (!strcmp(argv[1], "gpu")) { @@ -212,27 +227,27 @@ int main(int argc, char *argv[]) { InitTensorEngine(); net = new ConvNet(batch_size, insize, nchannel, ksize, kstride, psize, num_out); } - + // temp output layer - TensorContainer pred; + TensorContainer pred; pred.Resize(Shape2(batch_size, num_out)); - - // label + + // label std::vector ytrain, ytest; // data TensorContainer xtrain_, xtest_; LoadMNIST("train-images-idx3-ubyte", "train-labels-idx1-ubyte", ytrain, xtrain_, true); LoadMNIST("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte", ytest, xtest_, false); - + TensorContainer xtrain(Shape4(xtrain_.size(0), 1, insize, insize)); TensorContainer xtest(Shape4(xtest_.size(0), 1, insize, insize)); xtrain = reshape(xtrain_, xtrain.shape_); xtest = reshape(xtest_, xtest.shape_); - + int num_iter = 20; - + for (int i = 0; i < num_iter; ++ i) { - // training + // training for (index_t j = 0; j + batch_size <= xtrain.size(0); j += batch_size) { net->Forward(xtrain.Slice(j, j + batch_size), pred); // set gradient into pred @@ -249,15 +264,15 @@ int main(int argc, char *argv[]) { // evaluation long nerr = 0; for (index_t j = 0; j + batch_size <= xtest.size(0); j += batch_size) { - net->Forward(xtest.Slice(j, j + batch_size), pred); - for (int k = 0; k < batch_size; ++ k) { - nerr += MaxIndex(pred[k]) != ytest[j+k]; + net->Forward(xtest.Slice(j, j + batch_size), pred); + for (int k = 0; k < batch_size; ++ k) { + nerr += MaxIndex(pred[k]) != ytest[j+k]; } } printf("round %d: test-err=%f\n", i, (float)nerr/xtest.size(0)); - } + } delete net; - + if (!strcmp(argv[1], "gpu")) { ShutdownTensorEngine(); } else { diff --git a/guide/neuralnet/nnet.cu b/guide/neuralnet/nnet.cu index 8e79cf608f3c..6ef8b0db3f64 100644 --- a/guide/neuralnet/nnet.cu +++ b/guide/neuralnet/nnet.cu @@ -21,13 +21,13 @@ struct sigmoid{ class INNet{ public: virtual void Forward(const Tensor& inbatch, Tensor &oubatch) = 0; - virtual void Backprop(const Tensor& gradout) = 0; + virtual void Backprop(const Tensor& gradout) = 0; virtual void Update(void) = 0; virtual ~INNet() {} }; -/*! - * \brief simple two layer neural net +/*! + * \brief simple two layer neural net * this implementation is device invariant */ template @@ -35,6 +35,20 @@ class NNet : public INNet { public: // initialize the network NNet(int batch_size, int num_in, int num_hidden, int num_out) : rnd(0) { + // setup stream + Stream *stream = NewStream(); + ninput.set_stream(stream); + nhidden.set_stream(stream); + nhiddenbak.set_stream(stream); + nout.set_stream(stream); + hbias.set_stream(stream); + g_hbias.set_stream(stream); + g_obias.set_stream(stream); + obias.set_stream(stream); + Wi2h.set_stream(stream); + Wh2o.set_stream(stream); + g_Wi2h.set_stream(stream); + g_Wh2o.set_stream(stream); // setup nodes ninput.Resize(Shape2(batch_size, num_in)); nhidden.Resize(Shape2(batch_size, num_hidden)); @@ -48,7 +62,7 @@ class NNet : public INNet { Wi2h.Resize(Shape2(num_in, num_hidden)); g_Wi2h.Resize(Wi2h.shape_); Wh2o.Resize(Shape2(num_hidden, num_out)); g_Wh2o.Resize(Wh2o.shape_); rnd.SampleGaussian(&Wi2h, 0, 0.01f); - rnd.SampleGaussian(&Wh2o, 0, 0.01f); + rnd.SampleGaussian(&Wh2o, 0, 0.01f); } virtual ~NNet() {} // forward propagation @@ -57,35 +71,35 @@ class NNet : public INNet { // size is same conventsion as numpy index_t batch_size = inbatch.size(0); // copy data to input layer - Copy(ninput, inbatch); + Copy(ninput, inbatch, ninput.stream_); // first layer, fullc nhidden = dot(ninput, Wi2h); nhidden+= repmat(hbias, batch_size); - // activation, sigmloid, backup activation in nhidden + // activation, sigmloid, backup activation in nhidden nhidden = F(nhidden); - Copy(nhiddenbak, nhidden); + Copy(nhiddenbak, nhidden, nhiddenbak.stream_); // second layer fullc nout = dot(nhiddenbak, Wh2o); nout += repmat(obias, batch_size); // softmax calculation Softmax(nout, nout); // copy result out - Copy(oubatch, nout); + Copy(oubatch, nout, nout.stream_); } // back propagation virtual void Backprop(const Tensor& gradout) { // copy gradient to output layer - Copy(nout, gradout); + Copy(nout, gradout, nout.stream_); // calc grad of layer 2 g_obias = sum_rows(nout); g_Wh2o = dot(nhiddenbak.T(), nout); - // backprop to layer 1 + // backprop to layer 1 nhiddenbak = dot(nout, Wh2o.T()); // calculate gradient of sigmoid layer nhidden = nhidden * (1.0f-nhidden) * nhiddenbak; // calc grad of layer 1 g_hbias = sum_rows(nhidden); - g_Wi2h = dot(ninput.T(), nhidden); + g_Wi2h = dot(ninput.T(), nhidden); } // update weight virtual void Update(void) { @@ -107,7 +121,7 @@ class NNet : public INNet { // hidden bias, gradient TensorContainer hbias, obias, g_hbias, g_obias; // weight gradient - TensorContainer Wi2h, Wh2o, g_Wi2h, g_Wh2o; + TensorContainer Wi2h, Wh2o, g_Wi2h, g_Wh2o; }; // helper function to get the max inde inline int MaxIndex(Tensor pred) { @@ -123,12 +137,12 @@ int main(int argc, char *argv[]) { printf("Usage: cpu or gpu\n"); return 0; } srand(0); - + // settings int batch_size = 100; int num_in = 28 * 28; int num_hidden = 100; - int num_out = 10; + int num_out = 10; // choose which version to use INNet *net; if (!strcmp(argv[1], "gpu")) { @@ -138,22 +152,22 @@ int main(int argc, char *argv[]) { InitTensorEngine(); net = new NNet(batch_size, num_in, num_hidden, num_out); } - + // temp output layer - TensorContainer pred; + TensorContainer pred; pred.Resize(Shape2(batch_size, num_out)); - - // label + + // label std::vector ytrain, ytest; // data TensorContainer xtrain, xtest; LoadMNIST("train-images-idx3-ubyte", "train-labels-idx1-ubyte", ytrain, xtrain, true); LoadMNIST("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte", ytest, xtest, false); - + int num_iter = 20; - + for (int i = 0; i < num_iter; ++ i) { - // training + // training for (index_t j = 0; j + batch_size <= xtrain.size(0); j += batch_size) { net->Forward(xtrain.Slice(j, j + batch_size), pred); // set gradient into pred @@ -170,10 +184,10 @@ int main(int argc, char *argv[]) { // evaluation long nerr = 0; for (index_t j = 0; j + batch_size <= xtest.size(0); j += batch_size) { - net->Forward(xtest.Slice(j, j + batch_size), pred); - for (int k = 0; k < batch_size; ++ k) { + net->Forward(xtest.Slice(j, j + batch_size), pred); + for (int k = 0; k < batch_size; ++ k) { nerr += MaxIndex(pred[k]) != ytest[j+k]; - + } } printf("round %d: test-err=%f\n", i, (float)nerr/xtest.size(0)); diff --git a/mshadow/base.h b/mshadow/base.h index fd5dc1bc0c55..44e000a19572 100644 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -20,7 +20,7 @@ // macro defintiions /*! * \brief if this macro is define to be 1, - * mshadow should compile without any of other libs + * mshadow should compile without any of other libs */ #ifndef MSHADOW_STAND_ALONE #define MSHADOW_STAND_ALONE 0 @@ -30,9 +30,9 @@ #define MSHADOW_ALLOC_PAD true #endif /*! - * \brief + * \brief * x dimension of data must be bigger pad_size * ratio to be alloced padded memory, - * otherwise use tide allocation + * otherwise use tide allocation * for example, if pad_ratio=2, GPU memory alignement size is 32, * then we will only allocate padded memory if x dimension > 64 * set it to 0 then we will always allocate padded memory @@ -52,7 +52,7 @@ * error will be shot when default stream NULL is used */ #ifndef MSHADOW_FORCE_STREAM -#define MSHADOW_FORCE_STREAM 0 +#define MSHADOW_FORCE_STREAM 1 #endif /*! \brief use CBLAS for CBLAS */ @@ -71,6 +71,13 @@ #define MSHADOW_USE_CUDA 1 #endif +/*! + * \brief use CUDNN support, must ensure that the cudnn include path is correct + */ +#ifndef MSHADOW_USE_CUDNN + #define MSHADOW_USE_CUDNN 0 +#endif + /*! * \brief seems CUDAARCH is deprecated in future NVCC * set this to 1 if you want to use CUDA version smaller than 2.0 @@ -112,10 +119,16 @@ extern "C" { #endif #if MSHADOW_USE_CUDA - #include + #include #include #endif +#if MSHADOW_USE_CUDNN + #ifdef __CUDACC__ + #include + #endif +#endif + #if MSHADOW_USE_NVML #include #endif @@ -128,7 +141,7 @@ extern "C" { #define MSHADOW_FORCE_INLINE __forceinline #pragma warning( disable : 4068 ) #else -#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) +#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) #endif #ifdef __CUDACC__ #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ @@ -292,7 +305,7 @@ struct divto { namespace red { namespace limits { /*! - * \brief minimum value of certain types + * \brief minimum value of certain types * \tparam DType data type */ template @@ -321,9 +334,9 @@ struct sum { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { dst += src; } - /*! + /*! *\brief calculate gradient of redres with respect to redsrc, - * redres: reduced result, redsrc: one of reduction element + * redres: reduced result, redsrc: one of reduction element */ template MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { @@ -331,7 +344,7 @@ struct sum { } /*! *\brief set the initial value during reduction - */ + */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { initv = 0; @@ -355,7 +368,7 @@ struct maximum { } /*! *\brief set the initial value during reduction - */ + */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { initv = limits::MinValue(); diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 5ffcde36d455..4804d0f8a41b 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -30,40 +30,46 @@ struct BLASEngine { } inline static void SetStream(Stream *stream) { } - inline static void gemm(bool transa, bool transb, + inline static void gemm(Stream *stream, + bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } - inline static void gemm(bool transa, bool transb, + inline static void gemm(Stream *stream, + bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc) { cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } - inline static void gemv(bool trans, int m, int n, + inline static void gemv(Stream *stream, + bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY) { cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } - inline static void gemv(bool trans, int m, int n, double alpha, + inline static void gemv(Stream *stream, + bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY) { cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } - inline static void ger(int m, int n, float alpha, + inline static void ger(Stream *stream, + int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } - inline static void ger(int m, int n, double alpha, + inline static void ger(Stream *stream, + int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); @@ -75,49 +81,70 @@ struct BLASEngine { // All CuBLAS goes to here, use legacy API: not threadsafe template<> struct BLASEngine { - inline static char GetT(bool t) { - return t ? 'T' : 'N'; + inline static cublasOperation_t GetT(bool t) { + return t ? CUBLAS_OP_T : CUBLAS_OP_N; } inline static void SetStream(Stream *stream) { - cublasSetKernelStream(Stream::GetStream(stream)); + cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), + Stream::GetStream(stream)); + utils::Check(err == CUBLAS_STATUS_SUCCESS, + "cublas: set stream fail, set stream for tensor before use cublas"); } - inline static void gemm(bool transa, bool transb, + inline static void gemm(Stream *stream, + bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { - cublasSgemm(GetT(transa), GetT(transb), m, n, k, alpha, - A, lda, B, ldb, beta, C, ldc); + cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, B, ldb, &beta, C, ldc); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Sgemm fail"); } - inline static void gemm(bool transa, bool transb, + inline static void gemm(Stream *stream, + bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc) { - cublasDgemm(GetT(transa), GetT(transb), m, n, k, alpha, - A, lda, B, ldb, beta, C, ldc); + cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, B, ldb, &beta, C, ldc); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Dgemm fail"); } - inline static void gemv(bool trans, int m, int n, float alpha, + inline static void gemv(Stream *stream, + bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY) { - cublasSgemv(GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); + cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), + GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Sgemv fail"); } - inline static void gemv(bool trans, int m, int n, double alpha, + inline static void gemv(Stream *stream, + bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY) { - cublasDgemv(GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); + cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), + GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Dgemv fail"); } - inline static void ger(int m, int n, float alpha, + inline static void ger(Stream *stream, + int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { - cublasSger(m, n, alpha, X, incX, Y, incY, A, lda); + cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), + m, n, &alpha, X, incX, Y, incY, A, lda); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Sger fail"); } - inline static void ger(int m, int n, double alpha, + inline static void ger(Stream *stream, + int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { - cublasDger(m, n, alpha, X, incX, Y, incY, A, lda); + cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), + m, n, &alpha, X, incX, Y, incY, A, lda); + utils::Check(err == CUBLAS_STATUS_SUCCESS, "cublas: Dger fail"); } }; #endif // MSHADOW_USE_CUDA @@ -135,6 +162,7 @@ struct DotEngine { DType scale) { Tensor &dst = *p_dst; // set kernel stream + // if there is no stream, crush BLASEngine::SetStream(dst.stream_); Shape<2> sleft = GetShape(lhs.shape_, transpose_left); Shape<2> sright = GetShape(rhs.shape_, transpose_right); @@ -143,7 +171,8 @@ struct DotEngine { "dot-gemm: matrix shape mismatch"); // use column major argument to compatible with most BLAS BLASEngine::gemm - (transpose_right , transpose_left, + (dst.stream_, + transpose_right , transpose_left, transpose_right ? rhs.size(0) : rhs.size(1), transpose_left ? lhs.size(1) : lhs.size(0), transpose_right ? rhs.size(1) : rhs.size(0), @@ -162,12 +191,14 @@ struct DotEngine { DType scale) { Tensor &dst = *p_dst; // set kernel stream + // if there is no stream, crush BLASEngine::SetStream(dst.stream_); Shape<2> sright = GetShape(rhs.shape, transpose_right); utils::Check(dst.size(0) == sright[1] && lhs.size(0) == sright[0], "dot-gemv: matrix shape mismatch"); BLASEngine::gemv - (transpose_right, + (dst.stream_, + transpose_right, rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(), rhs.dptr_, rhs.stride_, lhs.dptr_, 1, SV::BetaBLAS(), @@ -182,12 +213,13 @@ struct DotEngine { DType scale) { Tensor &dst = *p_dst; // set kernel stream + // if there is no stream, crush BLASEngine::SetStream(dst.stream_); utils::Check(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0), "dot-ger: matrix shape mismatch"); if (SV::BetaBLAS() == 0.0f) { BLASEngine::ger - (rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), + (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_); } else { DotEngine struct Stream { + /*! \brief handle state */ + enum HandleState { + NoHandle = 0, + OwnHandle = 1, + }; /*! \brief cudaStream */ cudaStream_t stream_; - Stream(void) : stream_(0) {} + /*! \brief cublas handle */ + cublasHandle_t blas_handle_; + /*! \brief cublas handle ownership */ + HandleState blas_handle_ownership_; + /*! \brief cudnn handle ownership */ + HandleState dnn_handle_ownership_; +#if MSHADOW_USE_CUDNN == 1 + /*! \brief cudnn handle */ + cudnnHandle_t dnn_handle_; +#endif + Stream(void) : stream_(0), + blas_handle_ownership_(NoHandle), + dnn_handle_ownership_(NoHandle) {} /*! * \brief wait for all the computation associated * with this stream to complete @@ -30,7 +47,7 @@ struct Stream { /*! * \brief query whether the the stream is idle * \return true if the stream is idle and all the job have been completed - */ + */ inline bool CheckIdle(void) { cudaError_t err = cudaStreamQuery(stream_); if (err == cudaSuccess) return true; @@ -51,11 +68,74 @@ struct Stream { } else return stream->stream_; } + /*! + * \brief return actual cublasHandle + * \param pointer to GPU stream + */ + inline static cublasHandle_t GetBlasHandle(Stream *stream) { + if (stream == NULL) { + return 0; + } else { + utils::Check(stream->blas_handle_ownership_ != NoHandle, + "No handle exist in source stream"); + return stream->blas_handle_; + } + } + /*! \brief Destory cublas handle if own it */ + inline void DestoryBlasHandle() { + if (blas_handle_ownership_ == OwnHandle) { + cublasStatus_t err = cublasDestroy(blas_handle_); + blas_handle_ownership_ = NoHandle; + utils::Check(err == CUBLAS_STATUS_SUCCESS, "Destory cublas handle failed"); + } + } + /*! \brief Destory original blas handle and create a new one */ + inline void CreateBlasHandle() { + this->DestoryBlasHandle(); + cublasStatus_t err = cublasCreate(&blas_handle_); + blas_handle_ownership_ = OwnHandle; + utils::Check(err == CUBLAS_STATUS_SUCCESS, "Create cublas handle failed"); + } +#if MSHADOW_USE_CUDNN && defined(__CUDACC__) + inline static cudnnHandle_t GetDnnHandle(Stream *stream) { + if (stream == NULL) { + return 0; + } else { + utils::Check(stream->dnn_handle_ownership_ != NoHandle, + "No handle exist in source stream"); + return stream->dnn_handle_; + } + } +#endif + inline void DestroyDnnHandle() { +#if MSHADOW_USE_CUDNN && defined(__CUDACC__) + if (dnn_handle_ownership_ == OwnHandle) { + cudnnStatus_t err = cudnnDestroy(dnn_handle_); + utils::Check(err == CUDNN_STATUS_SUCCESS, + "Destroy cudnn handle failed"); + } +#endif + } + inline void CreateDnnHandle() { +#if MSHADOW_USE_CUDNN && defined(__CUDACC__) + this->DestroyDnnHandle(); + cudnnStatus_t err = cudnnCreate(&dnn_handle_); + utils::Check(err == CUDNN_STATUS_SUCCESS, + "Create cudnn handle failed"); +#endif + } }; template<> -inline Stream *NewStream(void) { +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle) { Stream *st = new Stream(); cudaError_t err = cudaStreamCreate(&st->stream_); + if (create_blas_handle) { + st->CreateBlasHandle(); + } + if (create_dnn_handle) { + st->CreateDnnHandle(); + } utils::Check(err == cudaSuccess, cudaGetErrorString(err)); return st; } @@ -63,8 +143,10 @@ template<> inline void DeleteStream(Stream *stream) { cudaError_t err = cudaStreamDestroy(stream->stream_); utils::Check(err == cudaSuccess, cudaGetErrorString(err)); + stream->DestoryBlasHandle(); + stream->DestroyDnnHandle(); delete stream; } -#endif +#endif } #endif // MSHADOW_STREAM_GPU_INL_H_ diff --git a/mshadow/tensor.h b/mshadow/tensor.h index 89df8dc4a60a..97a0dfbdb2a0 100644 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -52,7 +52,7 @@ struct Shape { for (int i = 0; i < kDimension; ++i) { this->shape_[i] = s[i]; } - } + } /*! * \brief get corresponding index * \param idx dimension index @@ -70,7 +70,7 @@ struct Shape { return shape_[idx]; } /*! - * \return whether two shape equals + * \return whether two shape equals * \param s the shape to compare against */ MSHADOW_XINLINE bool operator==(const Shape &s) const { @@ -220,6 +220,8 @@ struct Stream { inline bool CheckIdle(void) { return true; } + /*! \brief create a blas handle */ + inline void CreateBlasHandle() {} }; /*! * \brief Tensor RValue, this is the super type of all kinds of possible tensors @@ -466,11 +468,19 @@ template inline void SetDevice(int devid); /*! * \brief create a new stream from system + * \param create_blas_handle whether create blas handle in stream + * \param create_dnn_handle whether create cudnn handle in stream * \return a pointer to the created stream * \tparam Device the device type */ template -inline Stream *NewStream(void); +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle); +/*! \brief default behavior: create cublas handle */ +template +inline Stream *NewStream() { + return NewStream(true, false); +} /*! * \brief delete the computing stream * \param stream the stream parameter to be deleted diff --git a/mshadow/tensor_cpu-inl.h b/mshadow/tensor_cpu-inl.h index 1ec5fa2c3ad5..2f8c3edfaf39 100644 --- a/mshadow/tensor_cpu-inl.h +++ b/mshadow/tensor_cpu-inl.h @@ -23,7 +23,8 @@ template<> inline void SetDevice(int devid) { } template<> -inline Stream *NewStream(void) { +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle) { return new Stream(); } template<> diff --git a/mshadow/tensor_gpu-inl.h b/mshadow/tensor_gpu-inl.h index ffd203d33a1a..3c724250f08c 100644 --- a/mshadow/tensor_gpu-inl.h +++ b/mshadow/tensor_gpu-inl.h @@ -28,15 +28,13 @@ inline void InitTensorEngine(int dev_id) { utils::Check(cudaSetDevice(device_id) == cudaSuccess, "cannot set device"); cudaGetDeviceProperties(&prop, device_id); printf("Use CUDA Device %d: %s\n", device_id, prop.name); - cublasInit(); } template<> inline void ShutdownTensorEngine(void) { - cublasShutdown(); } template<> inline void SetDevice(int devid) { - utils::Check(cudaSetDevice(devid) == cudaSuccess, "cannot set device"); + utils::Check(cudaSetDevice(devid) == cudaSuccess, "cannot set device"); } template inline void AllocSpace(Tensor *obj, bool pad) { @@ -132,7 +130,7 @@ inline void MapReduceKeepLowest(TRValue *dst, ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); Shape<2> eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()).FlatTo2D(); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); utils::Check(eshape[1] == dshape[0], "MapReduceKeepLowest::reduction dimension do not match"); utils::Check(eshape[0] != 0, "can not reduce over empty tensor"); @@ -151,7 +149,7 @@ inline void MapReduceKeepHighDim(TRValue *dst, typedef Shape::kDim> EShape; EShape eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); utils::Check(eshape[dimkeep] == dshape[0], "MapReduceKeepHighDim::reduction dimension do not match"); // use equvalent form