Different Normalization Layers in Deep Learning

Weight Standardization — An Alternative to Batch Normalization

Nilesh Vijayrania
Towards Data Science

--

Presently Deep Learning has been revolutionizing many subfields such as natural language processing, computer vision, robotics, etc. Deep learning certainly involves training carefully designed deep neural networks and various design decisions impact the training regime of these deep networks. Some of these design decisions include

  • Type of network layer to use such as convolution layer, linear layer, recurrent layer, etc. in the network, and how many layers deep should our network be?
  • What kind of normalization layer we should use if at all?
  • What should be the correct loss function to optimize for?

Majorly these design decisions depend upon the underlying task we are trying to solve and require a deeper understanding of the different options we have at hand. In this post, I will focus on the second point “different Normalization Layers in Deep Learning”. Broadly I would cover the following methods.

  • Batch Normalization
  • Weight Normalization
  • Layer Normalization
  • Group Normalization
  • Weight Standarization

Batch Normalization(BN)

Batch Normalization focuses on standardizing the inputs to any particular layer(i.e. activations from previous layers). Standardizing the inputs mean that inputs to any layer in the network should have approximately zero mean and unit variance. Mathematically, BN layer transforms each input in the current mini-batch by subtracting the input mean in the current mini-batch and dividing it by the standard deviation.

But each layer doesn’t need to expect inputs with zero mean and unit variance, but instead, probably the model might perform better with some other mean and variance. Hence the BN layer also introduces two learnable parameters γ and β.

The whole layer operation is as follows. It takes an input x_i and transforms it into y_i as described in the below table.

Credits: Ioffe and Szegedy, Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

The question is how BN helps NN training? Intuitively, In gradient descent, the network calculates the gradient based on the current inputs to any layer and reduce the weights in the direction indicated by the gradient. But since the layers are stacked one after the other, the data distribution of input to any particular layer changes too much due to slight update in weights of earlier layer, and hence the current gradient might produce suboptimal signals for the network. But BN restricts the distribution of the input data to any particular layer(i.e. the activations from the previous layer) in the network, which helps the network to produce better gradients for weights update. Hence BN often provides a much stable and accelerated training regime.

However below are the few cons of Batch Normalization.

  • BN calculates the batch statistics(Mini-batch mean and variance) in every training iteration, therefore it requires larger batch sizes while training so that it can effectively approximate the population mean and variance from the mini-batch. This makes BN harder to train networks for application such as object detection, semantic segmentation, etc because they generally work with high input resolution(often as big as 1024x 2048) and training with larger batch sizes is not computationally feasible.
  • BN does not work well with RNNs. The problem is RNNs have a recurrent connection to previous timestamps and would require a separate β and γ for each timestep in the BN layer which instead adds additional complexity and makes it harder to use BN with RNNs.
  • Different training and test calculation: During test(or inference) time, the BN layer doesn’t calculate the mean and variance from the test data mini-batch(steps 1 and 2 from the algorithm table above) but uses the fixed mean and variance calculated from the training data. This requires cautious while using BN and introduces additional complexity. In pytorch model.eval() makes sure to set the model in evaluation model and hence the BN layer leverages this to use fixed mean and variance from pre-calculated from training data.

Weight Normalization

Due to the disadvantages of Batch Normalization, T. Saliman and P. Kingma proposed Weight Normalization. Their idea is to decouple the length from the direction of the weight vector and hence reparameterize the network to speed up the training.

What does reparameterization mean for Weight Normalization?

The authors of the Weight Normalization paper suggested using two parameters g(for length of the weight vector) and v(the direction of the weight vector) instead of w for gradient descent in the following manner.

Weight Normalization speeds up the training similar to batch normalization and unlike BN, it is applicable to RNNs as well. But the training of deep networks with Weight Normalization is significantly less stable compared to Batch Normalization and hence it is not widely used in practice.

Layer Normalization(LN)

Inspired by the results of Batch Normalization, Geoffrey Hinton et al. proposed Layer Normalization which normalizes the activations along the feature direction instead of mini-batch direction. This overcomes the cons of BN by removing the dependency on batches and makes it easier to apply for RNNs as well.

In essence, Layer Normalization normalizes each feature of the activations to zero mean and unit variance.

Group Normalization(GN)

Similar to layer Normalization, Group Normalization is also applied along the feature direction but unlike LN, it divides the features into certain groups and normalizes each group separately. In practice, Group normalization performs better than layer normalization, and its parameter num_groups is tuned as a hyperparameter.

If you find BN, LN, GN confusing, the below image summarizes them very precisely. Given the activation of shape (N, C, H, W), BN normalizes the N direction, LN and GN normalize the C direction but GN additionally divides the C channels into groups and normalizes the groups individually.

Image Credits: Siyuan Qiao et al.: Weight Standardization

Lets next understand what weight Standardization is.

Weight Standardization(WS)

Weight Standardization is transforming the weights of any layer to have zero mean and unit variance. This layer could be a convolution layer, RNN layer or linear layer, etc. For any given layer with shape(N, *) where * represents 1 or more dimensions, weight standardization, transforms the weights along the * dimension(s).

Below is the sample code for implementing weight standardization for the 2D conv layer in pytorch.

class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel, **kwargs):
super().__init__(in_channels, out_channels, kernel, **kwargs)

def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=(1,2,3), keepdim=True)
std = weight.std(dim=(1,2,3) + 1e-5
weight = (weight - weight_mean)/ std
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

The basic idea is to only transform the weights during the forward pass and calculate activations accordingly. Pytorch will handle the backward pass out of the box. Similarly, it could be implemented for the linear layer as well.

Recently, Siyun Qiao et al. introduced Weight Standardization in their paper “Micro-Batch Training with Batch-Channel Normalization and Weight Standardization” and found that group normalization when mixed with weight standardization, could outperform or perform equally well as BN even with batch size as small as 1. Shown below in the graph, the authors trained GN, BN, combination of GN+WS with Resnet50 and Resnet101 on Imagenet classification and MS COCO object detection task and found that GN+WS consistently outperforms the BN version even with much smaller batches than BN uses. This has attracted attention in dense prediction tasks such as semantic segmentation, instance segmentation which are usually not trainable with larger batch sizes due to memory constraints.

Image Credits: Siyuan Qiao et al. Weight Standardization. GN+WS effect on classification and object detection task[4]

In conclusion, Normalization layers in the model often helps to speed up and stabilize the learning process. If training with large batches isn’t an issue and if the network doesn’t have any recurrent connections, Batch Normalization could be used. For training with smaller batches or complex layer such as LSTM, GRU, Group Normalization with Weight Standardization could be tried instead of Batch Normalization.

One important thing to note is, in practice the normalization layers are used in between the Linear/Conv/RNN layer and the ReLU non-linearity(or hyperbolic tangent etc) so that when the activations reach the Non-linear activation function, the activations are equally centered around zero. This would potentially avoid the dead neurons which never get activated due to wrong random initialization and hence can improve training.

Below is the list of references used for this post and should be considered for further experiment details.

  1. Ioffe, Sergey, and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” arXiv preprint arXiv:1502.03167 (2015).
  2. Salimans, Tim, and Durk P. Kingma. "Weight normalization: A simple reparameterization to accelerate training of deep neural networks." Advances in neural information processing systems 29 (2016): 901-909.
  3. Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. “Layer normalization.” arXiv preprint arXiv:1607.06450 (2016).
  4. Qiao, Siyuan, et al. “Weight standardization.” arXiv preprint arXiv:1903.10520 (2019)

--

--