LSTM

Vanilla Recurrent Neural Networks (RNN) are defined as

$$ \begin{array}{l} y_{t}=\operatorname{softmax}\left(V \cdot s_{t}\right) \\ s_{t}=\tanh \left(U \cdot x_{t}+W \cdot s_{t-1}\right) \end{array} $$

The key ideas behind LSTMs are:

Setting $\left(\frac{\partial s_{j}}{\partial s_{j-1}}=1\right)$ to avoid Vanishing and Exploding Gradients

Remove immediate nonlinear relation between $s_{t}$ and $s_{t-1}$ as nonlinearities return gradient smaller than 1.

  • Replace tanh between $s_{t}$ and $s_{t-1}$ with identity
  • Add non-linearlity on the memory variable instead of state variable

Also, avoid continuous overwriting of state

  • Modulate the importance of new input by a gate
  • Modulate the importance of new output by a gate
  • Modulate the importance of past memories by a gate

By putting all these thing together, we allow LSTM at each time step to modify:

Input gate - Determine how important in the input and select the information to add to the new cell state from the input.

$$ i_t = \sigma \left( W^{(i)}x_t + U^{(i)}h_{t-1} \right) $$

Forget gate - Determine how important is the past state and delete information from the cell state that is no longer needed.

$$ f_t = \sigma \left( W^{(f)}x_t + U^{(f)}h_{t-1} \right) $$

New memory cell -What could be relevant for new memory? Extract information from the previous hidden cell and input and create candidate memory.

$$ \hat{c}_t = tanh\left( W^{(c)}x_t + U^{(c)}h_{t-1} \right) $$

Final memory cell - Compute the new cell state.

$$ c_t = \sigma(f_t \odot c_{t-1} + i_t\odot \hat{c_t}) $$

Output gate - How imporatnt is the new state useful for output.

$$ o_t = \sigma \left( W^{(o)}x_t + U^{(o)}h_{t-1} \right) $$

Final hidden state - Update the hidden state.

$$ h_t = o_t \odot tanh(c_t) $$
lstm-eq 1

LSTM insights

Comparing the state equations between RNN and LSTM
RNN: $s_{t}=\tanh \left(U \cdot x_{t}+W \cdot s_{t-1}\right)$
LSTM: $s_{t} =s_{t-1} \odot f_{t}+g_{t} \odot i_{t}, m_{t}=\tanh \left(s_{t}\right) \odot o_{t}$

  • The LSTM also has indirect nonlinear relation between $s_{t}$ and $s_{t-1}$ via $\boldsymbol{m}_{t}$. There is also direct linear relation -> Strong gradients encouraged
  • Use sigmoids for gating/squashing $\rightarrow(0,1)$ values
  • Use tanh as module's recurrence nonlinearity, instead.

References

  1. Understanding LSTM Networks by Chris Olah https://colah.github.io/posts/2015-08-Understanding-LSTMs/
  2. Lecture 6.3, UvA Deep Learning course 2020