Normalization
Normalization layers are an essential ingredient for stable training dynamics in large neural networks.
Variants:
- Batch normalization โ the original paper
- Layer Normalization โ used extensively in large DL models
- RMSNorm โ heavily used in modern Transformers
- > Group Normalization
- > Weight Normalization
Normalization Layer Placement
The placement of normalization layer in neural networks affects the stability and training dynamics. The original transformer block placed the two normalization layers in the transformer block after attention and MLP respectively, termed as "Post-Norm". It was shown by that "Pre-Norm" setup resulted in more smoother gradients in which the attention module is sandwiched between the norm layers rather than the MLP.
Post-Norm (original):
- $x \rightarrow \text{Attention} \rightarrow \text{Add \& Norm} \rightarrow \text{MLP} \rightarrow \text{Add \& Norm}$
Pre-Norm:
- $x \rightarrow \text{LayerNorm} \rightarrow \text{Attention} \rightarrow \text{Add} \rightarrow \text{LayerNorm} \rightarrow \text{MLP} \rightarrow \text{Add}$
Pre-Norm provides more stable training with smoother gradients (Xiong et al. 2020 ) because normalization occurs before the potentially destabilizing operations, improving gradient flow during backpropagation.
Batch normalization
Input distributions change for per layer, especially during training
Normalize the layer inputs with batch normalization Normalize $a_{l} \sim N(0,1)$
Followed by affine transformation $a_{l} \leftarrow \gamma a_{l}+\beta$, where parameters $\gamma$ and $\beta$ are trainable.
i runs over mini-batch samples, j over the feature dimensions
Why does it work?
Covariate shift - Per gradient update a module must adapt the weights to fit better the data, but also adapt to the change of its input distribution. Remember, each module inputs depend on other parameterized modules.
This interpretation doesn't explain practical observations:
- Why does batch norm work better after the nonlinearity?
- Why have $\gamma$ and $\beta$ parameters to reshape our gaussian if the problem is the covariate shift? Original reasoning is it gives that choice to the model itself.
There is another interpretation: Batch norm simplifies the learnin dynamics.
Neural network outputs determined by higher order layer interactions
They complicate the gradient update
Mean of BatchNorm output is $\beta$, std is $\gamma$
They are independent of the activation values themselves
Higher order interactions suppressed, training becomes easier
The benefits
- Higher learning rates, which means faster training
- Neurons of all layers activated in near optimal "regime"
- Model regularization
- Add some noise to per mini-batch mean and variance
- The added noise reduces overfitting
Test inference
How do we ship the Batch Norm layer after training? We might not have batches at test time.
Usually: keep a moving average of the mean and variance during training, plug them in at test time. To the limit, the moving average of mini-batch statistics approaches the batch statistics.
Disadvantages
Requires large mini-batches Cannot work with mini-batch of size $1(\sigma=0)$
And for small mini-batches we don't get very accurate gradients anyways
Awkward to use with recurrent neural networks
-
Must interleave it between recurrent layers
-
Also, store statistics per time step
-
Can cause gradient explosion as well: https://arxiv.org/abs/2304.11692
Instance Normalization
Similar to layer normalization but per channel per training example
Basic idea: network should be agnostic to the contrast of the original image
Originally proposed for style transfer
Not as good in image classification
Group Normalization
Same as instance norm but over groups of channels Between layer normalization and instance normalization.
Better than batch normalization for small batches, for example < 32. Competitive for larger batches.
Useful for object detection/segmentation networks.
They rely on high resolution images and cannot have big mini-batches.
Weight Normalization
Instead of normalizing activations, normalize weights by re-parameterizing weights
Separate the norm from the direction.
Similar to dividing by standard deviation in batch normalization.
Can be combined with mean-only batch normalization: Subtract the mean (but not divide by the standard deviation). Then, apply weight normalization.
L2 and normalization
L2 regularization has no regularizing effect when combined with normalization. Instead, regularization has an influence on the scale of weights, and thereby on the effective learning rate. [2]