# Batch Norm

**Paper**: <https://arxiv.org/pdf/1502.03167.pdf>

**Batch Normalization transformation:**&#x20;

![](https://1268307957-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MU4s3p9Ql9V0G1xvT9E%2F-MZJcScjI8w0zqcaxGRx%2F-MZJknVPNye7idJ_IvjN%2Fimage.png?alt=media\&token=b42ddcdc-d4d7-4b02-b575-f1670f52fc4a)

**Training:**

when training update gamma and beta using batch samples and also update running mean E\[x], running variance Var\[x]

When inference, use running mean E\[x], running variance Var\[x], rather than sample mean and sample var, to compute output

![](https://1268307957-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MU4s3p9Ql9V0G1xvT9E%2F-MZJcScjI8w0zqcaxGRx%2F-MZJl1_6XtS7p1WAJ6Qw%2Fimage.png?alt=media\&token=609e3d74-434d-4a5d-9b05-844f9fc88c8c)

**Gradient Computation** during training and updating var, mean

![](https://1268307957-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MU4s3p9Ql9V0G1xvT9E%2F-MZJcScjI8w0zqcaxGRx%2F-MZJloiCzUXbb31-mvQR%2Fimage.png?alt=media\&token=c24d426e-869e-4a9f-abbe-4aced66a22d4)

**Coding in pytorch**

```
import torch
from torch import nn
class BatchNorm(nn.Module):
  def __init__(self, in_dim, out_dim):
    super(BatchNorm, self).__init__()
    self.linear = nn.Linear(in_dim, out_dim,bias=True)
    self.in_dim, self.out_dim = in_dim, out_dim
    self.eps = 0.001
    self.momentum = 0.95
    self.running_mean = torch.zeros(in_dim,dtype=torch.float32)
    self.running_var = torch.zeros(in_dim,dtype=torch.float32)
    pass
  def forward(self,x, inference = False):
    #normalize x according to batch 
    mu = torch.mean(x, dim=0, keepdim=True) # compute along the first dim: num of samples
    var = torch.var(x, dim=0,keepdim=True)  # compute along the first dim: num of samples
    
    #print("mean: ", mu)
    #print("var: ", var)
    if not inference:
      # when training, update running mean and var
      # update x with sample mean
      self.running_mean = self.momentum*self.running_mean + (1-self.momentum)*mu # update running mean with sample mean
      self.running_var = self.momentum*self.running_var + (1-self.momentum)*var # update running var with sample var
      x = (x-mu)/torch.sqrt(var + self.eps)
    else:
      # when inference, use estimated running mean, var
      x = (x - self.running_mean)/ torch.sqrt(self.running_var + self.eps)
    # linear projection to rescale and enhance feature
    out = self.linear(x)
    return out
x = torch.tensor([[1.1, 1, 2, 3, 4, 5],[10, 11, 12, 13, 14, 15],[21,22,23,24,25,26]], dtype= torch.float32)
x2 = torch.rand((3, 6),dtype = torch.float)
bn = BatchNorm(in_dim=6, out_dim=3)

bn(x), bn(x2),bn(x, inference=True)


```
