alibabasglab's picture
Upload 161 files
8e8cd3e verified
raw
history blame
17.1 kB
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
from models.mossformer_gan_se.conformer import ConformerBlock
class LearnableSigmoid(nn.Module):
"""A learnable sigmoid activation function that scales the output
based on the input features.
Args:
in_features (int): The number of input features for the sigmoid function.
beta (float, optional): A scaling factor for the sigmoid output. Default is 1.
Attributes:
beta (float): The scaling factor for the sigmoid function.
slope (Parameter): Learnable parameter that adjusts the slope of the sigmoid.
"""
def __init__(self, in_features, beta=1):
"""Initializes the LearnableSigmoid module.
Args:
in_features (int): Number of input features.
beta (float, optional): Scaling factor for the sigmoid output.
"""
super().__init__()
self.beta = beta # Scaling factor for the sigmoid
self.slope = nn.Parameter(torch.ones(in_features)) # Learnable slope parameter
self.slope.requiresGrad = True # Ensure gradient updates
def forward(self, x):
"""Forward pass of the learnable sigmoid function.
Args:
x (torch.Tensor): Input tensor with shape [batch_size, in_features].
Returns:
torch.Tensor: The scaled sigmoid output tensor.
"""
return self.beta * torch.sigmoid(self.slope * x)
#%% Spectrograms
def segment_specs(y, seg_length=15, seg_hop=4, max_length=None):
"""Segments a spectrogram into smaller segments for input to a CNN.
Each segment includes neighboring frequency bins to preserve
contextual information.
Args:
y (torch.Tensor): Input spectrogram tensor of shape [B, H, W],
where B is batch size, H is number of mel bands,
and W is the length of the spectrogram.
seg_length (int): Length of each segment (must be odd). Default is 15.
seg_hop (int): Hop length for segmenting the spectrogram. Default is 4.
max_length (int, optional): Maximum number of windows allowed. If the number of
windows exceeds this, a ValueError is raised.
Returns:
torch.Tensor: Segmented tensor with shape [B*n, C, H, seg_length], where n is the
number of segments, C is the number of channels (always 1).
Raises:
ValueError: If seg_length is even or if the number of windows exceeds max_length.
"""
# Ensure segment length is odd
if seg_length % 2 == 0:
raise ValueError('seg_length must be odd! (seg_length={})'.format(seg_length))
# Convert input to tensor if it's not already
if not torch.is_tensor(y):
y = torch.tensor(y)
B, _, _ = y.size() # Extract batch size and dimensions
for b in range(B):
x = y[b, :, :] # Extract the current batch's spectrogram
n_wins = x.shape[1] - (seg_length - 1) # Calculate number of windows
# Segment the mel-spectrogram
idx1 = torch.arange(seg_length) # Indices for segment length
idx2 = torch.arange(n_wins) # Indices for number of windows
idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1) # Create indices for segments
x = x.transpose(1, 0)[idx3, :].unsqueeze(1).transpose(3, 2) # Rearrange dimensions for CNN input
# Adjust segments based on hop length
if seg_hop > 1:
x = x[::seg_hop, :] # Downsample segments
n_wins = int(np.ceil(n_wins / seg_hop)) # Update number of windows
# Pad the segments if max_length is specified
if max_length is not None:
if max_length < n_wins:
raise ValueError('n_wins {} > max_length {}. Increase max window length max_segments!'.format(n_wins, max_length))
x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3])) # Create a padded tensor
x_padded[:n_wins, :] = x # Fill the padded tensor with the segments
x = x_padded # Update x to the padded tensor
# Concatenate segments from each batch
if b == 0:
z = x.unsqueeze(0) # Initialize z for the first batch
else:
z = torch.cat((z, x.unsqueeze(0)), axis=0) # Concatenate to z
# Reshape the final tensor for output
B, n, c, f, t = z.size()
z = z.view(B * n, c, f, t) # Combine batch and segment dimensions
return z # Return the segmented spectrogram tensor
class AdaptCNN(nn.Module):
"""
AdaptCNN: A convolutional neural network (CNN) with adaptive max pooling that
can be used as a framewise model. This architecture is more flexible than a
standard CNN, which requires a fixed input dimension. The model consists of six
convolutional layers, with adaptive pooling at each layer to handle varying input sizes.
Args:
input_channels (int): Number of input channels (default is 2).
c_out_1 (int): Number of output channels for the first convolutional layer (default is 16).
c_out_2 (int): Number of output channels for the second convolutional layer (default is 32).
c_out_3 (int): Number of output channels for the third and subsequent convolutional layers (default is 64).
kernel_size (list or int): Size of the convolutional kernels (default is [3, 3]).
dropout (float): Dropout rate for regularization (default is 0.2).
pool_1 (list): Pooling parameters for the first adaptive pooling layer (default is [101, 7]).
pool_2 (list): Pooling parameters for the second adaptive pooling layer (default is [50, 7]).
pool_3 (list): Pooling parameters for the third adaptive pooling layer (default is [25, 5]).
pool_4 (list): Pooling parameters for the fourth adaptive pooling layer (default is [12, 5]).
pool_5 (list): Pooling parameters for the fifth adaptive pooling layer (default is [6, 3]).
fc_out_h (int, optional): Number of output units for the final fully connected layer. If None, the output size is determined from previous layers.
Attributes:
name (str): Name of the model.
dropout (Dropout): Dropout layer for regularization.
conv1, conv2, conv3, conv4, conv5, conv6 (Conv2d): Convolutional layers.
bn1, bn2, bn3, bn4, bn5, bn6 (BatchNorm2d): Batch normalization layers.
fc (Linear, optional): Fully connected layer.
fan_out (int): Output dimension of the final layer.
"""
def __init__(self,
input_channels=2,
c_out_1=16,
c_out_2=32,
c_out_3=64,
kernel_size=[3, 3],
dropout=0.2,
pool_1=[101, 7],
pool_2=[50, 7],
pool_3=[25, 5],
pool_4=[12, 5],
pool_5=[6, 3],
fc_out_h=None):
"""Initializes the AdaptCNN model with the specified parameters."""
super().__init__()
self.name = 'CNN_adapt'
# Model parameters
self.input_channels = input_channels
self.c_out_1 = c_out_1
self.c_out_2 = c_out_2
self.c_out_3 = c_out_3
self.kernel_size = kernel_size
self.pool_1 = pool_1
self.pool_2 = pool_2
self.pool_3 = pool_3
self.pool_4 = pool_4
self.pool_5 = pool_5
self.dropout_rate = dropout
self.fc_out_h = fc_out_h
# Dropout layer for regularization
self.dropout = nn.Dropout2d(p=self.dropout_rate)
# Ensure kernel_size is a tuple
if isinstance(self.kernel_size, int):
self.kernel_size = (self.kernel_size, self.kernel_size)
# Set kernel size for the last convolutional layer
self.kernel_size_last = (self.kernel_size[0], self.pool_5[1])
# Determine padding for convolutional layers based on kernel size
if self.kernel_size[1] == 1:
self.cnn_pad = (1, 0) # No padding needed for 1D convolution
else:
self.cnn_pad = (1, 1) # Padding for 2D convolution
# Define convolutional layers with batch normalization
self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)
self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)
self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn5 = nn.BatchNorm2d(self.conv5.out_channels)
self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0))
self.bn6 = nn.BatchNorm2d(self.conv6.out_channels)
# Define fully connected layer if output size is specified
if self.fc_out_h:
self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h)
self.fan_out = self.fc_out_h
else:
self.fan_out = (self.conv6.out_channels * self.pool_3[0])
def forward(self, x):
"""Defines the forward pass of the AdaptCNN model.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, input_channels, height, width].
Returns:
torch.Tensor: Output tensor after passing through the CNN layers.
"""
# Forward pass through each layer with ReLU activation and adaptive pooling
x = F.relu(self.bn1(self.conv1(x))) # First convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_1)) # Adaptive pooling after conv1
x = F.relu(self.bn2(self.conv2(x))) # Second convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_2)) # Adaptive pooling after conv2
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn3(self.conv3(x))) # Third convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_3)) # Adaptive pooling after conv3
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn4(self.conv4(x))) # Fourth convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_4)) # Adaptive pooling after conv4
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn5(self.conv5(x))) # Fifth convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_5)) # Adaptive pooling after conv5
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn6(self.conv6(x))) # Last convolutional layer
# Flatten the output for the fully connected layer
x = x.view(-1, self.conv6.out_channels * self.pool_5[0])
# Apply fully connected layer if defined
if self.fc_out_h:
x = self.fc(x) # Fully connected output
return x # Return the output tensor
class PoolAttFF(nn.Module):
"""
PoolAttFF: An attention pooling module with an additional feed-forward network.
This module performs attention-based pooling on input features followed by a
feed-forward neural network. The attention mechanism helps in focusing on the
important parts of the input while pooling.
Args:
d_input (int): The dimensionality of the input features (default is 384).
output_size (int): The size of the output after the feed-forward network (default is 1).
h (int): The size of the hidden layer in the feed-forward network (default is 128).
dropout (float): The dropout rate for regularization (default is 0.1).
Attributes:
linear1 (Linear): First linear layer transforming input features to hidden size.
linear2 (Linear): Second linear layer producing attention scores.
linear3 (Linear): Final linear layer producing the output.
activation (function): Activation function used in the network (ReLU).
dropout (Dropout): Dropout layer for regularization.
"""
def __init__(self, d_input=384, output_size=1, h=128, dropout=0.1):
"""Initializes the PoolAttFF module with the specified parameters."""
super().__init__()
# Define the feed-forward layers
self.linear1 = nn.Linear(d_input, h) # First linear layer
self.linear2 = nn.Linear(h, 1) # Second linear layer for attention scores
self.linear3 = nn.Linear(d_input, output_size) # Final output layer
self.activation = F.relu # Activation function
self.dropout = nn.Dropout(dropout) # Dropout layer for regularization
def forward(self, x):
"""Defines the forward pass of the PoolAttFF module.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, d_input].
Returns:
torch.Tensor: Output tensor after attention pooling and feed-forward network.
"""
# Compute attention scores
att = self.linear2(self.dropout(self.activation(self.linear1(x))))
att = att.transpose(2, 1) # Transpose for softmax
# Apply softmax to get attention weights
att = F.softmax(att, dim=2) # Softmax along the sequence length
# Perform attention pooling
x = torch.bmm(att, x) # Batch matrix multiplication
x = x.squeeze(1) # Remove unnecessary dimension
x = self.linear3(x) # Final output layer
return x # Return the output tensor
class Discriminator(nn.Module):
"""
Discriminator: A neural network that predicts a normalized PESQ value
between a predicted waveform (x) and a ground truth waveform (y).
The model concatenates the two input waveforms, processes them through
a convolutional network (CNN), applies self-attention, and outputs a
value between 0 and 1 using a sigmoid activation function.
Args:
ndf (int): Number of filters in the convolutional layers (not directly used in this implementation).
in_channel (int): Number of input channels (default is 2).
Attributes:
dim (int): Dimensionality of the feature representation (default is 384).
cnn (AdaptCNN): CNN model for feature extraction.
att (Sequential): Sequential stack of Conformer blocks for attention processing.
pool (PoolAttFF): Attention pooling module.
sigmoid (LearnableSigmoid): Sigmoid layer for final output.
"""
def __init__(self, ndf, in_channel=2):
"""Initializes the Discriminator with specified parameters."""
super().__init__()
self.dim = 384 # Dimensionality of the feature representation
self.cnn = AdaptCNN() # CNN model for feature extraction
# Define attention layers using Conformer blocks
self.att = nn.Sequential(
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2),
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2)
)
# Define attention pooling module
self.pool = PoolAttFF()
self.sigmoid = LearnableSigmoid(1) # Sigmoid layer for output normalization
def forward(self, x, y):
"""Defines the forward pass of the Discriminator.
Args:
x (torch.Tensor): Predicted waveform tensor of shape [batch_size, 1, height, width].
y (torch.Tensor): Ground truth waveform tensor of shape [batch_size, 1, height, width].
Returns:
torch.Tensor: Output tensor representing the predicted PESQ value.
"""
B, _, _, _ = x.size() # Get the batch size from input x
x = segment_specs(x.squeeze(1)) # Segment and process predicted waveform
y = segment_specs(y.squeeze(1)) # Segment and process ground truth waveform
# Concatenate the processed waveforms
xy = torch.cat([x, y], dim=1) # Concatenate along the channel dimension
cnn_out = self.cnn(xy) # Extract features using CNN
_, d = cnn_out.size() # Get dimensions of CNN output
cnn_out = cnn_out.view(B, -1, d) # Reshape for attention processing
att_out = self.att(cnn_out) # Apply self-attention layers
pool_out = self.pool(att_out) # Apply attention pooling module
out = self.sigmoid(pool_out) # Normalize output using sigmoid function
return out # Return the predicted PESQ value