Normalization

Normalization layers are an essential ingredient for stable training dynamics in large neural networks.

Variants:

  • Batch normalization โ€” the original paper
  • Layer Normalization โ€” used extensively in large DL models
  • RMSNorm โ€” heavily used in modern Transformers
  • > Group Normalization
  • > Weight Normalization

Normalization Layer Placement

The placement of normalization layer in neural networks affects the stability and training dynamics. The original transformer block placed the two normalization layers in the transformer block after attention and MLP respectively, termed as "Post-Norm". It was shown by that "Pre-Norm" setup resulted in more smoother gradients in which the attention module is sandwiched between the norm layers rather than the MLP.

Post-Norm (original):

  • $x \rightarrow \text{Attention} \rightarrow \text{Add \& Norm} \rightarrow \text{MLP} \rightarrow \text{Add \& Norm}$

Pre-Norm:

  • $x \rightarrow \text{LayerNorm} \rightarrow \text{Attention} \rightarrow \text{Add} \rightarrow \text{LayerNorm} \rightarrow \text{MLP} \rightarrow \text{Add}$

Pre-Norm provides more stable training with smoother gradients (Xiong et al. 2020 ) because normalization occurs before the potentially destabilizing operations, improving gradient flow during backpropagation.

Batch normalization

Input distributions change for per layer, especially during training

Normalize the layer inputs with batch normalization Normalize $a_{l} \sim N(0,1)$
Followed by affine transformation $a_{l} \leftarrow \gamma a_{l}+\beta$, where parameters $\gamma$ and $\beta$ are trainable.

i runs over mini-batch samples, j over the feature dimensions

$$ \begin{array}{l} \mu_{j} \leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i j} \\ \sigma_{j}^{2} \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i j}-\mu_{j}\right)^{2} \\ \hat{x}_{i j} \leftarrow \frac{x_{i j}-\mu_{j}}{\sqrt{\sigma_{j}^{2}+\varepsilon}} \\ \hat{x}_{i j} \leftarrow \gamma \hat{x}_{i j}+\beta \end{array} $$

Why does it work?

Covariate shift - Per gradient update a module must adapt the weights to fit better the data, but also adapt to the change of its input distribution. Remember, each module inputs depend on other parameterized modules.

batch-norm

This interpretation doesn't explain practical observations:

  1. Why does batch norm work better after the nonlinearity?
  2. Why have $\gamma$ and $\beta$ parameters to reshape our gaussian if the problem is the covariate shift? Original reasoning is it gives that choice to the model itself.

There is another interpretation: Batch norm simplifies the learnin dynamics.

Neural network outputs determined by higher order layer interactions
They complicate the gradient update
Mean of BatchNorm output is $\beta$, std is $\gamma$
They are independent of the activation values themselves
Higher order interactions suppressed, training becomes easier

The benefits

  1. Higher learning rates, which means faster training
  2. Neurons of all layers activated in near optimal "regime"
  3. Model regularization
    1. Add some noise to per mini-batch mean and variance
    2. The added noise reduces overfitting

Test inference

How do we ship the Batch Norm layer after training? We might not have batches at test time.

Usually: keep a moving average of the mean and variance during training, plug them in at test time. To the limit, the moving average of mini-batch statistics approaches the batch statistics.

Disadvantages

Requires large mini-batches Cannot work with mini-batch of size $1(\sigma=0)$

And for small mini-batches we don't get very accurate gradients anyways

Awkward to use with recurrent neural networks

Instance Normalization

Similar to layer normalization but per channel per training example

Basic idea: network should be agnostic to the contrast of the original image

Originally proposed for style transfer

Not as good in image classification

Group Normalization

Same as instance norm but over groups of channels Between layer normalization and instance normalization.

Better than batch normalization for small batches, for example < 32. Competitive for larger batches.

Useful for object detection/segmentation networks.

They rely on high resolution images and cannot have big mini-batches.

Weight Normalization

Instead of normalizing activations, normalize weights by re-parameterizing weights

$$ \boldsymbol{w}=g \frac{\boldsymbol{v}}{\|\boldsymbol{v}\|} $$

Separate the norm from the direction.
Similar to dividing by standard deviation in batch normalization.
Can be combined with mean-only batch normalization: Subtract the mean (but not divide by the standard deviation). Then, apply weight normalization.

L2 and normalization

L2 regularization has no regularizing effect when combined with normalization. Instead, regularization has an influence on the scale of weights, and thereby on the effective learning rate. [2]