Batch Normalization: Stabilizing Deep Neural Network Training
Deriving batch normalization and gradients from scratch.
We know that feature scaling makes the job of gradient descent easy and allows it to converge faster. Feature scaling is performed as a pre-processing task on the dataset. But once the normalized input is fed to the deep network, as each layer is affected by parameters in all the input layer, even a small change in the network parameter is amplified and leads to the input distribution being changed in the internal layers of the network. This is known as internal covariance shift.
Batch Normalization is an idea introduced by Ioffe & Szegedy
It is implemented as a layer (with trainable parameters) and normalizes the activations of the previous layer. Backpropagation allows the network to learn if they want the activations to be normalized and upto what extent. It is inserted immediately after fully connected or convolutional layers and before nonlinearities. It effectively reduces the internal covariance shift in deep networks.
Advantages of BatchNorm
- Improves gradient flow through very deep networks
- Reduces dependency on careful initialization
- Allows higher learning rates
- Provides regularization and reduces dependency on dropout
Forward Propagation
In the forward pass, we calculate the mean and variance of the batch, normalize the input to have unit Gaussian distribution and scale and shift it with the learnable parameters $\gamma$ and $\beta $, respectively.
The implementation is very simple and straightforward:
n_X,c_X,h_X,w_X = X.shape
X_flat = X.reshape(n_X,c_X*h_X*w_X)
mu = np.mean(X_flat,axis=0)
var = np.var(X_flat, axis=0)
X_norm = (X_flat - mu)/np.sqrt(var + 1e-8)
out = gamma * X_norm + beta
Backward Propagation
For our backward pass, we need to find gradients $\frac{\partial C}{\partial x_i}$, $\frac{\partial C}{\partial \gamma}$ and $\frac{\partial C}{\partial \beta}$. We calculate the intermediate gradients from top to bottom in the computational graph to get these gradients.
Now we have gradients for both the learnable parameters. Now for input gradient,
We can see from the computation graph, $\mu_B$ is on two nodes, so we need to add up gradients on both nodes.
Now we have all the intermediate gradients to calculate input gradient. Since $x_i$ is in three nodes, we add up the gradients on each of those nodes.
Translating the gradient expressions in python, we have our implementation of backprop through the BatchNorm layer:
n_X,c_X,h_X,w_X = X.shape
# flatten the inputs and dout
X_flat = X.reshape(n_X,c_X*h_X*w_X)
dout = dout.reshape(n_X,c_X*h_X*w_X)
X_mu = X_flat - mu
var_inv = 1./np.sqrt(var + 1e-8)
dX_norm = dout * gamma
dvar = np.sum(dX_norm * X_mu,axis=0) * -0.5 * (var + 1e-8)**(-3/2)
dmu = np.sum(dX_norm * -var_inv ,axis=0) + dvar * 1/n_X * np.sum(-2.* X_mu, axis=0)
dX = (dX_norm * var_inv) + (dmu / n_X) + (dvar * 2/n_X * X_mu)
dbeta = np.sum(dout,axis=0)
dgamma = dout * X_norm
Source code
Here is the source code for BatchNorm layer with forward and backward API implemented.
class Batchnorm():
def __init__(self,X_dim):
self.d_X, self.h_X, self.w_X = X_dim
self.gamma = np.ones((1, int(np.prod(X_dim)) ))
self.beta = np.zeros((1, int(np.prod(X_dim))))
self.params = [self.gamma,self.beta]
def forward(self,X):
self.n_X = X.shape[0]
self.X_shape = X.shape
self.X_flat = X.ravel().reshape(self.n_X,-1)
self.mu = np.mean(self.X_flat,axis=0)
self.var = np.var(self.X_flat, axis=0)
self.X_norm = (self.X_flat - self.mu)/np.sqrt(self.var + 1e-8)
out = self.gamma * self.X_norm + self.beta
return out.reshape(self.X_shape)
def backward(self,dout):
dout = dout.ravel().reshape(dout.shape[0],-1)
X_mu = self.X_flat - self.mu
var_inv = 1./np.sqrt(self.var + 1e-8)
dbeta = np.sum(dout,axis=0)
dgamma = dout * self.X_norm
dX_norm = dout * self.gamma
dvar = np.sum(dX_norm * X_mu,axis=0) * -0.5 * (self.var + 1e-8)**(-3/2)
dmu = np.sum(dX_norm * -var_inv ,axis=0) + dvar * 1/self.n_X * np.sum(-2.* X_mu, axis=0)
dX = (dX_norm * var_inv) + (dmu / self.n_X) + (dvar * 2/self.n_X * X_mu)
dX = dX.reshape(self.X_shape)
return dX, [dgamma, dbeta]
Citation
If you find this post useful, please cite it as:
Dahal, Paras. (May 2017). Batch Normalization: Stabilizing Deep Neural Network Training. Paras Dahal. https://parasdahal.com/batchnorm.
Or in BibTeX format:
@article{dahal2017batchnorm,
title = "Batch Normalization: Stabilizing Deep Neural Network Training",
author = "Dahal, Paras",
journal = "parasdahal.com",
year = "2017",
month = "May",
url = "https://parasdahal.com/batchnorm"
}