Skip to content

Commit

Permalink
Merge pull request apache#36 from dmlc/cublasv2
Browse files Browse the repository at this point in the history
switch to Cublasv2
  • Loading branch information
antinucleon committed Jul 2, 2015
2 parents d9300d1 + 9169d5c commit 8970c83
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 128 deletions.
21 changes: 12 additions & 9 deletions guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ template<typename Device, int dimension, typename DType = float>
struct Tensor {
DType *dptr_;
Shape<dimension> shape_;
Stream<Device> stream_;
index_t stride_;
};
// this is how shape object declaration look like
Shape<2> shape2;
// this is how tensor object declaration look like
// you can
// you can
Tensor<cpu, 2> ts2;
Tensor<gpu, 3, float> ts3;
```
``` Tensor<cpu,2>``` means a two dimensional tensor in CPU, while ``` Tensor<gpu,3>``` means three dimensional tensor in GPU.
``` Tensor<cpu,2>``` means a two dimensional tensor in CPU, while ``` Tensor<gpu,3>``` means three dimensional tensor in GPU.
```Shape<k>``` gives the shape information of k-dimensional tensor. The declaration use template, and
can be specialized into tensor of specific device and dimension. This is what two dimensional tensor will look like:
```c++
Expand All @@ -50,8 +51,8 @@ struct Tensor<cpu, 2, float> {
Shape<2> shape_;
index_t stride_;
};
```
* ``` Tensor<cpu, 2>``` contains ```dptr_```, which points to the space that backup the tensor.
```
* ``` Tensor<cpu, 2>``` contains ```dptr_```, which points to the space that backup the tensor.
* ```Shape<2>``` is a structure that stores shape information, the convention is same as numpy
* ```stride_``` gives the number of cell space allocated in the smallest dimension (if we use numpy convention, the dimension corresponds to shape_[-1]).
This is introduced when we introduce some padding cells in lowest dimension to make sure memory is aligned.
Expand All @@ -64,7 +65,7 @@ Tensor<cpu, 2> ts;
ts.dptr_ = data;
ts.shape_ = mshadow::Shape2(3, 2);
ts.stride_ = 3;
// now: ts[0][0] == 0, ts[0][1] == 1 , ts[1][0] == 3, ts[1][1] == 4
// now: ts[0][0] == 0, ts[0][1] == 1 , ts[1][0] == 3, ts[1][1] == 4
for (index_t i = 0; i < ts.size(0); ++i) {
for (index_t j = 0; j < ts.size(1), ++j) {
printf("ts[%u][%u]=%f\n", i, j, ts[i][j]);
Expand All @@ -73,10 +74,12 @@ for (index_t i = 0; i < ts.size(0); ++i) {
```
The result ts should be a 3 * 2 matrix, where data[2], data[5], data[8] are padding cells that are ignored. If you want a continuous memory, set ```stride_=shape_[1]```.

NOTICE: We highly recommend use stream in ```gpu``` mode, there will be an error thrown out if no stream is set. Check [basic_stream.cu](basic_stream.cu) for more detail.

Memory Allocation
====
An important design choice about mshadow is that the data structure is a **whitebox**:
it works so long as we set the space pointer ```dptr_```, corresponding ```shape_``` and ```stride_```:
it works so long as we set the space pointer ```dptr_```, corresponding ```shape_``` and ```stride_```:
* For ```Tensor<cpu, k>```, the space can be created by ```new float[]```, or pointer to some existing space such as float array in last example.
* For ```Tensor<gpu, k>```, the space need to lie in GPU, created by ```cudaMallocPitch```

Expand Down Expand Up @@ -197,17 +200,17 @@ int main(void) {
// Tensor object is only a handle, assignment means they have same data content
// we can specify content type of a Tensor, if not specified, it is float bydefault
Tensor<cpu, 2, float> mat2 = mat;
// shaape of matrix, note size order is same as numpy
printf("%u X %u matrix\n", mat.size(1), mat.size(1));
// initialize all element to zero
mat = 0.0f;
// assign some values
mat[0][1] = 1.0f; mat[1][0] = 2.0f;
// elementwise operations
mat += (mat + 10.0f) / 10.0f + 2.0f;
// print out matrix, note: mat2 and mat1 are handles(pointers)
for (index_t i = 0; i < mat.size(0); ++i) {
for (index_t j = 0; j < mat.size(1); ++j) {
Expand Down
8 changes: 4 additions & 4 deletions guide/neuralnet/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
#
# This is configuration script that you can use to compile mshadow
# Usage:
#
#
# include config.mk in your Makefile, or directly include the definition of variables
# include mshadow.mk after the variables are set
#
#
# Add MSHADOW_CFLAGS to the compile flags
# Add MSHADOW_LDFLAGS to the linker flags
# Add MSHADOW_NVCCFLAGS to the nvcc compile flags
Expand All @@ -22,11 +22,11 @@ USE_CUDA_PATH = NONE
#
# choose the version of blas you want to use
# can be: mkl, blas, atlas, openblas, apple
USE_BLAS = mkl
USE_BLAS = openblas
#
# add path to intel library, you may need it
# for MKL, if you did not add the path to enviroment variable
#
#
USE_INTEL_PATH = NONE

# whether compile with parameter server
Expand Down
97 changes: 56 additions & 41 deletions guide/neuralnet/convnet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ using namespace mshadow;
// this namespace contains all operator overloads
using namespace mshadow::expr;

// define operations
// define operations
struct relu{
MSHADOW_XINLINE static real_t Map(real_t a) {
using namespace std;
return max(a, 0.0f);
}
}
};
struct relu_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
Expand All @@ -26,12 +26,12 @@ struct relu_grad {
class INNet{
public:
virtual void Forward(const Tensor<cpu, 4, real_t>& inbatch, Tensor<cpu, 2, real_t> &oubatch) = 0;
virtual void Backprop(const Tensor<cpu, 2, real_t>& gradout) = 0;
virtual void Backprop(const Tensor<cpu, 2, real_t>& gradout) = 0;
virtual void Update(void) = 0;
virtual ~INNet() {}
};

/*!
/*!
* \brief simple two layer conv-net conv-pool-flat-fullc
* this implementation is device invariant
*/
Expand All @@ -41,9 +41,24 @@ class ConvNet : public INNet {
// initialize the network
ConvNet(int batch_size, int insize, int nchannel, int ksize, int kstride, int psize, int num_out)
:rnd(0), ksize(ksize), kstride(kstride), psize(psize) {
// setup stream
Stream<xpu> *stream = NewStream<xpu>();
ninput.set_stream(stream);
nhidden.set_stream(stream);
nhiddenbak.set_stream(stream);
npool.set_stream(stream);
npoolbak.set_stream(stream);
nflat.set_stream(stream);
nout.set_stream(stream);
hbias.set_stream(stream); g_hbias.set_stream(stream);
obias.set_stream(stream); g_obias.set_stream(stream);
Ki2h.set_stream(stream); g_Ki2h.set_stream(stream);
Wh2o.set_stream(stream); g_Wh2o.set_stream(stream);
tmp_col.set_stream(stream);
tmp_dst.set_stream(stream);
// setup nodes
ninput.Resize(Shape4(batch_size, 1, insize, insize));
nhidden.Resize(Shape4(batch_size, nchannel, (insize - ksize)/kstride+1, (insize -ksize)/kstride+1));
nhidden.Resize(Shape4(batch_size, nchannel, (insize - ksize)/kstride+1, (insize -ksize)/kstride+1));
nhiddenbak.Resize(nhidden.shape_);
npool.Resize(Shape4(batch_size, nchannel, (nhidden.size(2)+1-psize)/psize, (nhidden.size(3)+1-psize)/psize));
npoolbak.Resize(npool.shape_);
Expand All @@ -58,25 +73,25 @@ class ConvNet : public INNet {
Wh2o.Resize(Shape2(nflat.size(1), num_out)); g_Wh2o.Resize(Wh2o.shape_);
rnd.SampleGaussian(&Ki2h, 0, 0.01f);
rnd.SampleGaussian(&Wh2o, 0, 0.01f);

printf("conv=%d, pool=%d\n", nhidden.size(3), npool.size(3));
}
virtual ~ConvNet() {}
// forward propagation
virtual void Forward(const Tensor<cpu, 4, real_t>& inbatch, Tensor<cpu, 2, real_t> &oubatch) {
index_t batch_size = inbatch.size(0);
// copy data to input layer
Copy(ninput, inbatch);
// copy data to input layer
Copy(ninput, inbatch, ninput.stream_);
// first layer, conv, use stride=2
ConvForward(ninput, Ki2h, nhidden, ksize, kstride, tmp_col, tmp_dst);
// add bias
nhidden += broadcast<1>(hbias, nhidden.shape_);
// activation, relu, backup activation in nhidden
// activation, relu, backup activation in nhidden
nhidden = F<relu>(nhidden);
Copy(nhiddenbak, nhidden);
// max pooling
Copy(nhiddenbak, nhidden, nhiddenbak.stream_);
// max pooling
npool = pool<red::maximum>(nhiddenbak, npool[0][0].shape_, psize, psize, psize);
Copy(npoolbak, npool);
Copy(npoolbak, npool, npoolbak.stream_);
// flat
nflat = reshape(npool, nflat.shape_);
// second layer fullc
Expand All @@ -85,20 +100,20 @@ class ConvNet : public INNet {
// softmax calculation
Softmax(nout, nout);
// copy result out
Copy(oubatch, nout);
Copy(oubatch, nout, nout.stream_);
}
// back propagation
virtual void Backprop(const Tensor<cpu, 2, real_t>& gradout) {
// copy gradient to output layer
Copy(nout, gradout);
Copy(nout, gradout, nout.stream_);
// calc grad of final layer
g_obias = sum_rows(nout);
g_Wh2o = dot(nflat.T(), nout);
// backprop to previous layer
nflat = dot(nout, Wh2o.T());
npool = reshape(nflat, npool.shape_);
// backprop pooling layer
nhiddenbak = unpool<red::maximum>(nhiddenbak, npoolbak, npool, psize, psize, psize);
nhiddenbak = unpool<red::maximum>(nhiddenbak, npoolbak, npool, psize, psize, psize);
// calculate gradient of relu layer
nhidden = F<relu_grad>(nhidden) * nhiddenbak;
// calc grad of layer 1
Expand All @@ -118,10 +133,10 @@ class ConvNet : public INNet {
obias-= eta * g_obias;
}
private:
// forward convolution, tmp_col and tmp_dst are helper structure
// forward convolution, tmp_col and tmp_dst are helper structure
inline static void ConvForward(const Tensor<xpu, 4, real_t> &in,
const Tensor<xpu, 2, real_t> &kernel,
Tensor<xpu, 4, real_t> &out,
Tensor<xpu, 4, real_t> &out,
int ksize, int kstride,
TensorContainer<xpu, 2, real_t> &tmp_col,
TensorContainer<xpu, 2, real_t> &tmp_dst) {
Expand All @@ -130,20 +145,20 @@ class ConvNet : public INNet {
index_t nbatch = in.size(0);
index_t nchannel = out.size(1);
// we directly unpack all local patches and do a dot product
// this cost lots of memory, normally for large image, only unpack several image at a time
// this cost lots of memory, normally for large image, only unpack several image at a time
tmp_col.Resize(Shape2(in.size(1)*ksize*ksize, nbatch*oheight*owidth));
tmp_dst.Resize(Shape2(nchannel, nbatch*oheight*owidth));
// unpack local patches , stride=1
tmp_col = unpack_patch2col(in, ksize, ksize, kstride);
tmp_dst = dot(kernel, tmp_col);
// reshape, then swap axis, we chain equations together
// reshape, then swap axis, we chain equations together
out = swapaxis<1,0>(reshape(tmp_dst, Shape4(nchannel, nbatch, oheight, owidth)));
}
}
// backward convolution, calculate gradient of kernel, and backprop back to in
inline static void ConvBackWard(const Tensor<xpu, 4, real_t> &out,
const Tensor<xpu, 2, real_t> &kernel,
const Tensor<xpu, 2, real_t> &kernel,
Tensor<xpu, 2, real_t> &g_kernel,
Tensor<xpu, 4, real_t> &in,
Tensor<xpu, 4, real_t> &in,
int ksize, int kstride,
TensorContainer<xpu, 2, real_t> &tmp_col,
TensorContainer<xpu, 2, real_t> &tmp_dst) {
Expand All @@ -152,13 +167,13 @@ class ConvNet : public INNet {
index_t nbatch = in.size(0);
index_t nchannel = out.size(1);
// we directly unpack all local patches and do a dot product
// this cost lots of memory, normally for large image, only unpack several image at a time
// this cost lots of memory, normally for large image, only unpack several image at a time
tmp_col.Resize(Shape2(in.size(1) * ksize * ksize,
nbatch * oheight * owidth));
tmp_dst.Resize(Shape2(nchannel, nbatch * oheight * owidth));
// unpack local patches
tmp_col = unpack_patch2col(in, ksize, ksize, kstride);
tmp_dst = reshape(swapaxis<1,0>(out), tmp_dst.shape_);
// unpack local patches
tmp_col = unpack_patch2col(in, ksize, ksize, kstride);
tmp_dst = reshape(swapaxis<1,0>(out), tmp_dst.shape_);
g_kernel = dot(tmp_dst, tmp_col.T());
// backpropgation: not necessary for first layer, but included anyway
tmp_col = dot(kernel.T(), tmp_dst);
Expand Down Expand Up @@ -193,7 +208,7 @@ int main(int argc, char *argv[]) {
if(argc < 2) {
printf("Usage: cpu or gpu\n"); return 0;
}
srand(0);
srand(0);
// settings
int batch_size = 100;
int insize = 28;
Expand All @@ -202,7 +217,7 @@ int main(int argc, char *argv[]) {
int kstride = 1;
int psize = 2;
int num_out = 10;

// choose which version to use
INNet *net;
if (!strcmp(argv[1], "gpu")) {
Expand All @@ -212,27 +227,27 @@ int main(int argc, char *argv[]) {
InitTensorEngine<cpu>();
net = new ConvNet<cpu>(batch_size, insize, nchannel, ksize, kstride, psize, num_out);
}

// temp output layer
TensorContainer<cpu, 2, real_t> pred;
TensorContainer<cpu, 2, real_t> pred;
pred.Resize(Shape2(batch_size, num_out));
// label

// label
std::vector<int> ytrain, ytest;
// data
TensorContainer<cpu, 2, real_t> xtrain_, xtest_;
LoadMNIST("train-images-idx3-ubyte", "train-labels-idx1-ubyte", ytrain, xtrain_, true);
LoadMNIST("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte", ytest, xtest_, false);

TensorContainer<cpu, 4, real_t> xtrain(Shape4(xtrain_.size(0), 1, insize, insize));
TensorContainer<cpu, 4, real_t> xtest(Shape4(xtest_.size(0), 1, insize, insize));
xtrain = reshape(xtrain_, xtrain.shape_);
xtest = reshape(xtest_, xtest.shape_);

int num_iter = 20;

for (int i = 0; i < num_iter; ++ i) {
// training
// training
for (index_t j = 0; j + batch_size <= xtrain.size(0); j += batch_size) {
net->Forward(xtrain.Slice(j, j + batch_size), pred);
// set gradient into pred
Expand All @@ -249,15 +264,15 @@ int main(int argc, char *argv[]) {
// evaluation
long nerr = 0;
for (index_t j = 0; j + batch_size <= xtest.size(0); j += batch_size) {
net->Forward(xtest.Slice(j, j + batch_size), pred);
for (int k = 0; k < batch_size; ++ k) {
nerr += MaxIndex(pred[k]) != ytest[j+k];
net->Forward(xtest.Slice(j, j + batch_size), pred);
for (int k = 0; k < batch_size; ++ k) {
nerr += MaxIndex(pred[k]) != ytest[j+k];
}
}
printf("round %d: test-err=%f\n", i, (float)nerr/xtest.size(0));
}
}
delete net;

if (!strcmp(argv[1], "gpu")) {
ShutdownTensorEngine<gpu>();
} else {
Expand Down
Loading

0 comments on commit 8970c83

Please sign in to comment.