Skip to content

Commit

Permalink
Update SoftMax to work in spatial mode
Browse files Browse the repository at this point in the history
This updates SoftMax to work in spatial mode thus it accepts 1D/2D/3D/4D
tensor as input.
  • Loading branch information
jhjin committed Sep 19, 2015
1 parent bf25188 commit 80929fe
Showing 1 changed file with 70 additions and 57 deletions.
127 changes: 70 additions & 57 deletions SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#define MINUS_LOG_THRESHOLD -18.42
#define SOFTMAX_THREADS 128

__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, int nframe, int dim)
__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input,
int nframe, int dim, int stride)
{
__shared__ float buffer[SOFTMAX_THREADS+1];
int k = blockIdx.x;
float *input_k = input + k*dim;
float *output_k = output + k*dim;
float *input_k = input + blockIdx.x*dim*stride + blockIdx.y;
float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;

int i_start = threadIdx.x;
int i_end = dim;
Expand All @@ -18,7 +18,7 @@ __global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, in
buffer[threadIdx.x] = -FLT_MAX;
for (int i=i_start; i<i_end; i+=i_step)
{
float z = input_k[i];
float z = input_k[i*stride];
if(buffer[threadIdx.x] < z)
buffer[threadIdx.x] = z;
}
Expand All @@ -43,9 +43,9 @@ __global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, in
float max_k = buffer[SOFTMAX_THREADS];
buffer[threadIdx.x] = 0;
for (int i=i_start; i<i_end; i+=i_step) {
float z = __expf(input_k[i]-max_k);
float z = __expf(input_k[i*stride]-max_k);
buffer[threadIdx.x] += z;
output_k[i] = z;
output_k[i*stride] = z;
}

__syncthreads();
Expand All @@ -64,17 +64,17 @@ __global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, in
// softmax
float sum_k = buffer[SOFTMAX_THREADS];
for (int i=i_start; i<i_end; i+=i_step)
output_k[i] = output_k[i] / sum_k;
output_k[i*stride] = output_k[i*stride] / sum_k;
}


__global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *output, float *gradOutput, int nframe, int dim)
__global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *output, float *gradOutput
int nframe, int dim, int stride)
{
__shared__ float buffer[SOFTMAX_THREADS];
int k = blockIdx.x;
float *gradInput_k = gradInput + k*dim;
float *output_k = output + k*dim;
float *gradOutput_k = gradOutput + k*dim;
float *gradInput_k = gradInput + blockIdx.x*dim*stride + blockIdx.y;
float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;
float *gradOutput_k = gradOutput + blockIdx.x*dim*stride + blockIdx.y;

int i_start = threadIdx.x;
int i_end = dim;
Expand All @@ -83,7 +83,7 @@ __global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *out
// sum?
buffer[threadIdx.x] = 0;
for (int i=i_start; i<i_end; i+=i_step)
buffer[threadIdx.x] += gradOutput_k[i] * output_k[i];
buffer[threadIdx.x] += gradOutput_k[i*stride] * output_k[i*stride];

__syncthreads();

Expand All @@ -100,7 +100,7 @@ __global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *out

float sum_k = buffer[0];
for (int i=i_start; i<i_end; i+=i_step)
gradInput_k[i] = output_k[i] * (gradOutput_k[i] - sum_k);
gradInput_k[i*stride] = output_k[i*stride] * (gradOutput_k[i*stride] - sum_k);
}

static int cunn_SoftMax_updateOutput(lua_State *L)
Expand All @@ -112,27 +112,41 @@ static int cunn_SoftMax_updateOutput(lua_State *L)

input = THCudaTensor_newContiguous(state, input);
THCudaTensor_resizeAs(state, output, input);
long batchSize, dim, stride;

if(input->nDimension == 1)
{
dim3 blocks(1);
dim3 threads(SOFTMAX_THREADS);
cunn_SoftMax_updateOutput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, output),
THCudaTensor_data(state, input),
1, input->size[0]);
batchSize = 1;
dim = input->size[0];
stride = 1;
}
else if(input->nDimension == 2)
{
dim3 blocks(input->size[0]);
dim3 threads(SOFTMAX_THREADS);
cunn_SoftMax_updateOutput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, output),
THCudaTensor_data(state, input),
input->size[0], input->size[1]);
batchSize = input->size[0];
dim = input->size[1];
stride = 1;
}
else if(input->nDimension == 3)
{
batchSize = 1;
dim = input->size[0];
stride = input->size[1]*input->size[2];
}
else if(input->nDimension == 4)
{
batchSize = input->size[0];
dim = input->size[1];
stride = input->size[2]*input->size[3];
}
else
THError("vector or matrix expected");
THError("1D, 2D, 3D or 4D tensor expected");

dim3 blocks(batchSize, stride);
dim3 threads(SOFTMAX_THREADS);
cunn_SoftMax_updateOutput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, output),
THCudaTensor_data(state, input),
batchSize, dim, stride);

cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
Expand All @@ -142,18 +156,6 @@ static int cunn_SoftMax_updateOutput(lua_State *L)
return 1;
}

struct softmaxupdateGradInput_functor
{
float value;

softmaxupdateGradInput_functor(float value_) : value(value_) {}

__host__ __device__ float operator()(const float& output, const float& gradOutput) const
{
return gradOutput - exp(output)*value;
}
};

static int cunn_SoftMax_updateGradInput(lua_State *L)
{
THCState *state = getCutorchState(L);
Expand All @@ -166,31 +168,42 @@ static int cunn_SoftMax_updateGradInput(lua_State *L)
gradOutput = THCudaTensor_newContiguous(state, gradOutput);

THCudaTensor_resizeAs(state, gradInput, output);
long batchSize, dim, stride;

if(gradInput->nDimension == 1)
{
dim3 blocks(1);
dim3 threads(SOFTMAX_THREADS);

cunn_SoftMax_updateGradInput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, gradInput),
THCudaTensor_data(state, output),
THCudaTensor_data(state, gradOutput),
1, gradInput->size[0]);
batchSize = 1;
dim = gradInput->size[0];
stride = 1;
}
else if(gradInput->nDimension == 2)
{
dim3 blocks(gradInput->size[0]);
dim3 threads(SOFTMAX_THREADS);

cunn_SoftMax_updateGradInput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, gradInput),
THCudaTensor_data(state, output),
THCudaTensor_data(state, gradOutput),
gradInput->size[0], gradInput->size[1]);
batchSize = gradInput->size[0];
dim = gradInput->size[1];
stride = 1;
}
else if(gradInput->nDimension == 3)
{
batchSize = 1;
dim = gradInput->size[0];
stride = gradInput->size[1]*gradInput->size[2];
}
else if(gradInput->nDimension == 4)
{
batchSize = gradInput->size[0];
dim = gradInput->size[1];
stride = gradInput->size[2]*gradInput->size[3];
}
else
THError("vector or matrix expected");
THError("1D, 2D, 3D or 4D tensor expected");

dim3 blocks(batchSize, stride);
dim3 threads(SOFTMAX_THREADS);
cunn_SoftMax_updateGradInput_kernel<<<blocks,threads,
0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, gradInput),
THCudaTensor_data(state, output),
THCudaTensor_data(state, gradOutput),
batchSize, dim, stride);

cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
Expand Down

0 comments on commit 80929fe

Please sign in to comment.