Gradient histograms in a 10 layer network at initialization. The typical size of gradients is the same in all layers in a net without Batch Norm (left) and grows exponentially after inserting Batch Norm in every layer (right)

Deep learning practitioners know that using Batch Norm generally makes it easier to train deep networks. They also know that the presence of exploding gradients generally makes it harder to train deep networks. So recent work by Yang et. al might seem quite surprising; they show that our beloved Batch Norm can actually cause exploding gradients, at least at initialization time.

I like their paper because it has the rare quality of saying something both rigorous and meaningful about Batch Norm. Unfortunately, there’s not much intuition behind the 95 pages of technical mathematics comprising the paper. In this post I’ll provide a more intuitive explanation for the gradient explosion phenomenon. With a few “back-of-the-envelope” calculations (and admittedly a big envelope), we can show in simplified settings that gradient norms grow by π/(π1) in each layer of a ReLU network with Batch Normalization.

TL;DR Inserting Batch Norm into a network means that in the forward pass each neuron is divided by its standard deviation, σ, computed over a minibatch of samples. In the backward pass, gradients are divided by the same σ. In ReLU nets, we’ll show that this standard deviation is less than 1, in fact we can approximate σ(π1)/π0.82. Since this occurs at every layer, gradient norms in early layers are amplified by roughly 1.21(1/0.82) in every layer.

What are “exploding gradients”?

Before our calculation, let’s empirically show that “Batch Norm causes exploding gradients”. We’ll just initialize a net without Batchnorm, compute gradients, then insert Batch Norm into every layer of the net and recompute gradients and see how they differ.

Our net without Batch Norm, our “vanilla net”, is shown below. Its a boring old 10 layer, feedforward, fully connected network of uniform layer width N=1024. We’re using the ReLU nonlinearity so we’ll initialize weights with Kaiming initialization. This means we’ll set all the biases to zero and sample every element of W from a zero mean Gaussian with variance 2/N.

Our inputs x0(t) are 512 random Gaussian vectors. We’ll compute a linear loss over the network’s outputs: E=512t=1wx10(t). We choose w by drawing it from a unit Gaussian. Now we have everything we need to compute the gradients dE/dxl(t) at each layer l in the vanilla network (We’ll use Pytorch to automate this process for us).

Now we take that exact network, loss, and set of input vectors, but insert Batch Normalization before the ReLU in every layer:

Again we can compute gradients dE/dxl(t) at each layer. Now we’ll visualize the gradient histograms in each layer for the vanilla and Batch Norm net. Each histogram is made up of 512×1024 scalars (gradients for 512 inputs × 1024 features):

In the vanilla network, the gradient histograms are basically the same in every layer. However, after inserting Batch Norm, the gradients actually grow with decreasing layer. In fact, the widths of these histograms grow exponentially. It is in this sense that gradients explode. The rest of this post will be devoted to calculating how the widths of these histograms grow.

Gradients in a Batch Normalized Net

In this section we’ll review the details of Batch Normalization and how it modifies the forward and backward pass of a neural network.

Forward Pass

Each layer in our normalized network contains 3 modules: matrix multiply, Batch Norm, and ReLU. These are shown in the diagram above.

xl,yl and zl denote the vector outputs of the matrix multiply, Batch Norm, and ReLU modules in layer l for a single input. The element-wise product is denoted by ab. We’ll abuse notation and denote element-wise division with a/b. The notation indicates a minibatch average so μl and σ2l are the mean and variance of each element of yl over a minibatch. f is the ReLU nonlinearity which acts element-wise on an input vector.

Because we are only going to examine networks at initializion time, we have made two simplifications. One, we ignore bias addition in the matrix multiply layer as biases are initialized to zero. Two, we can ignore the part of Batch Norm with learnable parameters which transforms zlγlzl+βl because γl=1 and βl=0 at initialization.

Backward Pass

The diagram above shows the backwards pass through each module. Recall that the backwards pass applies the chain rule to compute gradients with respect to a module’s inputs, given gradients with respect to the module’s outputs.

For the Batch Norm module this gets rather tedious. Fortunately this is derived in the original Batch Norm paper, which I have simplified in the diagram above. If you want a step-by-step derivation, check out this blog post!

Before moving on, we’ll highlight two superficial similarities between the forwards and backwards pass through a Batch Norm module. One, division by σl occurs in each case. Two, in the same way Batch Norm centers the forwards pass, zl=0, it also centers gradients in the backwards pass, dEdy=0.

Typical Gradients in a Batch Normalized Net

Our goal was to understand how the width of the gradient histogram grows at each layer in our experiment. Recall that each histogram contained the gradient of every neuron in a layer computed over all samples. Since this histogram is zero-mean, we can easily compute this width in our experiment: 1TTt=11NNi=1(dEdxl,i(t))2.

However, we want to understand theoretically how this width grows, and this sum is not particularly enlightening. Since N is large in our experiments, the unwieldy sum over neurons is a decent approximation for a much simpler quantity: the squared gradient of a single neuron, averaged over weights.

Typical Squared Gradient =(dEdxl(t))21NNi=1(dEdxl,i(t))2

The double brackets indicate an average over weights in all layers. Also note that we have dropped the index i on the left side of the equation, it is is the same for every neuron in a layer (because the weights have identical distributions). So the histogram width is nearly equal to the typical squared gradient, averaged over samples in the minibatch.

To compute the typical squared gradient, we just need to square and average over weights each backwards pass equation for the three modules comprising each layer. Before deriving these, I’ll just show the results:

The ReLU module halves the average squared gradients and the matrix multiply module doubles the average squared gradient. So the combination preserves average squared gradients. In fact this was part of the original motivation given in the Kaiming initialization paper; set the weight variance to 2/N so gradients are preserved in networks without Batch Norm.

The interesting thing happens in the Batch Norm module. It amplifies average squared gradients by π/(π1). Note the approximation sign here. To calculate this we assume we have a wide network and a large minibatch with “sufficient randomness”.

Now let’s get to the derivation. Computing the typical squared backward pass through the matrix multiply and ReLU modules has been done before and is not so complicated, so we’ll move quickly to get the final results. If you want a more in depth derivation, check out either the Kaiming initialization paper or this blog post!

Matrix Multiply Module

Consider the backwards pass through a matrix multiply module dEdxi=jWjidEdyj. To minimize clutter, we’ve removed layer indices l1,l on the x,y,W. Now lets derive the “typical backpropagation” equation.

Step 1: Square both sides and average over weight configurations:

(dEdxi)2=jkWjiWkidEdyjdEdyk

Step 2: To make further progress, we need to assume that forward pass quantities are independent from backwards pass quantities, over randomness in the weights. This is called the “gradient independence” assumption. This allows us to separate the W terms and dE/dy terms:

(dEdxi)2=jkWjiWkidEdyjdEdyk

Step 3: We can simplify the sum using the fact that Wij are independent zero mean variables with variance 2/N so WjiWki=(2/N) if j=k and 0 otherwise

(dEdxi)2=(2/N)j(dEdyj)2

Step 4: Use the fact that the typical squared gradient at each layer is the same for all elements. This follows from the identical distribution of W.

(dEdxl1)2=2(dEdyl)2

ReLU Module

Consider the backwards pass through the ReLU module dEdzi=f(zi)dEdxi. Let’s procede like before.

Step 1: Square both sides of the backward pass equation and average over weights:

(dEdzi)2=f(zi)2(dEdxi)2

Step 2: Separate the f terms and dE/dx terms using the gradient independence assumption:

(dEdzi)2=f(zi)2(dEdxi)2

Step 3: ReLU is a particularly simple activation. If zi>0, then f=1. Otherwise f=0. This is great, since over weight configurations zi>0 half the time and zi<0 the other half. So the average f(zi)2=1/2:

(dEdz)2=12(dEdx)2

At this point, we can combine the typical backwards pass equations for the ReLU and the matrix multiply module to see that if we didn’t have Batch Norm, gradients would preserved

(dEdxl1)2=(dEdxl)2

Batch Norm Module

In some sense, 2/3 of our work is done; in principle we just need to average the squared backwards pass equations over weights like we did before. However in a much more real sense, most of the work is still to come.

The first complication is the gradients that come from backpropagating through the minibatch mean μ and variance σ2. We will argue that these gradients are nearly zero, at least if there is enough “randomness” in the data and the networks are sufficiently wide.

The second complication is the division by σ, the minibatch standard deviation. Since it is an average over samples, not weights, it will take some assumptions (and labor) to calculate.

Approximation 1: zdEdz0

We are going to argue that zdEdz, the term from backpropagating through σ, is nearly zero at initialization time.

To do so we’ll extend the gradient independence assumption to apply over the minibatch distribution. This means that for a typical configuration of weights, we assume forward pass quantities are independent from backward pass quantities over the minibatch distribution1.

With this assumption we have

zdEdzzdEdz=0

The last equality follows from the fact that z=0 because Batch Norm explicitly centers this variable.

Approximation 2: dEdz0

Now we’ll argue that dEdz, the term from backpropagating through μ, is nearly zero at initialization time, except possibly in the last layer.

First note that Batch Norm ensures gradients with respect to y are exactly zero mean: dEdy=0.With our backwards pass equations, we can write gradients for zl in terms of gradients for yl+1: dEdzl=f(zl)WTldEdyl+1.

Now we average both sides of this equation over samples in the minibatch. With the extended gradient independence assumption, f(zl) and dEdyl+1 are independent quantities so we can simplify the minibatch-averaged sum:

dEdzlf(zl)WTldEdyl+1=0

Note that this argument doesn’t really hold up in the last layer; we required a Batch Norm in layer l+1 to center incoming gradients to layer l. So if there are only L layers in the network, gradient signals to zL might not be centered. But we won’t get too worked up if this doesnt hold for just one layer.

Combining the last two approximations, we are left with the greatly simplified backwards pass equation:

dEdyl1σldEdzl

The next step will be to average this over weights like we did for the matrix multiply and ReLU modules.

Approximation 3: σ2π1π

We’re in the home stretch. Squaring then averaging the final equation over weights gives us:

(dEdyl)21σ2l(dEdzl)2

All that remains is to compute (1/σl)2, the expected inverse variance of an element of yl. While we already know that (1/σl)21.2 from our simulations (this is the factor by which the width of gradient histograms grew every layer), let’s calculate it analytically.

Step 1: Let’s start with something we do know the variance of. Because z is the output of a Batch Norm module, it has to be zero-mean and unit variance over the minibatch. Since x is just a rectified version of z, it might seem like we can relate the variance of x to z.

Unfortunately knowing the mean and variance of z is not enough to evaluate either of these integrals. It seems like a reasonable assumption however that z will take a Gaussian distribution because z is just a centered and scaled version of y, and y is the sum of N random variables. Also it’s easy enough to visualize histograms of z in our experiment and confirm that it appears Gaussian.

Now we just have to integrate the unit Gaussian over the positive half:

x2x2=12π+0z2ez22dz[12π+0zez22dz]2

The first term is just 12 (half the variance of a unit Gaussian). The 2nd term is 12π which follows from direct integration: zez2/2dz=ez2/2.

Step 2: Now we can relate the expected variance of an element of y to the variance of x (recall that the minibatch variance is computed over samples in the minibatch, and in general it depends on the network weights):

σ2=y2y2=w[xxxx] w=wCxxw

It is simple enough to compute the average minibatch variance σ2. When ij we have wiwj=0 and when i=j we have wiwj=2/N. Since we have assumed the x are identically distributed, we are left with the result:

σ2=2x2x2=π1π

Step 3: We’ve computed σ2, but this is generally different from 1/σ2, which is the quantity we wanted. We’ll argue that in our experiment, these are actually the same because fluctuations of σ2 are small so 1σ21σ2

In physics terms, we argue that σ2, the minibatch variance of an element of y, is a self-averaging quantity. This means that for nearly every configuration of weights, σ2 is roughly the same2:

y2y2y2y2

We can empirically show this self-averaging assumption holds in our simulation. We’ll do so by visualizing the distribution of the first 3 elements of y in the 5th layer of our Batch Normalized net.

Interestingly each element of y seems to have a different mean, but we can see that fluctuations around the mean appear to have identical distributions. This means that the variance of every element is the same (roughly).

Step 4: We’ve made it to the end. We can combine our estimate 1/σ2=ππ1 with our typical backward pass equation through a Batch Norm module: (dEdyl)21σ2l(dEdzl)2 to see how Batch Norm amplifies the gradients in the backwards pass:

(dEdyl)2(ππ1)(dEdzl)2

Result: Squared Gradients Amplified by ππ1

Combining the typical backwards pass equations through every layer, we get the final result that squared gradients are amplified by ππ1 at every layer:

(dEdxl1)2(ππ1)(dEdxl)2

Of course we can empirically test our approximations by measuring the width of the gradient histograms in each layer of our simulation:

I also ran a 3rd experiment, this time freezing μ and σ. So the forward pass of this network was identical to the Batch Norm network, but gradients were not backpropagated through the μ and σ. Earlier we predicted that expect possibly in the final layer, gradient backpropagation probably wasn’t changed much by μ and σ in a random network.

In both cases, I measured the the RMS gradient (the width of each histogram) to by 1.21, which as expected is very close to π/(π1). Cool, the theory worked!

Normalized Preactivations Imply Very Nonlinear Input-Output Mappings

Before concluding, we’ll give one more way to think about the exploding gradient phenomenon: by normalizing the inputs to the ReLU nonlinearity, you are ensuring that the nonlinearity is used “more effectively” at each layer and the network is more likely to implement some crazy nonlinear function which massively distorts the geometry of the input space.

This probably makes more sense if we just look at the distribution of inputs to ReLU in the vanilla and normalized networks:

In deeper layers of the vanilla net, some of the preactivations lie mostly on one side of the nonlinearity. When this happens, the ReLU activation either computes the function xi=yi for all inputs or it computes xi=0 for all inputs; either way, it’s acting in a linear fashion. This does not happen in the normalized network because inputs to the ReLU are centered by Batch Norm

The authors of the original “Mean Field Theory of Batch Normalization” paper would say the network which normalizes preactivations and therefore implements a highly nonlinear function is in a state of transient chaos.

Does this matter for Real World Network Training?

I’m really not sure.

Maybe it doesn’t matter. One, the actual factor by which gradients explode, at least in our scenario, is not too extreme. They grow by 1.21 at each layer, even after 10 layers this means gradients are just 6x larger in the beginning. This is probably not a crippling difference. And this might be small compared to differences coming from variable layer widths, residual connections, and other complications in real world networks. Two, all this analysis was in random networks. Relating this to training dynamics is not a simple matter.

But maybe it does matter. The authors of the “Mean Field Theory of Batch Normalization” paper show that extremely deep feedforward nets (50+ layers) are hard or impossible to train with Batch Norm. One might object that the deepest real world networks aren’t just 50 layers of matrix multiplication or convolution stacked on top of each other, but they include residual connections too. Perhaps this is no coincidence, maybe Batch Norm is not helpful, or is even harmful in extremely deep nets without other tricks.

Acknowledgements

The post originated from work done with my advisor Sebastian Seung.

Footnotes

  1. This assumption might seem a little odd; at some level deep learning only works because the gradients are functions of the parameters at each layer so they can tell us how to update the parameters to improve the loss. We are basically assuming there is enough “noise” in these gradients over the minibatch distribution. We are assuming that dependencies between forward pass quantities and backwards pass quantities are small compared to fluctuations in the values of forward pass and backwards pass quantities over the minibatch. (If you still don’t like this assumption we’ll later show empirically it seems to hold in our experiment). 

  2. Actually we could theoretically show that σ2 is self-averaging by showing that eigenvalues of the covariance matrix xxxx are “well behaved” (meaning none of the eigenvalues grow have magnitude O(N)). However doing so will take us a little far off track. Basically if our network is wide, we have a large number of samples in the minibatch, and there is enough “randomness” in the samples (for instance they can’t all be copies of the same vector), then σ2 will be self-averaging. For this post, we’ll take the simpler route and just show that empirically, σ2 has small fluctuations.