LeetCode-Notes
  • Introduction
  • Records of Practice
  • 关于Github 不支持密码问题
  • 面试题
    • 搜索广告
    • 多模态大模型
    • 刷题记录
  • 算法代码实现
  • Python
    • Python 笔记
  • Spark
    • PySpark
    • Spark Issues
    • Spark调优笔记
  • FeatureEngineering
    • Feature Cleaning
    • Feature Selection
    • Feature Transformation
    • Feature Crossing
  • Recommendation Algorithm
    • Recall-and-PreRank
      • Non-Negative Matrix Fatorization(NMF)
      • Fatorization Machine(FM)
      • User-base/Item-base实现
      • 多路召回实现
    • Ranking
      • NeuralFM
      • DeepFM
      • Deep&Cross network (DCN)
    • DeepLearning-Basic
      • Attention
      • Dropout
      • Batch Norm
  • Machine Learning
    • XGBoost
    • Cross Entropy Loss
    • Other models
  • Graph Neural Network
    • GNN-1-Basic
  • Big Data
    • Reservoir Sampling
  • SQL
    • SQL and PySpark functions
    • Query Film Infomation
    • Create, Insert and Alter Actor Table
    • Manage Employment Data
    • Manage Employment Data -2
  • DataStructure
    • Searching
      • Find power
      • 2 Sum All Pair II
      • Two Sum
      • Search in Rotate Array
      • Search In Shifted Sorted Array II
      • Search in 2D array
      • Three Sum with duplicated values
      • Median of Two Sorted Arrays
    • Array
      • Longest Consecutive Subarray
      • Merge Two Array in-place
      • Trapping water
      • Rotate matrix
    • Sorting
      • Merge intervals
      • 排序
      • 最小的k个数
      • Find TopK largest- QuickSelect快速选择 method
      • MergeSort Linkedlist
      • 第K大元素
    • LinkedList
      • Reverse LinkedList I
      • Reverse K-group linked list
      • Detect Start of Cycle
      • HasCycle
      • DetectCycle II
      • 链表的共同节点
      • 链表中倒数第k个节点
      • 删除链表倒数第k个节点
      • 合并两个链表
      • 在排序数组中查找元素的第一个和最后一个位置
      • 删除链表里面重复的元素-1
    • Tree
      • Find Tree height (general iteration method)
      • Check BST and Check CompleteTree
      • ZigZag Order traversal
      • Binary Tree diameter I
      • Maximum Path Sum Binary Tree
      • Maximum Path Sum Binary Tree II
      • Binary Tree Path Sum To Target III
      • Tree diameter 树的直径II
      • Tree ReConstruction
      • Check if B is Subtree of A
      • The Kth smallest in Binary Search Tree
      • 打印Tree的右视图
      • 二叉搜索树的后序遍历序列
      • 重建二叉树
      • 判断二叉树是否对称
      • Path Sum to Target in Binary Tree
      • Tree-PreOrder-InOrder-PostOrder
    • Heap&Queue
      • Top-K smallest
      • 滑动窗口最大值
      • Find the K-Largest
    • 合并k个已排序的链表
    • String
      • Reverse String
      • 最长不含重复字符的子字符串
      • 最长回文串
      • 最长回文子序列-DP
    • DFS/BFS
      • Number of island
      • Number of Provinces
      • All Permutations of Subsets without duplication
      • All Permutations of Subsets with duplication
      • Combinations Of Coins
      • All Subset I (without fixing size of subset, without order, without duplication)
      • All Subset of K size without duplication II
      • All Subset of K size III (with duplication without considering order)
      • All Permutation II (with duplication and consider order)
      • Factor Combination-质数分解
    • DynamicProgramming
      • DP-解题过程
      • Find Continuous Sequence Sum to Target
      • 1800. Maximum Ascending Subarray Sum
      • NC91 最长上升子序列
      • 查找string的编码方式个数
      • Maximum Product
      • Longest Common Substring
      • Longest Common Substring-II
      • minEditCost
      • Backpack I
      • Array Hopper I
      • Minimum distance between strings
      • 最大正方形
  • Big Data Algorithms
    • Big Data Processing Algorithms
      • Reservior Sampling
      • Shuffle
      • MapReduce
      • Bloom Filter
      • BitMap
      • Heap For Big Data
Powered by GitBook
On this page

Was this helpful?

  1. Recommendation Algorithm
  2. DeepLearning-Basic

Batch Norm

PreviousDropoutNextXGBoost

Last updated 4 years ago

Was this helpful?

Paper:

Batch Normalization transformation:

Training:

when training update gamma and beta using batch samples and also update running mean E[x], running variance Var[x]

When inference, use running mean E[x], running variance Var[x], rather than sample mean and sample var, to compute output

Gradient Computation during training and updating var, mean

Coding in pytorch

import torch
from torch import nn
class BatchNorm(nn.Module):
  def __init__(self, in_dim, out_dim):
    super(BatchNorm, self).__init__()
    self.linear = nn.Linear(in_dim, out_dim,bias=True)
    self.in_dim, self.out_dim = in_dim, out_dim
    self.eps = 0.001
    self.momentum = 0.95
    self.running_mean = torch.zeros(in_dim,dtype=torch.float32)
    self.running_var = torch.zeros(in_dim,dtype=torch.float32)
    pass
  def forward(self,x, inference = False):
    #normalize x according to batch 
    mu = torch.mean(x, dim=0, keepdim=True) # compute along the first dim: num of samples
    var = torch.var(x, dim=0,keepdim=True)  # compute along the first dim: num of samples
    
    #print("mean: ", mu)
    #print("var: ", var)
    if not inference:
      # when training, update running mean and var
      # update x with sample mean
      self.running_mean = self.momentum*self.running_mean + (1-self.momentum)*mu # update running mean with sample mean
      self.running_var = self.momentum*self.running_var + (1-self.momentum)*var # update running var with sample var
      x = (x-mu)/torch.sqrt(var + self.eps)
    else:
      # when inference, use estimated running mean, var
      x = (x - self.running_mean)/ torch.sqrt(self.running_var + self.eps)
    # linear projection to rescale and enhance feature
    out = self.linear(x)
    return out
x = torch.tensor([[1.1, 1, 2, 3, 4, 5],[10, 11, 12, 13, 14, 15],[21,22,23,24,25,26]], dtype= torch.float32)
x2 = torch.rand((3, 6),dtype = torch.float)
bn = BatchNorm(in_dim=6, out_dim=3)

bn(x), bn(x2),bn(x, inference=True)

https://arxiv.org/pdf/1502.03167.pdf