A summary on Recurrent Neural Networks (RNN)

I am working closely on RNNs these days, trying to reveal the ``black box'' and see what RNN learned to use its the hidden states and gates.

After intensively trying to do experiments, I suddenly realize that maybe first analyze them mathematically would give some clues for better visualization

There are already many good articles introducing RNNs and its variants (LSTMs, GRU) on the internet right now. So this is just a post for myself to summarize things up on RNNs.


So first, what is Recurrent Neural Network (RNN)?

In short, RNN is a type of neural network that deal with sequence data. Classical neural networks, e.g. Multi-layer Perceptron (MLP) or Convolutional Neural Network (CNN) takes a fixed sized input an produce a fixed size output. Although for CNNs you can resize images of different size into a standard size so that the model can work with variable size input, but for the CNN part it still only accept fixed sized input. So RNN is a more flexible form of neural networks that kind of extends the computation power of neural networks. Indeed, some research has proofed that RNN is Turing-Complete.

Model Definition

Now, let's dive into the detail of how RNN works.

The basic idea that differentiate RNN with CNN or Classical NN is that it introduce information and dependency of precious inputs when predicting output based on the current input. At each time step \(t\), given previous inputs \([x_0, \cdots, x_{t-1}]\) and current input \(x_t\), we can express the output of the model as conditional probability:

\[ P(y_t \mid x_t, x_{t-1}, \cdots, x_0 ) = P (y_t \mid x_t, i_t) \]

where \(i_t\) is the representation of previous information learned by the model.

Vanilla RNN

The simplest version of RNN model the previous information as a vector of hidden states,