From 0a0dce2d0acb4c173139183e8fc1274c388652ef Mon Sep 17 00:00:00 2001 From: Hao Li Date: Wed, 27 Jun 2018 01:43:57 +0800 Subject: [PATCH] add vRNN and dropout (#11399) --- example/rnn/bucketing/cudnn_rnn_bucketing.py | 16 +- src/operator/rnn-inl.h | 74 +- src/operator/rnn_impl.h | 947 ++++++++++++++++++- tests/python/unittest/test_operator.py | 116 ++- 4 files changed, 1099 insertions(+), 54 deletions(-) diff --git a/example/rnn/bucketing/cudnn_rnn_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py index 29a66a8f4843..5825290e73ec 100644 --- a/example/rnn/bucketing/cudnn_rnn_bucketing.py +++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py @@ -66,7 +66,7 @@ parser.add_argument('--dropout', type=float, default='0.0', help='dropout probability (1.0 - keep probability)') parser.add_argument('--rnntype', type=str, default='lstm', - help='rnn type: gru and lstm are supported') + help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported') #buckets = [32] buckets = [10, 20, 30, 40, 50, 60] @@ -188,6 +188,20 @@ def test(args): cell, mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)), output_prefix='bi_%s_%d'%(args.rnntype,i)) + elif args.rnntype == 'rnn_tanh': + cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dl0_'%(args.rnntype,i)) + if args.bidirectional: + cell = mx.rnn.BidirectionalCell( + cell, + mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dr0_'%(args.rnntype,i)), + output_prefix='bi_%s_%d'%(args.rnntype,i)) + elif args.rnntype == 'rnn_relu': + cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dl0_'%(args.rnntype,i)) + if args.bidirectional: + cell = mx.rnn.BidirectionalCell( + cell, + mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dr0_'%(args.rnntype,i)), + output_prefix='bi_%s_%d'%(args.rnntype,i)) stack.add(cell) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 99531739afa6..1f905eda4a92 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -99,10 +99,6 @@ inline size_t GetRNNWorkspaceSize(int seq_length, int mode) { size_t size = 0; switch (mode) { - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; - break; case rnn_enum::kLstm: size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2 + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8; @@ -110,6 +106,10 @@ inline size_t GetRNNWorkspaceSize(int seq_length, case rnn_enum::kGru: size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8; break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4; + break; default: LOG(FATAL) << "unknown RNN mode " << mode; break; @@ -125,18 +125,20 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, int mode) { size_t size = 0; switch (mode) { - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; - break; case rnn_enum::kLstm: - size = num_layer * direction * seq_length * batch_size * hidden_size * 6; + size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1); break; case rnn_enum::kGru: - size = seq_length * batch_size * hidden_size * direction * num_layer * 8 + + size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) + batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 + seq_length * batch_size * 7 * hidden_size * direction; break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) + + batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 + + seq_length * batch_size * 2 * hidden_size * direction; + break; default: LOG(FATAL) << "unknown RNN mode " << mode; break; @@ -223,21 +225,24 @@ void RNNForwardTraining(DType* ws, DType* y_ptr, DType* hy_ptr, DType* cy_ptr, + const float dropout, int mode) { switch (mode) { - case rnn_enum::kRnnTanh: - case rnn_enum::kRnnRelu: - LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; - break; case rnn_enum::kLstm: LstmForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, - w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); + w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, dropout); break; case rnn_enum::kGru: GruForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, - w_ptr, y_ptr, hy_ptr); + w_ptr, y_ptr, hy_ptr, dropout); + break; + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + VanillaRNNForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + w_ptr, y_ptr, hy_ptr, dropout, mode); break; default: LOG(FATAL) << "unknown RNN mode " << mode; @@ -264,10 +269,6 @@ void RNNForwardInference(DType* ws, DType* cy_ptr, int mode) { switch (mode) { - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; - break; case rnn_enum::kLstm: LstmForwardInference(ws, state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, @@ -278,6 +279,12 @@ void RNNForwardInference(DType* ws, batch_size, input_size, state_size, x_ptr, hx_ptr, w_ptr, y_ptr, hy_ptr); break; + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + VanillaRNNForwardInference(ws, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + w_ptr, y_ptr, hy_ptr, mode); + break; default: LOG(FATAL) << "unknown RNN mode" << mode; break; @@ -310,22 +317,27 @@ void RNNBackward(DType* ws, int req_params, int req_state, int req_statecell, + const float dropout, int mode) { switch (mode) { - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - break; case rnn_enum::kLstm: LstmBackward(ws, rs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr, dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr, - req_data, req_params, req_state, req_statecell); + req_data, req_params, req_state, req_statecell, dropout); break; case rnn_enum::kGru: GruBackward(ws, rs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, w_ptr, dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr, - req_data, req_params, req_state); + req_data, req_params, req_state, dropout); + break; + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + VanillaRNNBackward(ws, rs, num_layers, direction, seq_length, batch_size, + input_size, state_size, x_ptr, hx_ptr, w_ptr, + dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr, + req_data, req_params, req_state, dropout, mode); break; default: LOG(FATAL) << "unknown RNN mode" << mode; @@ -354,9 +366,8 @@ class RNNOp : public Operator{ const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru) - << "Only lstm and gru mode are supported at the moment."; - CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; + CHECK(param_.p >= 0.0f && param_.p < 1.0f) + << "unsupported dropout value, should be 0 <= dropout < 1"; size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; @@ -436,6 +447,7 @@ class RNNOp : public Operator{ y.dptr_, hy_ptr, cy_ptr, + param_.p, param_.mode); } else { RNNForwardInference(workspace.dptr_, @@ -467,9 +479,8 @@ class RNNOp : public Operator{ const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru) - << "Only lstm and gru mode are supported at the moment."; - CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; + CHECK(param_.p >= 0.0f && param_.p < 1.0f) + << "unsupported dropout value, should be 0 <= dropout < 1"; size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; @@ -566,6 +577,7 @@ class RNNOp : public Operator{ req[rnn_enum::kParams], req[rnn_enum::kState], req[rnn_enum::kStateCell], + param_.p, param_.mode); } diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index fa8d671a2007..e1b4a2b79c0a 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -49,6 +49,11 @@ inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x)); } +template +inline DType relu(DType x) { + return x > 0.0f ? static_cast(x) : 0.0f; +} + template void LstmForwardTrainingSingleLayer(DType* ws, DType* rs, @@ -133,7 +138,10 @@ void LstmForwardTraining(DType* ws, DType* b_ptr, DType* y_ptr, DType* hy_ptr, - DType* cy_ptr) { + DType* cy_ptr, + const float dropout) { + DType* dropout_random = rs; + DType* rs2 = dropout_random + (L - 1) * D * T * N * H; const int total_layers = D * L; Tensor hx(hx_ptr, Shape3(total_layers, N, H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); @@ -141,14 +149,15 @@ void LstmForwardTraining(DType* ws, const int r_size = D * T * N * H * 6; const int y_offset = T * N * H * 5; const int cell_size = N * H; + unsigned int seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) int idx = 0; // state & cell state's idx; const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); for (int i = 0; i < L; ++i) { const int input_size = i ? H * D : I; const int w_size = (input_size + H) * H * 4; Tensor x(x_ptr, Shape2(T * N, input_size)); - Tensor y(rs + y_offset, Shape3(T, N, H * D)); - LstmForwardTrainingSingleLayer(ws, rs, state_outputs, false, T, N, input_size, H, x, + Tensor y(rs2 + y_offset, Shape3(T, N, H * D)); + LstmForwardTrainingSingleLayer(ws, rs2, state_outputs, false, T, N, input_size, H, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); if (D == 2) { w_ptr += w_size; @@ -158,14 +167,27 @@ void LstmForwardTraining(DType* ws, hy_ptr += cell_size; cy_ptr += cell_size; } - LstmForwardTrainingSingleLayer(ws, rs, state_outputs, true, T, N, input_size, H, x, + LstmForwardTrainingSingleLayer(ws, rs2, state_outputs, true, T, N, input_size, H, x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); } if (i != L - 1) { w_ptr += w_size; b_ptr += b_size; + if (dropout > 0.0f) { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < T * N * H * D; j++) { + int rand_data = rand_r(&seed_); + if (static_cast(rand_data % 1000) < static_cast(1000 * dropout)) { + dropout_random[i * T * N * H * D + j] = 0; + y.dptr_[j] = 0; + } else { + dropout_random[i * T * N * H * D + j] = 1.0f - dropout; + y.dptr_[j] = y.dptr_[j] / (1.0f - dropout); + } + } + } x_ptr = y.dptr_; - rs += r_size; + rs2 += r_size; ++idx; if (state_outputs) { hy_ptr += cell_size; @@ -175,7 +197,7 @@ void LstmForwardTraining(DType* ws, } #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < T * N * H * D; ++i) { - y_ptr[i] = (rs + y_offset)[i]; + y_ptr[i] = (rs2 + y_offset)[i]; } } @@ -498,7 +520,10 @@ void LstmBackward(DType* ws, int req_data, int req_params, int req_state, - int req_statecell) { + int req_statecell, + const float dropout) { + DType* dropout_random = rs + (L - 1) * D * T * N * H; + DType* rs2 = rs + (L - 1) * D * T * N * H; DType* tmp_buf = ws; DType* ws2 = tmp_buf + 8 * T * H; const int total_layers = D * L; @@ -520,7 +545,7 @@ void LstmBackward(DType* ws, DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr; DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : dw_ptr; DType* db_cur_ptr = db_ptr + i * b_size * D; - DType* rs_cur_ptr = rs + i * r_size; + DType* rs_cur_ptr = rs2 + i * r_size; DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : NULL; DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : NULL; Tensor y(rs_cur_ptr + y_offset, Shape3(T, N, H * D)); @@ -543,6 +568,18 @@ void LstmBackward(DType* ws, dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr, req_data, req_params, req_state, req_statecell); } + if (dropout > 0.0f && i > 0 && req_data != kNullOp) { + dropout_random = dropout_random - T * N * D * H; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < T * N * D * H; j++) { + if (dropout_random[j] == 0) { + dx.dptr_[j] = 0; + } else { + dx.dptr_[j] = dx.dptr_[j] / (1.0f - dropout); + } + } + } dy_ptr = dx.dptr_; } } @@ -935,7 +972,8 @@ void GruForwardTraining(DType* ws, DType* hx_ptr, DType* w_ptr, DType* y_ptr, - DType* hy_ptr) { + DType* hy_ptr, + const float dropout) { DType* wx = w_ptr; DType* wh = wx + I * H * 3; DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) @@ -948,19 +986,34 @@ void GruForwardTraining(DType* ws, DType* gateN_l = gateZ_l + L * T * D * N * H; DType* y_l = gateN_l + L * T * D * N * H; DType* Mnh_l = y_l + L * T * N * H * D; - DType* tmp_buf = Mnh_l + L * D * T * N * H; - DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N; + DType* dropout_random = Mnh_l + L * D * T * N * H; + DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; + DType* ws2 = tmp_buf + D * N * H; DType* wx_l = wx; DType* wh_l = wh; DType* bx_l = bx; DType* bh_l = bh; DType* y_tmp = x_ptr; - + unsigned int seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) for (int l = 0; l < L; l++) { if (l != 0) { y_tmp = y_l; y_l = y_l + T * N * H * D; } + if (dropout > 0.0f && l > 0) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * I; i++) { + int rand_data = rand_r(&seed_); + if (static_cast(rand_data % 1000) < static_cast(1000 * dropout)) { + dropout_random[(l - 1) * T * N * I + i] = 0; + y_tmp[i] = 0; + } else { + dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout; + y_tmp[i] = y_tmp[i] / (1.0f - dropout); + } + } + } Tensor x_l(y_tmp, Shape2(T * N, I)); Tensor hx_l = hx[D * l]; GruForwardTrainingSingleLayer(ws2, tmp_buf, state_outputs, D, T, N, I, H, @@ -1349,7 +1402,8 @@ void GruBackward(DType* ws, DType* dw_ptr, int req_data, int req_params, - int req_state) { + int req_state, + const float dropout) { DType* wx = w_ptr; DType* dwx = dw_ptr; DType* dwh = dwx + I * H * 3; @@ -1360,7 +1414,8 @@ void GruBackward(DType* ws, DType* gateN_l = gateZ_l + L * T * D * N * H; DType* y_l = gateN_l + L * T * D * N * H; DType* Mnh_l = y_l + L * T * N * H * D; - DType* tmp_buf = Mnh_l + L * D * T * N * H; + DType* dropout_random = Mnh_l + L * D * T * N * H; + DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2; DType* ws2 = dx_l + T * N * D * H; DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H @@ -1403,6 +1458,17 @@ void GruBackward(DType* ws, GruBackwardSingleLayer(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l, dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l, dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state); + if (dropout > 0.0f && l > 0 && req_data != kNullOp) { + dropout_random = dropout_random - T * N * D * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * I; i++) { + if (dropout_random[i] == 0) { + dx_l[i] = 0; + } else { + dx_l[i] = dx_l[i] / (1.0f - dropout); + } + } + } if (l > 0) { #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < T * N * H * D; ++i) { @@ -1433,6 +1499,859 @@ void GruBackward(DType* ws, } } } + +template +void VanillaRNNForwardInferenceSingleLayer(DType* ws, + DType* tmp_buf, + bool state_outputs, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* bx_ptr, + DType* bh_ptr, + DType* y_ptr, + DType* hy_ptr, + int mode) { + DType* ht = y_ptr; + DType* ht_1 = y_ptr; + DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H; + DType* back_ht = back_ht_1; + DType* gemmC1 = ws; // [D, T, N, H] + DType* gemmC2 = gemmC1 + D * T * N * H; // N * H + DType* back_wx_ptr = wx_ptr + I * H + H * H; + DType* back_wh_ptr = wh_ptr + I * H + H * H; + DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL; + DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2: NULL; + DType* back_gemmC1 = gemmC1 + T * N * H; + DType* gemmC1_t = gemmC1; + + const Tensor wx(wx_ptr, Shape2(H, I)); + const Tensor wh(wh_ptr, Shape2(H, H)); + const Tensor bx(bx_ptr, Shape2(1, H)); + const Tensor bh(bh_ptr, Shape2(1, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H, H)); + const Tensor back_bx(back_bx_ptr, Shape2(1, H)); + const Tensor back_bh(back_bh_ptr, Shape2(1, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (D == 1) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * H + j] = hx[i][j]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * D * H + j] = hx[i][j]; + back_ht_1[i * D * H + j] = hx[N + i][j]; + } + } + Tensor dgemmC1(ws, Shape2(T * N, H)); + Tensor dgemmC2(gemmC2, Shape2(N, H)); + Tensor dback_gemmC1(back_gemmC1, Shape2(T * N, H)); + + // x * wx.T : [T * N, I] * [I, H] + DType alpha = 1.0; + DType beta = 0.0; + linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); + if (D == 2) { + linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); + } + + for (int t = 0; t < T; t++) { + // perform the first direction, X * wx and H * wh for each step + // ht-1 * wh, ht-1:[N, H] wh:[H, H] + Tensor dht_1(ht_1, Shape2(N, D * H)); + if (D == 1) { + linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); + } else { + Tensor dht_1_tmp = Tensor(reinterpret_cast(tmp_buf), + Shape3(D, H, N)); + dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); + linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); + } + gemmC1_t = gemmC1 + t * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int tb = i * H; + if (mode == 1) { + ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] + + gemmC2[tb + j] + bh[0][j]); + } else { + ht[i * D * H + j] = relu(gemmC1_t[tb + j] + bx[0][j] + + gemmC2[tb + j] + bh[0][j]); + } + } + } + ht_1 = ht; + ht = ht + D * H * N; + // perform the second direction + if (D == 2) { + gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H; + Tensor dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); + Tensor dback_ht_1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); + linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int tb = i * H; + if (mode == 1) { + back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j] + + gemmC2[tb + j] + back_bh[0][j]); + } else { + back_ht[i * D * H + j] = relu(gemmC1_t[tb + j] + back_bx[0][j] + + gemmC2[tb + j] + back_bh[0][j]); + } + } + } + back_ht_1 = back_ht; + back_ht = back_ht - D * H * N; + } + } + // copy last state to hy, from(N, H * D) to (D, N, H) + if (state_outputs) { + if (D == 1) { + DType* y_start = y_ptr + (T - 1) * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * H + j]; + } + } else { + DType* y_start = y_ptr + (T - 1) * N * H * D; + DType* y_back_start = y_ptr + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * D * H + j]; + hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; + } + } + } +} + +template +void VanillaRNNForwardInference(DType* ws, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* hy_ptr, + int mode) { + DType* wx = w_ptr; + DType* wh = wx + I * H; + DType* bx = wh + H * H + (D - 1) * (H * H + I * H) + + (L - 1) * ((D + 1) * H) * H * D; + DType* bh = bx + H; + + DType* y_tmp = ws; + DType* y_l = x_ptr; + DType* tmp_buf = y_tmp + D * T * N * H; + DType* ws2 = y_tmp + D * T * N * H + D * H * N; + + DType* wx_l = wx; + DType* wh_l = wh; + DType* bx_l = bx; + DType* bh_l = bh; + Tensor hx(hx_ptr, Shape3(D * L, N, H)); + DType* hy_l = hy_ptr; + for (int l = 0; l < L; l++) { + Tensor x_l(y_l, Shape2(T * N, I)); + if ((L + l) % 2) { + y_l = y_ptr; + } else { + y_l = y_tmp; + } + Tensor hx_l = hx[D * l]; + VanillaRNNForwardInferenceSingleLayer(ws2, tmp_buf, state_outputs, D, T, N, I, H, + x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, + hy_l, mode); + hy_l = hy_l + D * N * H; + bx_l = bx_l + H * D * 2; + bh_l = bh_l + H * D * 2; + wx_l = wx_l + I * H * D + H * H * D; + if (l == 0) { + I = D * H; + } + wh_l = wx_l + I * H; + } +} + + +template +void VanillaRNNForwardTrainingSingleLayer(DType* ws, + DType* tmp_buf, + bool state_outputs, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* bx_ptr, + DType* bh_ptr, + DType* gateN, + DType* y_ptr, + DType* hy_ptr, + int mode) { + DType* ht = y_ptr; + DType* ht_1 = y_ptr; + DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H; + DType* back_ht = back_ht_1; + + DType* gemmC1 = ws; // [D, T, N, H] + DType* gemmC2 = gemmC1 + D * T * N * H; // N * H + DType* nt = gateN; + DType* back_wx_ptr = wx_ptr + I * H + H * H; + DType* back_wh_ptr = wh_ptr + I * H + H * H; + DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL; + DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2 : NULL; + DType* back_gateN = gateN + T * N * H; + DType* back_gemmC1 = gemmC1 + T * N * H; + DType* gemmC1_t = gemmC1; + + const Tensor wx(wx_ptr, Shape2(H, I)); + const Tensor wh(wh_ptr, Shape2(H, H)); + const Tensor bx(bx_ptr, Shape2(1, H)); + const Tensor bh(bh_ptr, Shape2(1, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H * 1, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H * 1, H)); + const Tensor back_bx(back_bx_ptr, Shape2(1, H)); + const Tensor back_bh(back_bh_ptr, Shape2(1, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (D == 1) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * H + j] = hx[i][j]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * D * H + j] = hx[i][j]; + back_ht_1[i * D * H + j] = hx[N + i][j]; + } + } + + Tensor dgemmC1(ws, Shape2(T * N, H)); + Tensor dgemmC2(gemmC2, Shape2(N, H)); + Tensor dback_gemmC1(back_gemmC1, Shape2(T * N, H)); + + // x * wx.T : [T * N, I] * [I, H] + DType alpha = 1.0; + DType beta = 0.0; + linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); + if (D == 2) { + linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); + } + + for (int t = 0; t < T; t++) { + // perform the first direction, X * wx and H * wh for each step + // ht-1 * wh, ht-1:[N, H] wh:[H, H] + Tensor dht_1(ht_1, Shape2(N, D * H)); + if (D == 1) { + linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); + } else { + Tensor dht_1_tmp = Tensor(reinterpret_cast(tmp_buf), + Shape3(D, H, N)); + dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); + linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); + } + nt = gateN + t * N * H; + gemmC1_t = gemmC1 + t * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int tb = i * H; + if (mode == 1) { + nt[tb + j] = ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] + + gemmC2[tb + j] + bh[0][j]); + } else { + nt[tb + j] = gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j]; + ht[i * D * H + j] = relu(nt[tb + j]); + } + } + } + ht_1 = ht; + ht = ht + D * H * N; + // perform the second direction + if (D == 2) { + nt = back_gateN + (T - 1 - t) * N * H; + gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H; + Tensor dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); + Tensor dback_ht_1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); + linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int tb = i * H; + if (mode == 1) { + nt[tb + j] = back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j] + + gemmC2[tb + j] + back_bh[0][j]); + } else { + nt[tb + j] = gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j]; + back_ht[i * D * H + j] = relu(nt[tb + j]); + } + } + } + back_ht_1 = back_ht; + back_ht = back_ht - D * H * N; + } + } + + // copy last state to hy, from(N, H * D) to (D, N, H) + if (state_outputs) { + if (D == 1) { + DType* y_start = y_ptr + (T - 1) * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * H + j]; + } + } else { + DType* y_start = y_ptr + (T - 1) * N * H * D; + DType* y_back_start = y_ptr + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * D * H + j]; + hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; + } + } + } +} + +template +void VanillaRNNForwardTraining(DType* ws, + DType* rs, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* hy_ptr, + const float dropout, + int mode) { + DType* wx = w_ptr; + DType* wh = wx + I * H; + DType* bx = wh + H * H + (D - 1) * (H * H + I * H) + + (L - 1) * ((D + 1) * H) * H * D; + DType* bh = bx + H; + Tensor hx(hx_ptr, Shape3(D * L, N, H)); + DType* hy_l = hy_ptr; + DType* gateN_l = rs; + DType* y_l = gateN_l + L * T * D * N * H; + DType* dropout_random = y_l + L * D * T * N * H; + DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; + DType* ws2 = tmp_buf + D * N * H; + DType* wx_l = wx; + DType* wh_l = wh; + DType* bx_l = bx; + DType* bh_l = bh; + DType* y_tmp = x_ptr; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + unsigned int seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + for (int l = 0; l < L; l++) { + if (l != 0) { + y_tmp = y_l; + y_l = y_l + T * N * H * D; + } + if (dropout > 0.0f && l > 0) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * I; i++) { + int rand_data = rand_r(&seed_); + if (static_cast(rand_data % 1000) < static_cast(1000 * dropout)) { + dropout_random[(l - 1) * T * N * I + i] = 0; + y_tmp[i] = 0; + } else { + dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout; + y_tmp[i] = y_tmp[i] / (1.0f - dropout); + } + } + } + Tensor x_l(y_tmp, Shape2(T * N, I)); + Tensor hx_l = hx[D * l]; + VanillaRNNForwardTrainingSingleLayer(ws2, tmp_buf, state_outputs, D, T, N, I, H, + x_l, hx_l, wx_l, wh_l, bx_l, bh_l, + gateN_l, y_l, hy_l, mode); + gateN_l = gateN_l + T * D * N * H; + hy_l = hy_l + D * N * H; + bx_l = bx_l + H * D * 2; + bh_l = bh_l + H * D * 2; + + wx_l = wx_l + I * H * D + H * H * D; + if (l == 0) { + I = D * H; + } + wh_l = wx_l + I * H; + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * H * D; ++i) { + y_ptr[i] = y_l[i]; + } +} + +template +void VanillaRNNBackwardSingleLayer(DType* ws, + DType* tmp_buf, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* y_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* gateN, + DType* dx, + DType* dhx, + DType* dwx, + DType* dwh, + DType* dbx, + DType* dbh, + int req_data, + int req_params, + int req_state, + int mode) { + DType* dyt; + DType* ht1; // [N, D, H] + DType* dart; + DType* nt; + DType* dar = ws; // [T, N, H] + DType* dht1 = dar + T * N * H; // [D, N, H] + DType* hx_ = dht1 + D * N * H; // [N, D, H] + + DType* back_ht1; + DType* back_dht1 = dht1 + N * H; // [N, H] + DType* back_gateN = gateN + T * N * H; + DType* back_wx_ptr = wx_ptr + I * H + H * H; + DType* back_wh_ptr = wh_ptr + I * H + H * H; + DType* back_dwx = dwx + I * H + H * H; + DType* back_dwh = dwh + I * H + H * H; + DType* back_dbx = dbx + H * 2; + DType* back_dbh = dbh + H * 2; + + DType alpha = 1.0; + DType beta = 0.0; + const Tensor wx(wx_ptr, Shape2(H, I)); + const Tensor wh(wh_ptr, Shape2(H, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (req_params != kNullOp && req_params != kAddTo) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < D * H * H; ++i) { + dwh[i] = 0; + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < D * H; ++i) { + dbx[i] = 0; + dbh[i] = 0; + } + } + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N * H; ++i) { + if (dhy_ptr) { + dht1[i] = dhy_ptr[i]; + } else { + dht1[i] = 0; + } + } + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + hx_[i * D * H + j] = hx[i][j]; + } + } + + if (D == 2) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N * H; ++i) { + if (dhy_ptr) { + back_dht1[i] = dhy_ptr[N * H + i]; + } else { + back_dht1[i] = 0; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + hx_[i * D * H + H + j] = hx[N + i][j]; + } + } + } + for (int t = T - 1; t >= 0; --t) { + if (t) { + ht1 = y_ptr + (t - 1) * N * D * H; + } else { + ht1 = hx_; + } + // add dy[T, N, D, H] to dhy[D, N, H] + dyt = dy_ptr + t * N * D * H; + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + dht1[i * H + j] += dyt[i * D * H + j]; + } + } + + nt = gateN + t * N * H; + dart = dar + t * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int id = i * H + j; + if (mode == 1) { + dart[id] = dht1[id] * (1 - nt[id] * nt[id]); + } else { + dart[id] = nt[id] > 0.0f ? static_cast(dht1[id]) : 0.0f; + } + dht1[id] = 0; + } + } + if (req_params != kNullOp) { + alpha = 1.0; + beta = 1.0; + // dht1 = dart * wh [N, H] = [N, H] * [H, H] + Tensor d_dht1(dht1, Shape2(N, H)); + Tensor d_dart(dart, Shape2(N, H)); + linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false); + + if (req_params == kAddTo) { + beta = 2.0; + // dwx = da.T * x [H, I] = [H, N] * [N, I] for AddTo + Tensor d_xt(x.dptr_ + t * N * I, Shape2(N, I)); + Tensor d_dwx(dwx, Shape2(H, I)); + linalg_gemm(d_dart, d_xt, d_dwx, alpha, beta, true, false); + } + // dwh = dart.T * ht1 [H, H] = [H, N] * [N, H] + Tensor d_ht1(ht1, Shape2(N, D * H)); + Tensor d_dwh(dwh, Shape2(H, H)); + Tensor d_ht1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N)); + linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true); + } + } + + if (req_params != kNullOp) { + // dbx = e * da [1, H] = [1, N] * [N, H] + if (req_params != kAddTo) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + for (int j = 0; j < N * T; ++j) { + dbx[i] += dar[j * H + i]; + dbh[i] = dbx[i]; + } + } + } else { + const Tensor tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T)); + const Tensor tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T)); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H * T; ++i) { + tmp_dbx.dptr_[i] = 0; + tmp_dbh.dptr_[i] = 0; + } + + for (int t = T - 1; t >= 0; --t) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + for (int j = 0; j < N; ++j) { + tmp_dbx[i][t] += dar[t * N * H + j * H + i]; + tmp_dbh[i][t] = tmp_dbx[i][t]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + dbx[i] += tmp_dbx[i][t] + dbx[i]; + dbh[i] = dbx[i]; + } + } + } + } + alpha = 1.0; + beta = 0.0; + + // dx = da * wx [T * N, I] = [T * N, H] * [H, I] + Tensor d_dar(dar, Shape2(T * N, H)); + if (req_data != kNullOp) { + Tensor d_dx(dx, Shape2(T * N, I)); + linalg_gemm(d_dar, wx, d_dx, alpha, beta, false, false); + } + + // dwx = da.T * x [H, I] = [H, T * N] * [T * N, I] + if (req_params != kNullOp && req_params != kAddTo) { + Tensor d_dwx(dwx, Shape2(H, I)); + linalg_gemm(d_dar, x, d_dwx, alpha, beta, true, false); + } + + if (D == 2) { + for (int t = 0; t < T; ++t) { + if (t == T-1) { + back_ht1 = hx_; + } else { + back_ht1 = y_ptr + (t + 1) * N * D * H; + } + + // add dy[T, N, D, H] to dhy[D, N, H] + dyt = dy_ptr + t * N * D * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + back_dht1[i * H + j] += dyt[i * D * H + H + j]; + } + } + + nt = back_gateN + t * N * H; + dart = dar + t * N * H; + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int id = i * H + j; + if (mode == 1) { + dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]); + } else { + dart[id] = nt[id] > 0.0f ? static_cast(back_dht1[id]) : 0.0f; + } + back_dht1[id] = 0; + } + } + + if (req_params != kNullOp) { + alpha = 1.0; + beta = 1.0; + // dht1 = da * wh [N, H] = [N, H] * [H, H] + Tensor d_dart(dart, Shape2(N, H)); + Tensor d_back_dht1(back_dht1, Shape2(N, H)); + linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false); + + // dwh = da.T * ht1 [H, H] = [H, N] * [N, H] + Tensor d_back_dwh(back_dwh, Shape2(H, H)); + Tensor d_back_ht1(back_ht1 + H, Shape2(N, D * H)); + Tensor d_back_ht1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N)); + if (req_params == kAddTo) { + beta = 2.0; + // dwx = da.T * x [ H, I] = [H, N] * [N, I] for AddTo + Tensor d_xt(x.dptr_ + t * N * I, Shape2(N, I)); + Tensor d_back_dwx(back_dwx, Shape2(H, I)); + linalg_gemm(d_dart, d_xt, d_back_dwx, alpha, beta, true, false); + } + linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true); + } + } + + if (req_params != kNullOp) { + // dbx = e * da [1, H] = [1, N] * [N, H] + if (req_params != kAddTo) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + for (int j = 0; j < N * T; ++j) { + back_dbx[i] += dar[j * H + i]; + back_dbh[i] = back_dbx[i]; + } + } + } else { + const Tensor tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T)); + const Tensor tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T)); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H * T; ++i) { + tmp_dbx.dptr_[i] = 0; + tmp_dbh.dptr_[i] = 0; + } + + for (int t = T - 1; t >= 0; --t) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + for (int j = 0; j < N; ++j) { + tmp_dbx[i][t] += dar[t * N * H + j * H + i]; + tmp_dbh[i][t] = tmp_dbx[i][t]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; ++i) { + back_dbx[i] += tmp_dbx[i][t] + back_dbx[i]; + back_dbh[i] = back_dbx[i]; + } + } + } + } + alpha = 1.0; + beta = 1.0; + // dxt = da * wx [T * N, I] = [T * N, H] * [H, I] + Tensor d_dar2(dar, Shape2(T * N, H)); + if (req_data != kNullOp) { + Tensor d_dx(dx, Shape2(T * N, I)); + linalg_gemm(d_dar2, back_wx, d_dx, alpha, beta, false, false); + } + alpha = 1.0; + beta = 0.0; + // dwx = da.T * x [H, I] = [H, T * N] * [T * N, I] + if (req_params != kNullOp && req_params != kAddTo) { + Tensor d_back_dwx(back_dwx, Shape2(H, I)); + linalg_gemm(d_dar2, x, d_back_dwx, alpha, beta, true, false); + } + } + if (req_state != kNullOp) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N * H * D; ++i) { + dhx[i] = dht1[i]; + } + } +} + +template +void VanillaRNNBackward(DType* ws, + DType* rs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* dx_ptr, + DType* dhx_ptr, + DType* dw_ptr, + int req_data, + int req_params, + int req_state, + const float dropout, + int mode) { + DType* wx = w_ptr; + DType* dwx = dw_ptr; + DType* dwh = dwx + I * H; + DType* dbx = dwh + H * H + (D - 1) * (H * H + I * H) + + (L - 1) * ((D + 1) * H) * H * D; + DType* gateN_l = rs + (L - 1) * T * D * N * H; + DType* y_l = gateN_l + L * T * D * N * H; + DType* dropout_random = y_l + L * D * T * N * H; + DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; + DType* dx_l = tmp_buf + T * N * D * H + H * T * 2; + DType* ws2 = dx_l + T * N * D * H; + DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * H + + D * I * H + D * H * H; + DType* wh_l = wx_l; + if (L == 1) { + wh_l = wh_l + I * H; + } else { + wh_l = wh_l + (D * H) * H; + } + DType* dhy_l = NULL; + if (dhy_ptr) + dhy_l = dhy_ptr + (L - 1) * D * N * H; + DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * H + + D * I * H + D * H * H; + DType* dwh_l = NULL; + if (L == 1) { + dwh_l = dwx_l + I * H; + } else { + dwh_l = dwx_l + (D * H) * H; + } + DType* dbx_l = dbx + (L - 1) * D * H * 2; + DType* dbh_l = dbx_l + H; + DType* dhx_l = dhx_ptr + (L - 1) * D * N * H; + DType* dy_l = dy_ptr; + Tensor hx(hx_ptr, Shape3(L, D * N, H)); + int inputsize = I; + DType* y_tmp = y_l - T * N * H * D; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + for (int l = L - 1; l >= 0; --l) { + if (l == 0) { + I = inputsize; + y_tmp = x_ptr; + dx_l = dx_ptr; + } else { + I = D * H; + } + Tensor hx_l = hx[l]; + Tensor x_l(y_tmp, Shape2(T * N, I)); + VanillaRNNBackwardSingleLayer(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, + y_l, dy_l, dhy_l, gateN_l, dx_l, dhx_l, dwx_l, dwh_l, + dbx_l, dbh_l, req_data, req_params, req_state, mode); + if (dropout > 0.0f && l > 0 && req_data != kNullOp) { + dropout_random = dropout_random - T * N * D * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * I; i++) { + if (dropout_random[i] == 0) { + dx_l[i] = 0; + } else { + dx_l[i] = dx_l[i] / (1.0f - dropout); + } + } + } + if (l > 0) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < T * N * H * D; ++i) { + dy_l[i] = dx_l[i]; + } + gateN_l = gateN_l - T * D * N * H; + dhx_l = dhx_l - D * N * H; + if (dhy_l) + dhy_l = dhy_l - D * N * H; + y_l = y_l - T * N * H * D; + y_tmp = y_l; + if (l == 1) { + wx_l = wx_l - (inputsize + H) * H * D; + wh_l = wx_l + inputsize * H; + dwx_l = dwx_l - (inputsize + H) * H * D; + dwh_l = dwx_l + inputsize * H; + } else { + wx_l = wx_l - (I + H) * H * D; + wh_l = wx_l + I * H; + dwx_l = dwx_l - (I + H) * H * D; + dwh_l = dwx_l + I * H; + } + dbx_l = dbx_l - D * H * 2; + dbh_l = dbx_l + H; + } + } +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_RNN_IMPL_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3de30f21e16f..e07a602b8c18 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -137,8 +137,76 @@ def test_gru_bidirectional(): check_rnn_consistency(fused, stack, T, N, I, H, 'add') check_rnn_consistency(fused, stack, T, N, I, H, 'null') -# Currently, fused LSTM operator doesn't support dropout. -# Will change this test after dropout is supported +@with_seed() +def test_rnntanh_sym(): + T, N, I, H = 5, 32, 800, 800 + + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + +@with_seed() +def test_rnntanh_bidirectional(): + T, N, I, H = 5, 20, 800, 800 + + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), + output_prefix='bi_rnntanh_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), + output_prefix='bi_rnntanh_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + +@with_seed() +def test_rnnrelu_sym(): + T, N, I, H = 5, 32, 200, 200 + + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + +@with_seed() +def test_rnnrelu_bidirectional(): + T, N, I, H = 5, 20, 200, 200 + + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), + output_prefix='bi_rnnrelu_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), + output_prefix='bi_rnnrelu_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + @with_seed() def test_lstm_dropout(): X = mx.sym.Variable('x') @@ -149,12 +217,44 @@ def test_lstm_dropout(): rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX, state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM') exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I)) - try: - out = exe.forward(is_train=False) - out[0].wait_to_read() - assert False # should not reach here - except mx.base.MXNetError as err: - assert str(err).find('Dropout is not supported at the moment') != -1 + out = exe.forward(is_train=True) + out[0].wait_to_read() + +@with_seed() +def test_gru_dropout(): + X = mx.sym.Variable('x') + Params = mx.sym.Variable('params') + HX = mx.sym.Variable('state') + T, N, I, H = 300, 20, 800, 800 + rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, + state_size=H, num_layers=5, mode='gru', p=0.5, state_outputs=True, name='GRU') + exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I)) + out = exe.forward(is_train=True) + out[0].wait_to_read() + +@with_seed() +def test_rnntanh_dropout(): + X = mx.sym.Variable('x') + Params = mx.sym.Variable('params') + HX = mx.sym.Variable('state') + T, N, I, H = 300, 20, 800, 800 + rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, + state_size=H, num_layers=5, mode='rnn_tanh', p=0.5, state_outputs=True, name='RNN_TANH') + exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I)) + out = exe.forward(is_train=True) + out[0].wait_to_read() + +@with_seed() +def test_rnnrelu_dropout(): + X = mx.sym.Variable('x') + Params = mx.sym.Variable('params') + HX = mx.sym.Variable('state') + T, N, I, H = 300, 20, 800, 800 + rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, + state_size=H, num_layers=5, mode='rnn_relu', p=0.5, state_outputs=True, name='RNN_RELU') + exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I)) + out = exe.forward(is_train=True) + out[0].wait_to_read() def np_softmax(x, axis=-1): # fix for old numpy on Travis not supporting keepdims