LeetCode-Notes
  • Introduction
  • Records of Practice
  • 关于Github 不支持密码问题
  • 面试题
    • 搜索广告
    • 多模态大模型
    • 刷题记录
  • 算法代码实现
  • Python
    • Python 笔记
  • Spark
    • PySpark
    • Spark Issues
    • Spark调优笔记
  • FeatureEngineering
    • Feature Cleaning
    • Feature Selection
    • Feature Transformation
    • Feature Crossing
  • Recommendation Algorithm
    • Recall-and-PreRank
      • Non-Negative Matrix Fatorization(NMF)
      • Fatorization Machine(FM)
      • User-base/Item-base实现
      • 多路召回实现
    • Ranking
      • NeuralFM
      • DeepFM
      • Deep&Cross network (DCN)
    • DeepLearning-Basic
      • Attention
      • Dropout
      • Batch Norm
  • Machine Learning
    • XGBoost
    • Cross Entropy Loss
    • Other models
  • Graph Neural Network
    • GNN-1-Basic
  • Big Data
    • Reservoir Sampling
  • SQL
    • SQL and PySpark functions
    • Query Film Infomation
    • Create, Insert and Alter Actor Table
    • Manage Employment Data
    • Manage Employment Data -2
  • DataStructure
    • Searching
      • Find power
      • 2 Sum All Pair II
      • Two Sum
      • Search in Rotate Array
      • Search In Shifted Sorted Array II
      • Search in 2D array
      • Three Sum with duplicated values
      • Median of Two Sorted Arrays
    • Array
      • Longest Consecutive Subarray
      • Merge Two Array in-place
      • Trapping water
      • Rotate matrix
    • Sorting
      • Merge intervals
      • 排序
      • 最小的k个数
      • Find TopK largest- QuickSelect快速选择 method
      • MergeSort Linkedlist
      • 第K大元素
    • LinkedList
      • Reverse LinkedList I
      • Reverse K-group linked list
      • Detect Start of Cycle
      • HasCycle
      • DetectCycle II
      • 链表的共同节点
      • 链表中倒数第k个节点
      • 删除链表倒数第k个节点
      • 合并两个链表
      • 在排序数组中查找元素的第一个和最后一个位置
      • 删除链表里面重复的元素-1
    • Tree
      • Find Tree height (general iteration method)
      • Check BST and Check CompleteTree
      • ZigZag Order traversal
      • Binary Tree diameter I
      • Maximum Path Sum Binary Tree
      • Maximum Path Sum Binary Tree II
      • Binary Tree Path Sum To Target III
      • Tree diameter 树的直径II
      • Tree ReConstruction
      • Check if B is Subtree of A
      • The Kth smallest in Binary Search Tree
      • 打印Tree的右视图
      • 二叉搜索树的后序遍历序列
      • 重建二叉树
      • 判断二叉树是否对称
      • Path Sum to Target in Binary Tree
      • Tree-PreOrder-InOrder-PostOrder
    • Heap&Queue
      • Top-K smallest
      • 滑动窗口最大值
      • Find the K-Largest
    • 合并k个已排序的链表
    • String
      • Reverse String
      • 最长不含重复字符的子字符串
      • 最长回文串
      • 最长回文子序列-DP
    • DFS/BFS
      • Number of island
      • Number of Provinces
      • All Permutations of Subsets without duplication
      • All Permutations of Subsets with duplication
      • Combinations Of Coins
      • All Subset I (without fixing size of subset, without order, without duplication)
      • All Subset of K size without duplication II
      • All Subset of K size III (with duplication without considering order)
      • All Permutation II (with duplication and consider order)
      • Factor Combination-质数分解
    • DynamicProgramming
      • DP-解题过程
      • Find Continuous Sequence Sum to Target
      • 1800. Maximum Ascending Subarray Sum
      • NC91 最长上升子序列
      • 查找string的编码方式个数
      • Maximum Product
      • Longest Common Substring
      • Longest Common Substring-II
      • minEditCost
      • Backpack I
      • Array Hopper I
      • Minimum distance between strings
      • 最大正方形
  • Big Data Algorithms
    • Big Data Processing Algorithms
      • Reservior Sampling
      • Shuffle
      • MapReduce
      • Bloom Filter
      • BitMap
      • Heap For Big Data
Powered by GitBook
On this page

Was this helpful?

  1. Recommendation Algorithm
  2. Recall-and-PreRank

Fatorization Machine(FM)

machine learing

PreviousNon-Negative Matrix Fatorization(NMF)NextUser-base/Item-base实现

Last updated 4 years ago

Was this helpful?

Fatorization machine FM 实现

1. 公式

这里考虑FM的输入是只有一个field (fields一般有 User, Item, Tag等不同的fields。如果在DeepFM里面有多个fields(每个field有多个embedding vectors和对应的feature values), 那么就先把每个field里面的embedding vectors先加起来,这样每个field对应一个pooling后的embedding vector, 之后再做FM的交叉)的feature vector x (只有一个field直接FM交叉多个embedding vectors,如果有多个field,先pooling embedding vector使每个field只有一个embedding后再交叉), 里面可以是continuous value也可以是sparse value,每个value都有对应的embedding vector,交叉时要把continuous value和 embedding vector element-wise相乘作为rescale,而sparse feature作为lookup table的形式选择

2. 化简

第二项可以化简成一个 时间复杂度为O(kn)的计算形式, FM的计算是完全是线性的, 这也是FM的一个相对于传统人工交叉特征的优势

另外如果最外层的loop 不把k个value相加,就会得到一个交叉后的embedding vector, 这个vector可以用作Neural FM的DNN的输入进行学习

3. PyTorch Code

# Fatorization machine pytorch 实现
import torch
from torch import nn
class FM(nn.Module):
  def __init__(self,fea_dim, embed_dim ,reduce_sum = True):
    super(FM, self).__init__()
    self.reduce_sum = reduce_sum
    self.linear = nn.Linear(fea_dim, 1)
    # embedding matrix
    self.V = nn.Parameter(torch.randn(fea_dim, embed_dim))


  def forward(self, x):
    """
    x: input with shape: (batch_size, feature vector dim before embedding)

    Note
      1. fields in FM usually contain User, Item, Tag, different fields mean different profile of user, item, other factor
      2. FM: y = b + \sum_^n_i w_i*xi + \sum^n_i \sum^n_{j=i+1} <vi, vj>xi*xj
      这里的xi, xj 不管是sparse还是continuous value 它都有自己的对应的embedding, 并且是 vi*xi 把这个continuous value 和 embedding rescale
      3. \sum^n_i \sum^n_{j=i+1} <vi, vj>xi*xj 交叉项可以简化成 0.5*\sum^k_f=1 ( (\sum^n_i(vi*xi))^2 - \sum^n_i(vi^2*xi^2) )
      这样计算只有O(kn)的时间,并且如果是 只需要vector而不用reduce dimension到scalar value,就可以简单去掉 \sum^k_f=1 的loop
    """
    linear_output = self.linear(x).squeeze() #  # Linear combination
    sum_square = torch.pow(torch.matmul(x, self.V),2) # x: (batch, n), self.V: (n, embed_dimension)
    square_sum = torch.matmul(torch.pow(x,2),torch.pow(self.V,2))
    cross_fea = 0.5*(sum_square - square_sum)
    if self.reduce_sum:
      # Note: cross_fea has shape (batch size, dimension in a field) , there is only one field
      # torch.sum(dim=1): sum alone dimension of shape[1], that is along dimension of a field
      cross_fea = torch.sum(cross_fea, dim= 1)
      return linear_output+cross_fea
    return linear_output ,cross_fea


n, fea_dim,  emb_dim = 10, 5, 8
x = torch.rand((n, fea_dim))
fm = FM(fea_dim= fea_dim, embed_dim=emb_dim, reduce_sum=True)
fm2 = FM(fea_dim= fea_dim, embed_dim=emb_dim, reduce_sum=False)
y = fm(x)
y2 = fm2(x)
# torch.sum(y[1],dim=1)
y,y2

4. Reference

https://github.com/rixwew/pytorch-fm/blob/f74ad19771eda104e99874d19dc892e988ec53fa/torchfm/layer.py#L64