Spaces:
Running
Running
import torch | |
import random | |
import fairseq | |
import numpy as np | |
import torch.nn as nn | |
from torch import Tensor | |
from typing import Union | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
class SSLModel(nn.Module): | |
def __init__(self, device): | |
super(SSLModel, self).__init__() | |
cp_path = hf_hub_download("arnabdas8901/aasist-trained-asvspoof2024", filename='xlsr2_300m.pt') #'xlsr2_300m.pt' # Change the pre-trained XLSR model path. | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) | |
self.model = model[0] | |
self.device = device | |
self.out_dim = 1024 | |
return | |
def extract_feat(self, input_data): | |
# put the model to GPU if it not there | |
if next(self.model.parameters()).device != input_data.device \ | |
or next(self.model.parameters()).dtype != input_data.dtype: | |
self.model.to(input_data.device, dtype=input_data.dtype) | |
self.model.train() | |
if True: | |
# input should be in shape (batch, length) | |
if input_data.ndim == 3: | |
input_tmp = input_data[:, :, 0] | |
else: | |
input_tmp = input_data | |
# [batch, length, dim] | |
emb = self.model(input_tmp, mask=False, features_only=True)['x'] | |
return emb | |
# ---------AASIST back-end------------------------# | |
''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. | |
AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. | |
In Proc. ICASSP 2022, pp: 6367--6371.''' | |
class GraphAttentionLayer(nn.Module): | |
def __init__(self, in_dim, out_dim, **kwargs): | |
super().__init__() | |
# attention map | |
self.att_proj = nn.Linear(in_dim, out_dim) | |
self.att_weight = self._init_new_params(out_dim, 1) | |
# project | |
self.proj_with_att = nn.Linear(in_dim, out_dim) | |
self.proj_without_att = nn.Linear(in_dim, out_dim) | |
# batch norm | |
self.bn = nn.BatchNorm1d(out_dim) | |
# dropout for inputs | |
self.input_drop = nn.Dropout(p=0.2) | |
# activate | |
self.act = nn.SELU(inplace=True) | |
# temperature | |
self.temp = 1. | |
if "temperature" in kwargs: | |
self.temp = kwargs["temperature"] | |
def forward(self, x): | |
''' | |
x :(#bs, #node, #dim) | |
''' | |
# apply input dropout | |
x = self.input_drop(x) | |
# derive attention map | |
att_map = self._derive_att_map(x) | |
# projection | |
x = self._project(x, att_map) | |
# apply batch norm | |
x = self._apply_BN(x) | |
x = self.act(x) | |
return x | |
def _pairwise_mul_nodes(self, x): | |
''' | |
Calculates pairwise multiplication of nodes. | |
- for attention map | |
x :(#bs, #node, #dim) | |
out_shape :(#bs, #node, #node, #dim) | |
''' | |
nb_nodes = x.size(1) | |
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) | |
x_mirror = x.transpose(1, 2) | |
return x * x_mirror | |
def _derive_att_map(self, x): | |
''' | |
x :(#bs, #node, #dim) | |
out_shape :(#bs, #node, #node, 1) | |
''' | |
att_map = self._pairwise_mul_nodes(x) | |
# size: (#bs, #node, #node, #dim_out) | |
att_map = torch.tanh(self.att_proj(att_map)) | |
# size: (#bs, #node, #node, 1) | |
att_map = torch.matmul(att_map, self.att_weight) | |
# apply temperature | |
att_map = att_map / self.temp | |
att_map = F.softmax(att_map, dim=-2) | |
return att_map | |
def _project(self, x, att_map): | |
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) | |
x2 = self.proj_without_att(x) | |
return x1 + x2 | |
def _apply_BN(self, x): | |
org_size = x.size() | |
x = x.view(-1, org_size[-1]) | |
x = self.bn(x) | |
x = x.view(org_size) | |
return x | |
def _init_new_params(self, *size): | |
out = nn.Parameter(torch.FloatTensor(*size)) | |
nn.init.xavier_normal_(out) | |
return out | |
class HtrgGraphAttentionLayer(nn.Module): | |
def __init__(self, in_dim, out_dim, **kwargs): | |
super().__init__() | |
self.proj_type1 = nn.Linear(in_dim, in_dim) | |
self.proj_type2 = nn.Linear(in_dim, in_dim) | |
# attention map | |
self.att_proj = nn.Linear(in_dim, out_dim) | |
self.att_projM = nn.Linear(in_dim, out_dim) | |
self.att_weight11 = self._init_new_params(out_dim, 1) | |
self.att_weight22 = self._init_new_params(out_dim, 1) | |
self.att_weight12 = self._init_new_params(out_dim, 1) | |
self.att_weightM = self._init_new_params(out_dim, 1) | |
# project | |
self.proj_with_att = nn.Linear(in_dim, out_dim) | |
self.proj_without_att = nn.Linear(in_dim, out_dim) | |
self.proj_with_attM = nn.Linear(in_dim, out_dim) | |
self.proj_without_attM = nn.Linear(in_dim, out_dim) | |
# batch norm | |
self.bn = nn.BatchNorm1d(out_dim) | |
# dropout for inputs | |
self.input_drop = nn.Dropout(p=0.2) | |
# activate | |
self.act = nn.SELU(inplace=True) | |
# temperature | |
self.temp = 1. | |
if "temperature" in kwargs: | |
self.temp = kwargs["temperature"] | |
def forward(self, x1, x2, master=None): | |
''' | |
x1 :(#bs, #node, #dim) | |
x2 :(#bs, #node, #dim) | |
''' | |
# print('x1',x1.shape) | |
# print('x2',x2.shape) | |
num_type1 = x1.size(1) | |
num_type2 = x2.size(1) | |
# print('num_type1',num_type1) | |
# print('num_type2',num_type2) | |
x1 = self.proj_type1(x1) | |
# print('proj_type1',x1.shape) | |
x2 = self.proj_type2(x2) | |
# print('proj_type2',x2.shape) | |
x = torch.cat([x1, x2], dim=1) | |
# print('Concat x1 and x2',x.shape) | |
if master is None: | |
master = torch.mean(x, dim=1, keepdim=True) | |
# print('master',master.shape) | |
# apply input dropout | |
x = self.input_drop(x) | |
# derive attention map | |
att_map = self._derive_att_map(x, num_type1, num_type2) | |
# print('master',master.shape) | |
# directional edge for master node | |
master = self._update_master(x, master) | |
# print('master',master.shape) | |
# projection | |
x = self._project(x, att_map) | |
# print('proj x',x.shape) | |
# apply batch norm | |
x = self._apply_BN(x) | |
x = self.act(x) | |
x1 = x.narrow(1, 0, num_type1) | |
# print('x1',x1.shape) | |
x2 = x.narrow(1, num_type1, num_type2) | |
# print('x2',x2.shape) | |
return x1, x2, master | |
def _update_master(self, x, master): | |
att_map = self._derive_att_map_master(x, master) | |
master = self._project_master(x, master, att_map) | |
return master | |
def _pairwise_mul_nodes(self, x): | |
''' | |
Calculates pairwise multiplication of nodes. | |
- for attention map | |
x :(#bs, #node, #dim) | |
out_shape :(#bs, #node, #node, #dim) | |
''' | |
nb_nodes = x.size(1) | |
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) | |
x_mirror = x.transpose(1, 2) | |
return x * x_mirror | |
def _derive_att_map_master(self, x, master): | |
''' | |
x :(#bs, #node, #dim) | |
out_shape :(#bs, #node, #node, 1) | |
''' | |
att_map = x * master | |
att_map = torch.tanh(self.att_projM(att_map)) | |
att_map = torch.matmul(att_map, self.att_weightM) | |
# apply temperature | |
att_map = att_map / self.temp | |
att_map = F.softmax(att_map, dim=-2) | |
return att_map | |
def _derive_att_map(self, x, num_type1, num_type2): | |
''' | |
x :(#bs, #node, #dim) | |
out_shape :(#bs, #node, #node, 1) | |
''' | |
att_map = self._pairwise_mul_nodes(x) | |
# size: (#bs, #node, #node, #dim_out) | |
att_map = torch.tanh(self.att_proj(att_map)) | |
# size: (#bs, #node, #node, 1) | |
att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) | |
att_board[:, :num_type1, :num_type1, :] = torch.matmul( | |
att_map[:, :num_type1, :num_type1, :], self.att_weight11) | |
att_board[:, num_type1:, num_type1:, :] = torch.matmul( | |
att_map[:, num_type1:, num_type1:, :], self.att_weight22) | |
att_board[:, :num_type1, num_type1:, :] = torch.matmul( | |
att_map[:, :num_type1, num_type1:, :], self.att_weight12) | |
att_board[:, num_type1:, :num_type1, :] = torch.matmul( | |
att_map[:, num_type1:, :num_type1, :], self.att_weight12) | |
att_map = att_board | |
# apply temperature | |
att_map = att_map / self.temp | |
att_map = F.softmax(att_map, dim=-2) | |
return att_map | |
def _project(self, x, att_map): | |
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) | |
x2 = self.proj_without_att(x) | |
return x1 + x2 | |
def _project_master(self, x, master, att_map): | |
x1 = self.proj_with_attM(torch.matmul( | |
att_map.squeeze(-1).unsqueeze(1), x)) | |
x2 = self.proj_without_attM(master) | |
return x1 + x2 | |
def _apply_BN(self, x): | |
org_size = x.size() | |
x = x.view(-1, org_size[-1]) | |
x = self.bn(x) | |
x = x.view(org_size) | |
return x | |
def _init_new_params(self, *size): | |
out = nn.Parameter(torch.FloatTensor(*size)) | |
nn.init.xavier_normal_(out) | |
return out | |
class GraphPool(nn.Module): | |
def __init__(self, k: float, in_dim: int, p: Union[float, int]): | |
super().__init__() | |
self.k = k | |
self.sigmoid = nn.Sigmoid() | |
self.proj = nn.Linear(in_dim, 1) | |
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() | |
self.in_dim = in_dim | |
def forward(self, h): | |
Z = self.drop(h) | |
weights = self.proj(Z) | |
scores = self.sigmoid(weights) | |
new_h = self.top_k_graph(scores, h, self.k) | |
return new_h | |
def top_k_graph(self, scores, h, k): | |
""" | |
args | |
===== | |
scores: attention-based weights (#bs, #node, 1) | |
h: graph data (#bs, #node, #dim) | |
k: ratio of remaining nodes, (float) | |
returns | |
===== | |
h: graph pool applied data (#bs, #node', #dim) | |
""" | |
_, n_nodes, n_feat = h.size() | |
n_nodes = max(int(n_nodes * k), 1) | |
_, idx = torch.topk(scores, n_nodes, dim=1) | |
idx = idx.expand(-1, -1, n_feat) | |
h = h * scores | |
h = torch.gather(h, 1, idx) | |
return h | |
class Residual_block(nn.Module): | |
def __init__(self, nb_filts, first=False): | |
super().__init__() | |
self.first = first | |
if not self.first: | |
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) | |
self.conv1 = nn.Conv2d(in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
kernel_size=(2, 3), | |
padding=(1, 1), | |
stride=1) | |
self.selu = nn.SELU(inplace=True) | |
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) | |
self.conv2 = nn.Conv2d(in_channels=nb_filts[1], | |
out_channels=nb_filts[1], | |
kernel_size=(2, 3), | |
padding=(0, 1), | |
stride=1) | |
if nb_filts[0] != nb_filts[1]: | |
self.downsample = True | |
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
padding=(0, 1), | |
kernel_size=(1, 3), | |
stride=1) | |
else: | |
self.downsample = False | |
def forward(self, x): | |
identity = x | |
if not self.first: | |
out = self.bn1(x) | |
out = self.selu(out) | |
else: | |
out = x | |
# print('out',out.shape) | |
out = self.conv1(x) | |
# print('aft conv1 out',out.shape) | |
out = self.bn2(out) | |
out = self.selu(out) | |
# print('out',out.shape) | |
out = self.conv2(out) | |
# print('conv2 out',out.shape) | |
if self.downsample: | |
identity = self.conv_downsample(identity) | |
out += identity | |
# out = self.mp(out) | |
return out | |
class Residual_block_aasist(nn.Module): | |
def __init__(self, nb_filts, first=False): | |
super().__init__() | |
self.first = first | |
if not self.first: | |
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) | |
self.conv1 = nn.Conv2d(in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
kernel_size=(2, 3), | |
padding=(1, 1), | |
stride=1) | |
self.selu = nn.SELU(inplace=True) | |
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) | |
self.conv2 = nn.Conv2d(in_channels=nb_filts[1], | |
out_channels=nb_filts[1], | |
kernel_size=(2, 3), | |
padding=(0, 1), | |
stride=1) | |
if nb_filts[0] != nb_filts[1]: | |
self.downsample = True | |
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], | |
out_channels=nb_filts[1], | |
padding=(0, 1), | |
kernel_size=(1, 3), | |
stride=1) | |
else: | |
self.downsample = False | |
self.mp = nn.MaxPool2d((1, 3)) | |
def forward(self, x): | |
identity = x | |
if not self.first: | |
out = self.bn1(x) | |
out = self.selu(out) | |
else: | |
out = x | |
out = self.conv1(x) | |
# print('aft conv1 out',out.shape) | |
out = self.bn2(out) | |
out = self.selu(out) | |
# print('out',out.shape) | |
out = self.conv2(out) | |
# print('conv2 out',out.shape) | |
if self.downsample: | |
identity = self.conv_downsample(identity) | |
out += identity | |
out = self.mp(out) | |
return out | |
class Model(nn.Module): | |
def __init__(self, args, device): | |
super().__init__() | |
self.device = device | |
# AASIST parameters | |
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] | |
gat_dims = [64, 32] | |
pool_ratios = [0.5, 0.5, 0.5, 0.5] | |
temperatures = [2.0, 2.0, 100.0, 100.0] | |
#### | |
# create network wav2vec 2.0 | |
#### | |
self.ssl_model = SSLModel(self.device) | |
self.LL = nn.Linear(self.ssl_model.out_dim, 128) | |
self.first_bn = nn.BatchNorm2d(num_features=1) | |
self.first_bn1 = nn.BatchNorm2d(num_features=64) | |
self.drop = nn.Dropout(0.5, inplace=True) | |
self.drop_way = nn.Dropout(0.2, inplace=True) | |
self.selu = nn.SELU(inplace=True) | |
# RawNet2 encoder | |
self.encoder = nn.Sequential( | |
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), | |
nn.Sequential(Residual_block(nb_filts=filts[2])), | |
nn.Sequential(Residual_block(nb_filts=filts[3])), | |
nn.Sequential(Residual_block(nb_filts=filts[4])), | |
nn.Sequential(Residual_block(nb_filts=filts[4])), | |
nn.Sequential(Residual_block(nb_filts=filts[4]))) | |
self.attention = nn.Sequential( | |
nn.Conv2d(64, 128, kernel_size=(1, 1)), | |
nn.SELU(inplace=True), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(128, 64, kernel_size=(1, 1)), | |
) | |
# position encoding | |
self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) | |
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
# Graph module | |
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], | |
gat_dims[0], | |
temperature=temperatures[0]) | |
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], | |
gat_dims[0], | |
temperature=temperatures[1]) | |
# HS-GAL layer | |
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( | |
gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( | |
gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( | |
gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( | |
gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
# Graph pooling layers | |
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) | |
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) | |
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.out_layer = nn.Linear(5 * gat_dims[1], 2) | |
def forward(self, x): | |
# -------pre-trained Wav2vec model fine tunning ------------------------## | |
x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1)) | |
x = self.LL(x_ssl_feat) # (bs,frame_number,feat_out_dim) | |
# post-processing on front-end features | |
x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number) | |
x = x.unsqueeze(dim=1) # add channel | |
x = F.max_pool2d(x, (3, 3)) | |
x = self.first_bn(x) | |
x = self.selu(x) | |
# RawNet2-based encoder | |
x = self.encoder(x) | |
x = self.first_bn1(x) | |
x = self.selu(x) | |
w = self.attention(x) | |
# ------------SA for spectral feature-------------# | |
w1 = F.softmax(w, dim=-1) | |
m = torch.sum(x * w1, dim=-1) | |
e_S = m.transpose(1, 2) + self.pos_S | |
# graph module layer | |
gat_S = self.GAT_layer_S(e_S) | |
out_S = self.pool_S(gat_S) # (#bs, #node, #dim) | |
# ------------SA for temporal feature-------------# | |
w2 = F.softmax(w, dim=-2) | |
m1 = torch.sum(x * w2, dim=-2) | |
e_T = m1.transpose(1, 2) | |
# graph module layer | |
gat_T = self.GAT_layer_T(e_T) | |
out_T = self.pool_T(gat_T) | |
# learnable master node | |
master1 = self.master1.expand(x.size(0), -1, -1) | |
master2 = self.master2.expand(x.size(0), -1, -1) | |
# inference 1 | |
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( | |
out_T, out_S, master=self.master1) | |
out_S1 = self.pool_hS1(out_S1) | |
out_T1 = self.pool_hT1(out_T1) | |
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( | |
out_T1, out_S1, master=master1) | |
out_T1 = out_T1 + out_T_aug | |
out_S1 = out_S1 + out_S_aug | |
master1 = master1 + master_aug | |
# inference 2 | |
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( | |
out_T, out_S, master=self.master2) | |
out_S2 = self.pool_hS2(out_S2) | |
out_T2 = self.pool_hT2(out_T2) | |
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( | |
out_T2, out_S2, master=master2) | |
out_T2 = out_T2 + out_T_aug | |
out_S2 = out_S2 + out_S_aug | |
master2 = master2 + master_aug | |
out_T1 = self.drop_way(out_T1) | |
out_T2 = self.drop_way(out_T2) | |
out_S1 = self.drop_way(out_S1) | |
out_S2 = self.drop_way(out_S2) | |
master1 = self.drop_way(master1) | |
master2 = self.drop_way(master2) | |
out_T = torch.max(out_T1, out_T2) | |
out_S = torch.max(out_S1, out_S2) | |
master = torch.max(master1, master2) | |
# Readout operation | |
T_max, _ = torch.max(torch.abs(out_T), dim=1) | |
T_avg = torch.mean(out_T, dim=1) | |
S_max, _ = torch.max(torch.abs(out_S), dim=1) | |
S_avg = torch.mean(out_S, dim=1) | |
last_hidden = torch.cat( | |
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) | |
last_hidden = self.drop(last_hidden) | |
output = self.out_layer(last_hidden) | |
return output | |
class CONV(nn.Module): | |
def to_mel(hz): | |
return 2595 * np.log10(1 + hz / 700) | |
def to_hz(mel): | |
return 700 * (10**(mel / 2595) - 1) | |
def __init__(self, | |
out_channels, | |
kernel_size, | |
sample_rate=16000, | |
in_channels=1, | |
stride=1, | |
padding=0, | |
dilation=1, | |
bias=False, | |
groups=1, | |
mask=False): | |
super().__init__() | |
if in_channels != 1: | |
msg = "SincConv only support one input channel (here, in_channels = {%i})" % ( | |
in_channels) | |
raise ValueError(msg) | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.sample_rate = sample_rate | |
# Forcing the filters to be odd (i.e, perfectly symmetrics) | |
if kernel_size % 2 == 0: | |
self.kernel_size = self.kernel_size + 1 | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.mask = mask | |
if bias: | |
raise ValueError('SincConv does not support bias.') | |
if groups > 1: | |
raise ValueError('SincConv does not support groups.') | |
NFFT = 512 | |
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1) | |
fmel = self.to_mel(f) | |
fmelmax = np.max(fmel) | |
fmelmin = np.min(fmel) | |
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1) | |
filbandwidthsf = self.to_hz(filbandwidthsmel) | |
self.mel = filbandwidthsf | |
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, | |
(self.kernel_size - 1) / 2 + 1) | |
self.band_pass = torch.zeros(self.out_channels, self.kernel_size) | |
for i in range(len(self.mel) - 1): | |
fmin = self.mel[i] | |
fmax = self.mel[i + 1] | |
hHigh = (2*fmax/self.sample_rate) * \ | |
np.sinc(2*fmax*self.hsupp/self.sample_rate) | |
hLow = (2*fmin/self.sample_rate) * \ | |
np.sinc(2*fmin*self.hsupp/self.sample_rate) | |
hideal = hHigh - hLow | |
self.band_pass[i, :] = Tensor(np.hamming( | |
self.kernel_size)) * Tensor(hideal) | |
def forward(self, x, mask=False): | |
band_pass_filter = self.band_pass.clone().to(x.device) | |
if mask: | |
A = np.random.uniform(0, 20) | |
A = int(A) | |
A0 = random.randint(0, band_pass_filter.shape[0] - A) | |
band_pass_filter[A0:A0 + A, :] = 0 | |
else: | |
band_pass_filter = band_pass_filter | |
self.filters = (band_pass_filter).view(self.out_channels, 1, | |
self.kernel_size) | |
return F.conv1d(x, | |
self.filters, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
bias=None, | |
groups=1) | |
class AASIST_Model(nn.Module): | |
def __init__(self, args, device): | |
super().__init__() | |
filts = [70, [1, 32], [32, 32], [32, 64], [64, 64]] | |
gat_dims = [64, 32] | |
pool_ratios =[0.5, 0.7, 0.5, 0.5] | |
temperatures =[2.0, 2.0, 100.0, 100.0] | |
self.conv_time = CONV(out_channels=filts[0], | |
kernel_size=128, | |
in_channels=1) | |
self.first_bn = nn.BatchNorm2d(num_features=1) | |
self.drop = nn.Dropout(0.5, inplace=True) | |
self.drop_way = nn.Dropout(0.2, inplace=True) | |
self.selu = nn.SELU(inplace=True) | |
self.encoder = nn.Sequential( | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[1], first=True)), | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[2])), | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[3])), | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])), | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[4])), | |
nn.Sequential(Residual_block_aasist(nb_filts=filts[4]))) | |
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1])) | |
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) | |
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], | |
gat_dims[0], | |
temperature=temperatures[0]) | |
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], | |
gat_dims[0], | |
temperature=temperatures[1]) | |
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( | |
gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( | |
gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( | |
gat_dims[0], gat_dims[1], temperature=temperatures[2]) | |
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( | |
gat_dims[1], gat_dims[1], temperature=temperatures[2]) | |
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) | |
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) | |
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) | |
self.out_layer = nn.Linear(5 * gat_dims[1], 2) | |
def forward(self, x, Freq_aug=False): | |
x = x.unsqueeze(1) | |
x = self.conv_time(x, mask=Freq_aug) | |
x = x.unsqueeze(dim=1) | |
x = F.max_pool2d(torch.abs(x), (3, 3)) | |
x = self.first_bn(x) | |
x = self.selu(x) | |
# get embeddings using encoder | |
# (#bs, #filt, #spec, #seq) | |
e = self.encoder(x) | |
# spectral GAT (GAT-S) | |
e_S, _ = torch.max(torch.abs(e), dim=3) # max along time | |
e_S = e_S.transpose(1, 2) + self.pos_S | |
gat_S = self.GAT_layer_S(e_S) | |
out_S = self.pool_S(gat_S) # (#bs, #node, #dim) | |
# temporal GAT (GAT-T) | |
e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq | |
e_T = e_T.transpose(1, 2) | |
gat_T = self.GAT_layer_T(e_T) | |
out_T = self.pool_T(gat_T) | |
# learnable master node | |
master1 = self.master1.expand(x.size(0), -1, -1) | |
master2 = self.master2.expand(x.size(0), -1, -1) | |
# inference 1 | |
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( | |
out_T, out_S, master=self.master1) | |
out_S1 = self.pool_hS1(out_S1) | |
out_T1 = self.pool_hT1(out_T1) | |
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( | |
out_T1, out_S1, master=master1) | |
out_T1 = out_T1 + out_T_aug | |
out_S1 = out_S1 + out_S_aug | |
master1 = master1 + master_aug | |
# inference 2 | |
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( | |
out_T, out_S, master=self.master2) | |
out_S2 = self.pool_hS2(out_S2) | |
out_T2 = self.pool_hT2(out_T2) | |
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( | |
out_T2, out_S2, master=master2) | |
out_T2 = out_T2 + out_T_aug | |
out_S2 = out_S2 + out_S_aug | |
master2 = master2 + master_aug | |
out_T1 = self.drop_way(out_T1) | |
out_T2 = self.drop_way(out_T2) | |
out_S1 = self.drop_way(out_S1) | |
out_S2 = self.drop_way(out_S2) | |
master1 = self.drop_way(master1) | |
master2 = self.drop_way(master2) | |
out_T = torch.max(out_T1, out_T2) | |
out_S = torch.max(out_S1, out_S2) | |
master = torch.max(master1, master2) | |
T_max, _ = torch.max(torch.abs(out_T), dim=1) | |
T_avg = torch.mean(out_T, dim=1) | |
S_max, _ = torch.max(torch.abs(out_S), dim=1) | |
S_avg = torch.mean(out_S, dim=1) | |
last_hidden = torch.cat( | |
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) | |
last_hidden = self.drop(last_hidden) | |
output = self.out_layer(last_hidden) | |
return output | |