Deep&Cross network (DCN)

paper: https://arxiv.org/pdf/1708.05123.pdf

Implementation of DCN

import torch
from torch import nn
class CrossNet(nn.Module):
  def __init__(self, fea_dim = 5,embed_dim = 5, layer_num=3):
    super(CrossNet, self).__init__()
    self.weight = nn.Parameter(torch.randn(layer_num, fea_dim))
    self.bias = nn.Parameter(torch.randn(layer_num, fea_dim))
    self.layer_num = layer_num
  def forward(self, x):
    #
    #x: dense input  after concating embedding
    in_fea = x[:,:]
    # print(x.shape)
    for j in range(len(x)):
      x0 = x[j,:]
      for i in range(self.layer_num):
        in_fea[j,:] = x0 * torch.dot(in_fea[j,:],self.weight[i,:])  + self.bias[i,:] + in_fea[j,:]
        #print(in_fea[j,:].shape)

    # for i in range(self.layer_num):
    #   # dot product of vector weight[i,:] with all rows of input
    #   in_fea = torch.tensordot(in_fea,self.weight[i,:],dims=([1],[0]))
    #   in_fea = torch.multiply(in_fea,x)
    #   in_fea  += self.bias[i,:] + in_fea
    #   #print(in_fea[j,:].shape)
         
    return in_fea


class DCN(nn.Module):
  def __init__(self, fea_dim = 5,emb_dim=8, dense_dim =5):
    super(DCN, self).__init__()
    # using weight vector in each layer
    #embedding matrix
    self.V = nn.Parameter(torch.randn(fea_dim, emb_dim))
    
    dnn_input_size = emb_dim * fea_dim + dense_dim
    layers =[nn.Linear(dnn_input_size,256),
             nn.ReLU(),
             nn.Dropout(0.5)
             ]
    # layers += layers
    #layers += [nn.Linear(256,1)]
    self.DNN = nn.Sequential(*layers)

    self.CrossN = CrossNet(dnn_input_size, layer_num=3)
    self.linear = None#nn.Linear()

  def forward(self,x_dense, x_sparse):
    # 
    #convert feature to dense features
    embedding_list = []
    for i in range(len(x_sparse)):
      emb_vec = torch.multiply( torch.unsqueeze(x_sparse[i],1),self.V)
      emb_vec = torch.flatten(emb_vec)
      embedding_list.append(emb_vec)
    embedding_list = torch.stack(embedding_list)
    dense_input = torch.cat([embedding_list,x_dense],dim=1)
    cross_out = self.CrossN(dense_input)
    dnn_out = self.DNN(dense_input)
    #out = dnn_out + cross_out
    cat_out = torch.cat([dnn_out, cross_out],dim=1)
    if self.linear == None:
      self.linear = nn.Linear(cat_out.shape[1],1)
    out =self.linear(cat_out)
    print(dnn_out,cross_out)
    return out


fea_dim = 5
emb_dim=8
dense_dim =5
batch = 5
dcn = DCN(fea_dim = 5,emb_dim=8, dense_dim =5)


x_sparse = torch.randn(batch,fea_dim)<0.5
x_dense = torch.randn(batch,dense_dim)
out = dcn(x_dense, x_sparse)
out 

Reference:

Last updated