glucosedao_gpu / gluformer /variance.py
Livia_Zaharia
added code for the first time
bacf16b
raw
history blame
616 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
class Variance(nn.Module):
def __init__(self, d_model, r_drop, len_seq):
super(Variance, self).__init__()
self.proj1 = nn.Linear(d_model, 1)
self.dropout = nn.Dropout(r_drop)
self.activ1 = nn.ReLU()
# + 1 (for seq) for embedded person token
self.proj2 = nn.Linear(len_seq+1, 1)
self.activ2 = nn.Tanh()
def forward(self, x):
x = self.proj1(x)
x = self.activ1(x)
x = self.dropout(x)
x = x.transpose(-1, 1)
x = self.proj2(x)
# scale to [-10, 10] range
x = 10 * self.activ2(x)
return x