Skip to content

A simple PyTorch RNN package including a general RNN frame

License

Notifications You must be signed in to change notification settings

daehwannam/pytorch-rnn-library

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

PyTorch is a flexible deep learning framework, which enables you to make custom RNN and LSTM cells. However, custom RNN and LSTM cells cannot exploit the convenient options provided by PyTorch’s standard RNN and LSTM, including:

  • Bidirection
  • Multi-layers
  • Dropout

Therefore, it is tedious to implement all such options from scratch.

Here, I made a simple RNN package which eases the cumbersome by RNNFrame. Also, the package provides LayerNormLSTM, a LSTM variant including layer normalization and recurrent dropout as options.

Requirement

It needs PyTorch-1.7 or more recent version.

Documentation

Important classes are described as follows:

  • RNNFrame: It is a general framework to customize RNNs (or LSTMs when for_lstm=True is passed to __init__). Its forward provides the same API with that of RNN and LSTM when a layout of RNN cells is passed to __init__. For example, a layout [(cell_0f, cell_0b), (cell_1f, cell_1b)] makes a a bidirectional and two-layer RNNs. Also, you can set the options of dropout and bidirection. Caution: It can process a batch with variable-length sequences (just by converting PackedSequence object to a tensor and lengths), but the computation time is proportional to the maximum length of input sequences for each batch.
  • LSTMFrame: It is a wrapper of RNNFrame where for_lstm=True is set by force in the __init__.
  • LayerNormLSTMCell: An example of custom LSTM cell where layer normalization and recurrent dropout are applied. The implementation is based on tf.contrib.rnn.LayerNormBasicLSTMCell.
  • LayerNormLSTM: An application of LSTMFrame with LayerNormLSTMCell. The class provides the key options of LSTM and additional options, r_dropout for recurrent dropout and layer_norm_enabled for layer normalization.

Also, you can check example.py to understand the usage of RNNFrame, LSTMFrame and LayerNormLSTM.

Note

  • RNNFrame is not exhaustively tested for various options. So, I recommend to use the standard RNN or LSTM first then replace it with this package later.
  • If you need only LSTMFrame rather than RNNFrame, you can use the snapshot of tag v1.1 .

About

A simple PyTorch RNN package including a general RNN frame

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages