Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-107] Add Fused Vanilla RNN and dropout for CPU #11399

Merged
merged 1 commit into from
Jun 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion example/rnn/bucketing/cudnn_rnn_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
parser.add_argument('--dropout', type=float, default='0.0',
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru and lstm are supported')
help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -188,6 +188,20 @@ def test(args):
cell,
mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'rnn_tanh':
cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'rnn_relu':
cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))

stack.add(cell)

Expand Down
74 changes: 43 additions & 31 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
+ seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand All @@ -125,18 +125,20 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1);
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) +
batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
seq_length * batch_size * 7 * hidden_size * direction;
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) +
batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 +
seq_length * batch_size * 2 * hidden_size * direction;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -223,21 +225,24 @@ void RNNForwardTraining(DType* ws,
DType* y_ptr,
DType* hy_ptr,
DType* cy_ptr,
const float dropout,
int mode) {
switch (mode) {
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, dropout);
break;
case rnn_enum::kGru:
GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
w_ptr, y_ptr, hy_ptr, dropout);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr, dropout, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
Expand All @@ -264,10 +269,6 @@ void RNNForwardInference(DType* ws,
DType* cy_ptr,
int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
Expand All @@ -278,6 +279,12 @@ void RNNForwardInference(DType* ws,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
Expand Down Expand Up @@ -310,22 +317,27 @@ void RNNBackward(DType* ws,
int req_params,
int req_state,
int req_statecell,
const float dropout,
int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
break;
case rnn_enum::kLstm:
LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
req_data, req_params, req_state, req_statecell);
req_data, req_params, req_state, req_statecell, dropout);
break;
case rnn_enum::kGru:
GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
req_data, req_params, req_state);
req_data, req_params, req_state, dropout);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
req_data, req_params, req_state, dropout, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
Expand Down Expand Up @@ -354,9 +366,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
Expand Down Expand Up @@ -436,6 +447,7 @@ class RNNOp : public Operator{
y.dptr_,
hy_ptr,
cy_ptr,
param_.p,
param_.mode);
} else {
RNNForwardInference<DType>(workspace.dptr_,
Expand Down Expand Up @@ -467,9 +479,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
Expand Down Expand Up @@ -566,6 +577,7 @@ class RNNOp : public Operator{
req[rnn_enum::kParams],
req[rnn_enum::kState],
req[rnn_enum::kStateCell],
param_.p,
param_.mode);
}

Expand Down
Loading