Start building your own chatbot now!

Batch normalization is a recently developed technique to reduce training time. Sometimes, it can even improve the accuracy of a model. But what exactly is batch normalization? Simply put, it is a way to counteract the internal covariate shift between two layers of a neural network. If you are like me, and you did not find this answer very helpful, let’s break down what the internal covariate shift is and why it is messing with your network and how batch normalization is a genius solution.

## What is Internal Covariate Shift

Given a neural network composed of several layers, the forward propagation can be abstracted as function composition (function composition): each layer can indeed be viewed as a function that takes as input a vector from $latex \mathbb{R}^n$ and outputs a vector from $latex \mathbb{R}^m$ (for a layer of $latex m$ neurons preceded by a layer of $latex n$ neurons). For each function, there is a linear mapping from $latex \mathbb{R}^n$ to $latex \mathbb{R}^m$, given a matrix $latex w$ of weights of dimension $latex n, m$, followed by the addition of the biases $latex b$ ($latex b \in\mathbb{R}^m$) and finally a non-linear activation (e.g the sigmoid function).

$latex \sigma (x) = \frac{1}{1 + e^{-x}}$

$latex \sigma^{\prime}(x) = (1 – \sigma(x))\sigma (x)$

When a network is viewed as functions composed (the similarity of concept between function composition and deep learning is subject to many speculations among researchers) with one another, it is easier to notice that the output of a function (except for the first layer aka the input function) would be affected by the small changes of all preceding functions.

Let’s consider the sigmoid activation function $latex \sigma$. Let $latex z = \sigma(Wu + b)$, for $latex u$ the input and $latex W$, $latex b$ respectively the weights and the biases. It can be verified that as $latex \lvert x \rvert$ grows (see derivatives for the sigmoid function), $latex \sigma^{\prime}(x)$ approaches $latex 0$, meaning that the gradient of the sigmoid function vanishes. In the case of our layer, for each of its dimension $latex x = Wu + b$, except for low values of the components of $latex x$, the gradient will vanish.

Imagine a bunch of guys putting out a fire, organized in parallel lines, trying to pour water inside buckets from the the front line to the back, where there is a building burning. In this metaphor each guy is like a neuron, and each line of guys are like the layers of a neural network.
Imagine that after each back and forth passage of the buckets of water, a large number of your firefighters stopped working, or worse, spilt pretty much all their water where they are and you, as their leader, you would be helpless to make them learn how to hold their bucket. If they are training to become firefighters, as you could imagine, it would not be effective. It is what is happening to a network when it is subject to ICS, at each forwardpropagation and backpropagation, a large portion of your neurons will permanently stop learning.

The dramatic part is that this effect has the ability to amplify as your network goes deeper. For each backpropagation step, and for each layer, small changes in the output of the preceding layer might shift the distribution of the inputs of the layer of interest to the vanishing gradient zone. Some interesting simulations can be found at batch norm simulation. At the end of the training, a lot of your neurons might be in the saturated regime of the nonlinearities and they will be prevented from learning anything, which will probably decrease the accuracy of your network.

In a nutshell this is the Internal Covariate Shift.

## How ICS used to be taken care of

First, if you wish to mitigate the dire effect of ICS and its associated vanishing gradients, you could use a very low learning rate to train your network, typically below $latex 1.10^-4$. Usually, this has the effect of limiting the shifts in the input’s distribution of each function that constitute the neural network. But this is at the expense of a much longer training time. Another popular technique was to carefully initialize the weights of the network. The most popular initialization was the He initialization. But you have then to take into account new and complicated hyper–parameters. By the way, rectified linear units (ReLU) were also introduced to solve, at least partially, the problem of vanishing gradients.

ICS and vanishing gradients made it harder to go deeper with neural networks. Even when the ICS between two adjacent layers was not dramatic, meaning not too many weights or biases were sent to the saturated regime of the nonlinearity, the backpropagation step could be more about correcting the change induced by the previous layer. This could also be counteracted by dropout. One can view dropout as a way to make the computation of each neuron independent. Indeed, if a change in parameters is detrimental to the task of a network and if that cannot be partially corrected by the following layer, the chances that such a change might be eliminated are higher.

Overall, ICS was counteracted with careful initialization, a very low learning rate and the addition of dropout. The cost of going deeper was then a much longer training time due to dropout and a slower learning rate and the addition of new hyper-parameters.
Again, take my example of the firefighter apprentices. Imagine that you want to train them to handle their buckets of water very, very slowly. This is what you would do by reducing the learning rate of network during the training phase.

If you want to implement dropout with your firefighters, you would also need to stop one guy working at each epoch (a forward and backward propagation). In order to avoid water loss, the idea is to force the guy that came before to hold both reponsability and his water instead of counting on the next guy trying to catch the bucket.

As for He initialization, the metaphor starts to get out of hand!

## Batch Normalization as a solution

The solution came with the realization that when the inputs of a network are transformed so that they have a zero mean and unit variance, the network converges faster. At first Ioffe and colleagues tried to subtract to $latex x_k$ — the output of the $latex k^{ieth}$ layer — the mean of all training values at this layer. By doing that, the authors realized that some parameters could increase indefinitely through backpropagation while the overall loss would remain unchanged. To remedy this effect, the authors divided the output $latex x_k$ by the variance of batch values. Ioffe et al tried to introduce this transformation inside the network, with $latex x_k$ the batch output of a layer which is then transformed in $latex y_t$, the output of batch normalization. $latex E(x_k)$ is the mean of training batch values and $latex Var(x_k)$ its variance. $latex \gamma$ and $latex \beta$ are two trainable parameters and the beauty of this framework is that it is end-to-end differentiable, so that both $latex \gamma$ and $latex \beta$ can be learned. $latex \epsilon$ is introduced to prevent division by zero in case the variance equalled zero.

$latex \hat{x}_k = \frac{x_k – E(x_k)}{\sqrt{Var(x_k)} + \epsilon}$

$latex y_k = \gamma \hat{x}_k + \beta _{k}$

To conclude my firefighter story, batch normalization would be to reposition the firefighters at the right distance between each line of firefighters so that their individual learning would not be affected by too much by where the teammates that came before placed the bucket.

Ioffe and colleagues succeeded in training their network with a much higher learning rate, while reducing the need for dropout and also allowing the usage of saturating nonlinearities for deep network. As matter of fact, they reduced the training time by a factor of 14. And, the cherry on top, they even gained a few percentage points accuracy for image classification. Today batch normalization is commonly used for vision-related tasks and for deep learning in general.

## Keras implementation

You can easily implement a batch-normalization layer in keras:

from keras import Sequential
from keras.layers import BatchNormalization, Dense

model = Sequential()
model.add(Dense(input_dim=your_dim))
model.add(BatchNormalization())

Hope that helps! Let me know in the comments if you have any questions, and happy coding.