SRU is a recurrent unit that can run over 10 times faster than cuDNN LSTM, without loss of accuracy tested on many tasks.
Average processing time of LSTM, conv2d and SRU, tested on GTX 1070
Simple Recurrent Units for Highly Parallelizable Recurrence
@inproceedings{lei2018sru,
title={Simple Recurrent Units for Highly Parallelizable Recurrence},
author={Tao Lei, Yu Zhang, Sida I. Wang, Hui Dai and Yoav Artzi},
booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
year={2018}
}
Install requirements via pip install -r requirements.txt
. CuPy and pynvrtc needed to support training / testing on GPU.
SRU can be installed as a regular package via python setup.py install
or pip install .
.
pip install sru
pip install sru[cuda]
additionally installs Cupy and pynvrtc.
pip install sru[cpu]
additionally installs ninja
Make sure this repo and CUDA library can be found by the system, e.g.
export PYTHONPATH=path_to_repo/sru
export LD_LIBRARY_PATH=/usr/local/cuda/lib64
The usage of SRU is similar to nn.LSTM
. SRU likely requires more stacking layers than LSTM. We recommend starting by 2 layers and use more if necessary (see our report for more experimental details).
import torch
from torch.autograd import Variable
from sru import SRU, SRUCell
# input has length 20, batch size 32 and dimension 128
x = Variable(torch.FloatTensor(20, 32, 128).cuda())
input_size, hidden_size = 128, 128
rnn = SRU(input_size, hidden_size,
num_layers = 2, # number of stacking RNN layers
dropout = 0.0, # dropout applied between RNN layers
bidirectional = False, # bidirectional RNN
layer_norm = False, # apply layer normalization on the output of each layer
highway_bias = 0, # initial bias of highway gate (<= 0)
rescale = True, # whether to use scaling correction
)
rnn.cuda()
output_states, c_states = rnn(x) # forward pass
# output_states is (length, batch size, number of directions * hidden size)
# c_states is (layers, batch size, number of directions * hidden size)
https://github.com/taolei87/sru/graphs/contributors
@musyoku had a very nice SRU implementaion in chainer.
@adrianbg implemented the first CPU version.