Dropout
Dropout 实现
class Dropout(torch.nn.Module):
def __init__(self, p=0.5):
super(Dropout, self).__init__()
# p is dropout rate. When p= 1, dropout is disabled
self.p = p
pass
def forward(self, x):
s = x.shape
dropout_mask = (torch.rand(*s)<self.p)/self.p # divided by p to rescale weight to avoid output range changes
#print(dropout_mask)
out = x* dropout_mask # element-wise multiplication
return out
size = 6
batch_size = 3
x = torch.tensor([[1.1, 1, 2, 3, 4, 5],[10, 11, 12, 13, 14, 15],[21,22,23,24,25,26]], dtype= torch.float32)
linear = nn.Linear(6, 6)
relu = nn.ReLU()
dropout = Dropout(1)
x = linear(x)
x = relu(x)
#print("relu:", x)
x = dropout(x)
x
Last updated