Find_The_Fake / models.py
Arnab Das
bug fix
5fde11f
raw
history blame
28.8 kB
import torch
import random
import fairseq
import numpy as np
import torch.nn as nn
from torch import Tensor
from typing import Union
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
class SSLModel(nn.Module):
def __init__(self, device):
super(SSLModel, self).__init__()
cp_path = hf_hub_download("arnabdas8901/aasist-trained-asvspoof2024", filename='xlsr2_300m.pt') #'xlsr2_300m.pt' # Change the pre-trained XLSR model path.
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
self.model = model[0]
self.device = device
self.out_dim = 1024
return
def extract_feat(self, input_data):
# put the model to GPU if it not there
if next(self.model.parameters()).device != input_data.device \
or next(self.model.parameters()).dtype != input_data.dtype:
self.model.to(input_data.device, dtype=input_data.dtype)
self.model.train()
if True:
# input should be in shape (batch, length)
if input_data.ndim == 3:
input_tmp = input_data[:, :, 0]
else:
input_tmp = input_data
# [batch, length, dim]
emb = self.model(input_tmp, mask=False, features_only=True)['x']
return emb
# ---------AASIST back-end------------------------#
''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans.
AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks.
In Proc. ICASSP 2022, pp: 6367--6371.'''
class GraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_weight = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x):
'''
x :(#bs, #node, #dim)
'''
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
return x
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map(self, x):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_map = torch.matmul(att_map, self.att_weight)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class HtrgGraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
self.proj_type1 = nn.Linear(in_dim, in_dim)
self.proj_type2 = nn.Linear(in_dim, in_dim)
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_projM = nn.Linear(in_dim, out_dim)
self.att_weight11 = self._init_new_params(out_dim, 1)
self.att_weight22 = self._init_new_params(out_dim, 1)
self.att_weight12 = self._init_new_params(out_dim, 1)
self.att_weightM = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
self.proj_with_attM = nn.Linear(in_dim, out_dim)
self.proj_without_attM = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x1, x2, master=None):
'''
x1 :(#bs, #node, #dim)
x2 :(#bs, #node, #dim)
'''
# print('x1',x1.shape)
# print('x2',x2.shape)
num_type1 = x1.size(1)
num_type2 = x2.size(1)
# print('num_type1',num_type1)
# print('num_type2',num_type2)
x1 = self.proj_type1(x1)
# print('proj_type1',x1.shape)
x2 = self.proj_type2(x2)
# print('proj_type2',x2.shape)
x = torch.cat([x1, x2], dim=1)
# print('Concat x1 and x2',x.shape)
if master is None:
master = torch.mean(x, dim=1, keepdim=True)
# print('master',master.shape)
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x, num_type1, num_type2)
# print('master',master.shape)
# directional edge for master node
master = self._update_master(x, master)
# print('master',master.shape)
# projection
x = self._project(x, att_map)
# print('proj x',x.shape)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
x1 = x.narrow(1, 0, num_type1)
# print('x1',x1.shape)
x2 = x.narrow(1, num_type1, num_type2)
# print('x2',x2.shape)
return x1, x2, master
def _update_master(self, x, master):
att_map = self._derive_att_map_master(x, master)
master = self._project_master(x, master, att_map)
return master
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map_master(self, x, master):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = x * master
att_map = torch.tanh(self.att_projM(att_map))
att_map = torch.matmul(att_map, self.att_weightM)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _derive_att_map(self, x, num_type1, num_type2):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
att_board[:, :num_type1, :num_type1, :] = torch.matmul(
att_map[:, :num_type1, :num_type1, :], self.att_weight11)
att_board[:, num_type1:, num_type1:, :] = torch.matmul(
att_map[:, num_type1:, num_type1:, :], self.att_weight22)
att_board[:, :num_type1, num_type1:, :] = torch.matmul(
att_map[:, :num_type1, num_type1:, :], self.att_weight12)
att_board[:, num_type1:, :num_type1, :] = torch.matmul(
att_map[:, num_type1:, :num_type1, :], self.att_weight12)
att_map = att_board
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _project_master(self, x, master, att_map):
x1 = self.proj_with_attM(torch.matmul(
att_map.squeeze(-1).unsqueeze(1), x))
x2 = self.proj_without_attM(master)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class GraphPool(nn.Module):
def __init__(self, k: float, in_dim: int, p: Union[float, int]):
super().__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
self.in_dim = in_dim
def forward(self, h):
Z = self.drop(h)
weights = self.proj(Z)
scores = self.sigmoid(weights)
new_h = self.top_k_graph(scores, h, self.k)
return new_h
def top_k_graph(self, scores, h, k):
"""
args
=====
scores: attention-based weights (#bs, #node, 1)
h: graph data (#bs, #node, #dim)
k: ratio of remaining nodes, (float)
returns
=====
h: graph pool applied data (#bs, #node', #dim)
"""
_, n_nodes, n_feat = h.size()
n_nodes = max(int(n_nodes * k), 1)
_, idx = torch.topk(scores, n_nodes, dim=1)
idx = idx.expand(-1, -1, n_feat)
h = h * scores
h = torch.gather(h, 1, idx)
return h
class Residual_block(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(1, 1),
stride=1)
self.selu = nn.SELU(inplace=True)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(0, 1),
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.selu(out)
else:
out = x
# print('out',out.shape)
out = self.conv1(x)
# print('aft conv1 out',out.shape)
out = self.bn2(out)
out = self.selu(out)
# print('out',out.shape)
out = self.conv2(out)
# print('conv2 out',out.shape)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
# out = self.mp(out)
return out
class Residual_block_aasist(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(1, 1),
stride=1)
self.selu = nn.SELU(inplace=True)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(0, 1),
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
self.mp = nn.MaxPool2d((1, 3))
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.selu(out)
else:
out = x
out = self.conv1(x)
# print('aft conv1 out',out.shape)
out = self.bn2(out)
out = self.selu(out)
# print('out',out.shape)
out = self.conv2(out)
# print('conv2 out',out.shape)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class Model(nn.Module):
def __init__(self, args, device):
super().__init__()
self.device = device
# AASIST parameters
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
gat_dims = [64, 32]
pool_ratios = [0.5, 0.5, 0.5, 0.5]
temperatures = [2.0, 2.0, 100.0, 100.0]
####
# create network wav2vec 2.0
####
self.ssl_model = SSLModel(self.device)
self.LL = nn.Linear(self.ssl_model.out_dim, 128)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.first_bn1 = nn.BatchNorm2d(num_features=64)
self.drop = nn.Dropout(0.5, inplace=True)
self.drop_way = nn.Dropout(0.2, inplace=True)
self.selu = nn.SELU(inplace=True)
# RawNet2 encoder
self.encoder = nn.Sequential(
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block(nb_filts=filts[2])),
nn.Sequential(Residual_block(nb_filts=filts[3])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])))
self.attention = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(1, 1)),
nn.SELU(inplace=True),
nn.BatchNorm2d(128),
nn.Conv2d(128, 64, kernel_size=(1, 1)),
)
# position encoding
self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
# Graph module
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[0])
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[1])
# HS-GAL layer
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
# Graph pooling layers
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
def forward(self, x):
# -------pre-trained Wav2vec model fine tunning ------------------------##
x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
x = self.LL(x_ssl_feat) # (bs,frame_number,feat_out_dim)
# post-processing on front-end features
x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number)
x = x.unsqueeze(dim=1) # add channel
x = F.max_pool2d(x, (3, 3))
x = self.first_bn(x)
x = self.selu(x)
# RawNet2-based encoder
x = self.encoder(x)
x = self.first_bn1(x)
x = self.selu(x)
w = self.attention(x)
# ------------SA for spectral feature-------------#
w1 = F.softmax(w, dim=-1)
m = torch.sum(x * w1, dim=-1)
e_S = m.transpose(1, 2) + self.pos_S
# graph module layer
gat_S = self.GAT_layer_S(e_S)
out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
# ------------SA for temporal feature-------------#
w2 = F.softmax(w, dim=-2)
m1 = torch.sum(x * w2, dim=-2)
e_T = m1.transpose(1, 2)
# graph module layer
gat_T = self.GAT_layer_T(e_T)
out_T = self.pool_T(gat_T)
# learnable master node
master1 = self.master1.expand(x.size(0), -1, -1)
master2 = self.master2.expand(x.size(0), -1, -1)
# inference 1
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
out_T, out_S, master=self.master1)
out_S1 = self.pool_hS1(out_S1)
out_T1 = self.pool_hT1(out_T1)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
out_T1, out_S1, master=master1)
out_T1 = out_T1 + out_T_aug
out_S1 = out_S1 + out_S_aug
master1 = master1 + master_aug
# inference 2
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
out_T, out_S, master=self.master2)
out_S2 = self.pool_hS2(out_S2)
out_T2 = self.pool_hT2(out_T2)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
out_T2, out_S2, master=master2)
out_T2 = out_T2 + out_T_aug
out_S2 = out_S2 + out_S_aug
master2 = master2 + master_aug
out_T1 = self.drop_way(out_T1)
out_T2 = self.drop_way(out_T2)
out_S1 = self.drop_way(out_S1)
out_S2 = self.drop_way(out_S2)
master1 = self.drop_way(master1)
master2 = self.drop_way(master2)
out_T = torch.max(out_T1, out_T2)
out_S = torch.max(out_S1, out_S2)
master = torch.max(master1, master2)
# Readout operation
T_max, _ = torch.max(torch.abs(out_T), dim=1)
T_avg = torch.mean(out_T, dim=1)
S_max, _ = torch.max(torch.abs(out_S), dim=1)
S_avg = torch.mean(out_S, dim=1)
last_hidden = torch.cat(
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
last_hidden = self.drop(last_hidden)
output = self.out_layer(last_hidden)
return output
class CONV(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10**(mel / 2595) - 1)
def __init__(self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
mask=False):
super().__init__()
if in_channels != 1:
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size % 2 == 0:
self.kernel_size = self.kernel_size + 1
self.stride = stride
self.padding = padding
self.dilation = dilation
self.mask = mask
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
NFFT = 512
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
fmel = self.to_mel(f)
fmelmax = np.max(fmel)
fmelmin = np.min(fmel)
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
filbandwidthsf = self.to_hz(filbandwidthsmel)
self.mel = filbandwidthsf
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
(self.kernel_size - 1) / 2 + 1)
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
for i in range(len(self.mel) - 1):
fmin = self.mel[i]
fmax = self.mel[i + 1]
hHigh = (2*fmax/self.sample_rate) * \
np.sinc(2*fmax*self.hsupp/self.sample_rate)
hLow = (2*fmin/self.sample_rate) * \
np.sinc(2*fmin*self.hsupp/self.sample_rate)
hideal = hHigh - hLow
self.band_pass[i, :] = Tensor(np.hamming(
self.kernel_size)) * Tensor(hideal)
def forward(self, x, mask=False):
band_pass_filter = self.band_pass.clone().to(x.device)
if mask:
A = np.random.uniform(0, 20)
A = int(A)
A0 = random.randint(0, band_pass_filter.shape[0] - A)
band_pass_filter[A0:A0 + A, :] = 0
else:
band_pass_filter = band_pass_filter
self.filters = (band_pass_filter).view(self.out_channels, 1,
self.kernel_size)
return F.conv1d(x,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1)
class AASIST_Model(nn.Module):
def __init__(self, args, device):
super().__init__()
filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]]
gat_dims = [64, 32]
pool_ratios =[0.5, 0.7, 0.5, 0.5]
temperatures =[2.0, 2.0, 100.0, 100.0]
self.conv_time = CONV(out_channels=filts[0],
kernel_size=128,
in_channels=1)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.drop = nn.Dropout(0.5, inplace=True)
self.drop_way = nn.Dropout(0.2, inplace=True)
self.selu = nn.SELU(inplace=True)
self.encoder = nn.Sequential(
nn.Sequential(Residual_block_aasist(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block_aasist(nb_filts=filts[2])),
nn.Sequential(Residual_block_aasist(nb_filts=filts[3])),
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])),
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])))
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[0])
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[1])
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
def forward(self, x, Freq_aug=False):
x = x.unsqueeze(1)
x = self.conv_time(x, mask=Freq_aug)
x = x.unsqueeze(dim=1)
x = F.max_pool2d(torch.abs(x), (3, 3))
x = self.first_bn(x)
x = self.selu(x)
# get embeddings using encoder
# (#bs, #filt, #spec, #seq)
e = self.encoder(x)
# spectral GAT (GAT-S)
e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
e_S = e_S.transpose(1, 2) + self.pos_S
gat_S = self.GAT_layer_S(e_S)
out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
# temporal GAT (GAT-T)
e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
e_T = e_T.transpose(1, 2)
gat_T = self.GAT_layer_T(e_T)
out_T = self.pool_T(gat_T)
# learnable master node
master1 = self.master1.expand(x.size(0), -1, -1)
master2 = self.master2.expand(x.size(0), -1, -1)
# inference 1
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
out_T, out_S, master=self.master1)
out_S1 = self.pool_hS1(out_S1)
out_T1 = self.pool_hT1(out_T1)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
out_T1, out_S1, master=master1)
out_T1 = out_T1 + out_T_aug
out_S1 = out_S1 + out_S_aug
master1 = master1 + master_aug
# inference 2
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
out_T, out_S, master=self.master2)
out_S2 = self.pool_hS2(out_S2)
out_T2 = self.pool_hT2(out_T2)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
out_T2, out_S2, master=master2)
out_T2 = out_T2 + out_T_aug
out_S2 = out_S2 + out_S_aug
master2 = master2 + master_aug
out_T1 = self.drop_way(out_T1)
out_T2 = self.drop_way(out_T2)
out_S1 = self.drop_way(out_S1)
out_S2 = self.drop_way(out_S2)
master1 = self.drop_way(master1)
master2 = self.drop_way(master2)
out_T = torch.max(out_T1, out_T2)
out_S = torch.max(out_S1, out_S2)
master = torch.max(master1, master2)
T_max, _ = torch.max(torch.abs(out_T), dim=1)
T_avg = torch.mean(out_T, dim=1)
S_max, _ = torch.max(torch.abs(out_S), dim=1)
S_avg = torch.mean(out_S, dim=1)
last_hidden = torch.cat(
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
last_hidden = self.drop(last_hidden)
output = self.out_layer(last_hidden)
return output