Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch.nn.functional as F | |
import torch | |
import torch.nn as nn | |
from models.mossformer_gan_se.conformer import ConformerBlock | |
class LearnableSigmoid(nn.Module): | |
"""A learnable sigmoid activation function that scales the output | |
based on the input features. | |
Args: | |
in_features (int): The number of input features for the sigmoid function. | |
beta (float, optional): A scaling factor for the sigmoid output. Default is 1. | |
Attributes: | |
beta (float): The scaling factor for the sigmoid function. | |
slope (Parameter): Learnable parameter that adjusts the slope of the sigmoid. | |
""" | |
def __init__(self, in_features, beta=1): | |
"""Initializes the LearnableSigmoid module. | |
Args: | |
in_features (int): Number of input features. | |
beta (float, optional): Scaling factor for the sigmoid output. | |
""" | |
super().__init__() | |
self.beta = beta # Scaling factor for the sigmoid | |
self.slope = nn.Parameter(torch.ones(in_features)) # Learnable slope parameter | |
self.slope.requiresGrad = True # Ensure gradient updates | |
def forward(self, x): | |
"""Forward pass of the learnable sigmoid function. | |
Args: | |
x (torch.Tensor): Input tensor with shape [batch_size, in_features]. | |
Returns: | |
torch.Tensor: The scaled sigmoid output tensor. | |
""" | |
return self.beta * torch.sigmoid(self.slope * x) | |
#%% Spectrograms | |
def segment_specs(y, seg_length=15, seg_hop=4, max_length=None): | |
"""Segments a spectrogram into smaller segments for input to a CNN. | |
Each segment includes neighboring frequency bins to preserve | |
contextual information. | |
Args: | |
y (torch.Tensor): Input spectrogram tensor of shape [B, H, W], | |
where B is batch size, H is number of mel bands, | |
and W is the length of the spectrogram. | |
seg_length (int): Length of each segment (must be odd). Default is 15. | |
seg_hop (int): Hop length for segmenting the spectrogram. Default is 4. | |
max_length (int, optional): Maximum number of windows allowed. If the number of | |
windows exceeds this, a ValueError is raised. | |
Returns: | |
torch.Tensor: Segmented tensor with shape [B*n, C, H, seg_length], where n is the | |
number of segments, C is the number of channels (always 1). | |
Raises: | |
ValueError: If seg_length is even or if the number of windows exceeds max_length. | |
""" | |
# Ensure segment length is odd | |
if seg_length % 2 == 0: | |
raise ValueError('seg_length must be odd! (seg_length={})'.format(seg_length)) | |
# Convert input to tensor if it's not already | |
if not torch.is_tensor(y): | |
y = torch.tensor(y) | |
B, _, _ = y.size() # Extract batch size and dimensions | |
for b in range(B): | |
x = y[b, :, :] # Extract the current batch's spectrogram | |
n_wins = x.shape[1] - (seg_length - 1) # Calculate number of windows | |
# Segment the mel-spectrogram | |
idx1 = torch.arange(seg_length) # Indices for segment length | |
idx2 = torch.arange(n_wins) # Indices for number of windows | |
idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1) # Create indices for segments | |
x = x.transpose(1, 0)[idx3, :].unsqueeze(1).transpose(3, 2) # Rearrange dimensions for CNN input | |
# Adjust segments based on hop length | |
if seg_hop > 1: | |
x = x[::seg_hop, :] # Downsample segments | |
n_wins = int(np.ceil(n_wins / seg_hop)) # Update number of windows | |
# Pad the segments if max_length is specified | |
if max_length is not None: | |
if max_length < n_wins: | |
raise ValueError('n_wins {} > max_length {}. Increase max window length max_segments!'.format(n_wins, max_length)) | |
x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3])) # Create a padded tensor | |
x_padded[:n_wins, :] = x # Fill the padded tensor with the segments | |
x = x_padded # Update x to the padded tensor | |
# Concatenate segments from each batch | |
if b == 0: | |
z = x.unsqueeze(0) # Initialize z for the first batch | |
else: | |
z = torch.cat((z, x.unsqueeze(0)), axis=0) # Concatenate to z | |
# Reshape the final tensor for output | |
B, n, c, f, t = z.size() | |
z = z.view(B * n, c, f, t) # Combine batch and segment dimensions | |
return z # Return the segmented spectrogram tensor | |
class AdaptCNN(nn.Module): | |
""" | |
AdaptCNN: A convolutional neural network (CNN) with adaptive max pooling that | |
can be used as a framewise model. This architecture is more flexible than a | |
standard CNN, which requires a fixed input dimension. The model consists of six | |
convolutional layers, with adaptive pooling at each layer to handle varying input sizes. | |
Args: | |
input_channels (int): Number of input channels (default is 2). | |
c_out_1 (int): Number of output channels for the first convolutional layer (default is 16). | |
c_out_2 (int): Number of output channels for the second convolutional layer (default is 32). | |
c_out_3 (int): Number of output channels for the third and subsequent convolutional layers (default is 64). | |
kernel_size (list or int): Size of the convolutional kernels (default is [3, 3]). | |
dropout (float): Dropout rate for regularization (default is 0.2). | |
pool_1 (list): Pooling parameters for the first adaptive pooling layer (default is [101, 7]). | |
pool_2 (list): Pooling parameters for the second adaptive pooling layer (default is [50, 7]). | |
pool_3 (list): Pooling parameters for the third adaptive pooling layer (default is [25, 5]). | |
pool_4 (list): Pooling parameters for the fourth adaptive pooling layer (default is [12, 5]). | |
pool_5 (list): Pooling parameters for the fifth adaptive pooling layer (default is [6, 3]). | |
fc_out_h (int, optional): Number of output units for the final fully connected layer. If None, the output size is determined from previous layers. | |
Attributes: | |
name (str): Name of the model. | |
dropout (Dropout): Dropout layer for regularization. | |
conv1, conv2, conv3, conv4, conv5, conv6 (Conv2d): Convolutional layers. | |
bn1, bn2, bn3, bn4, bn5, bn6 (BatchNorm2d): Batch normalization layers. | |
fc (Linear, optional): Fully connected layer. | |
fan_out (int): Output dimension of the final layer. | |
""" | |
def __init__(self, | |
input_channels=2, | |
c_out_1=16, | |
c_out_2=32, | |
c_out_3=64, | |
kernel_size=[3, 3], | |
dropout=0.2, | |
pool_1=[101, 7], | |
pool_2=[50, 7], | |
pool_3=[25, 5], | |
pool_4=[12, 5], | |
pool_5=[6, 3], | |
fc_out_h=None): | |
"""Initializes the AdaptCNN model with the specified parameters.""" | |
super().__init__() | |
self.name = 'CNN_adapt' | |
# Model parameters | |
self.input_channels = input_channels | |
self.c_out_1 = c_out_1 | |
self.c_out_2 = c_out_2 | |
self.c_out_3 = c_out_3 | |
self.kernel_size = kernel_size | |
self.pool_1 = pool_1 | |
self.pool_2 = pool_2 | |
self.pool_3 = pool_3 | |
self.pool_4 = pool_4 | |
self.pool_5 = pool_5 | |
self.dropout_rate = dropout | |
self.fc_out_h = fc_out_h | |
# Dropout layer for regularization | |
self.dropout = nn.Dropout2d(p=self.dropout_rate) | |
# Ensure kernel_size is a tuple | |
if isinstance(self.kernel_size, int): | |
self.kernel_size = (self.kernel_size, self.kernel_size) | |
# Set kernel size for the last convolutional layer | |
self.kernel_size_last = (self.kernel_size[0], self.pool_5[1]) | |
# Determine padding for convolutional layers based on kernel size | |
if self.kernel_size[1] == 1: | |
self.cnn_pad = (1, 0) # No padding needed for 1D convolution | |
else: | |
self.cnn_pad = (1, 1) # Padding for 2D convolution | |
# Define convolutional layers with batch normalization | |
self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad) | |
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels) | |
self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad) | |
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels) | |
self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
self.bn3 = nn.BatchNorm2d(self.conv3.out_channels) | |
self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
self.bn4 = nn.BatchNorm2d(self.conv4.out_channels) | |
self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad) | |
self.bn5 = nn.BatchNorm2d(self.conv5.out_channels) | |
self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0)) | |
self.bn6 = nn.BatchNorm2d(self.conv6.out_channels) | |
# Define fully connected layer if output size is specified | |
if self.fc_out_h: | |
self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h) | |
self.fan_out = self.fc_out_h | |
else: | |
self.fan_out = (self.conv6.out_channels * self.pool_3[0]) | |
def forward(self, x): | |
"""Defines the forward pass of the AdaptCNN model. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, input_channels, height, width]. | |
Returns: | |
torch.Tensor: Output tensor after passing through the CNN layers. | |
""" | |
# Forward pass through each layer with ReLU activation and adaptive pooling | |
x = F.relu(self.bn1(self.conv1(x))) # First convolutional layer | |
x = F.adaptive_max_pool2d(x, output_size=(self.pool_1)) # Adaptive pooling after conv1 | |
x = F.relu(self.bn2(self.conv2(x))) # Second convolutional layer | |
x = F.adaptive_max_pool2d(x, output_size=(self.pool_2)) # Adaptive pooling after conv2 | |
x = self.dropout(x) # Apply dropout | |
x = F.relu(self.bn3(self.conv3(x))) # Third convolutional layer | |
x = F.adaptive_max_pool2d(x, output_size=(self.pool_3)) # Adaptive pooling after conv3 | |
x = self.dropout(x) # Apply dropout | |
x = F.relu(self.bn4(self.conv4(x))) # Fourth convolutional layer | |
x = F.adaptive_max_pool2d(x, output_size=(self.pool_4)) # Adaptive pooling after conv4 | |
x = self.dropout(x) # Apply dropout | |
x = F.relu(self.bn5(self.conv5(x))) # Fifth convolutional layer | |
x = F.adaptive_max_pool2d(x, output_size=(self.pool_5)) # Adaptive pooling after conv5 | |
x = self.dropout(x) # Apply dropout | |
x = F.relu(self.bn6(self.conv6(x))) # Last convolutional layer | |
# Flatten the output for the fully connected layer | |
x = x.view(-1, self.conv6.out_channels * self.pool_5[0]) | |
# Apply fully connected layer if defined | |
if self.fc_out_h: | |
x = self.fc(x) # Fully connected output | |
return x # Return the output tensor | |
class PoolAttFF(nn.Module): | |
""" | |
PoolAttFF: An attention pooling module with an additional feed-forward network. | |
This module performs attention-based pooling on input features followed by a | |
feed-forward neural network. The attention mechanism helps in focusing on the | |
important parts of the input while pooling. | |
Args: | |
d_input (int): The dimensionality of the input features (default is 384). | |
output_size (int): The size of the output after the feed-forward network (default is 1). | |
h (int): The size of the hidden layer in the feed-forward network (default is 128). | |
dropout (float): The dropout rate for regularization (default is 0.1). | |
Attributes: | |
linear1 (Linear): First linear layer transforming input features to hidden size. | |
linear2 (Linear): Second linear layer producing attention scores. | |
linear3 (Linear): Final linear layer producing the output. | |
activation (function): Activation function used in the network (ReLU). | |
dropout (Dropout): Dropout layer for regularization. | |
""" | |
def __init__(self, d_input=384, output_size=1, h=128, dropout=0.1): | |
"""Initializes the PoolAttFF module with the specified parameters.""" | |
super().__init__() | |
# Define the feed-forward layers | |
self.linear1 = nn.Linear(d_input, h) # First linear layer | |
self.linear2 = nn.Linear(h, 1) # Second linear layer for attention scores | |
self.linear3 = nn.Linear(d_input, output_size) # Final output layer | |
self.activation = F.relu # Activation function | |
self.dropout = nn.Dropout(dropout) # Dropout layer for regularization | |
def forward(self, x): | |
"""Defines the forward pass of the PoolAttFF module. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, d_input]. | |
Returns: | |
torch.Tensor: Output tensor after attention pooling and feed-forward network. | |
""" | |
# Compute attention scores | |
att = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
att = att.transpose(2, 1) # Transpose for softmax | |
# Apply softmax to get attention weights | |
att = F.softmax(att, dim=2) # Softmax along the sequence length | |
# Perform attention pooling | |
x = torch.bmm(att, x) # Batch matrix multiplication | |
x = x.squeeze(1) # Remove unnecessary dimension | |
x = self.linear3(x) # Final output layer | |
return x # Return the output tensor | |
class Discriminator(nn.Module): | |
""" | |
Discriminator: A neural network that predicts a normalized PESQ value | |
between a predicted waveform (x) and a ground truth waveform (y). | |
The model concatenates the two input waveforms, processes them through | |
a convolutional network (CNN), applies self-attention, and outputs a | |
value between 0 and 1 using a sigmoid activation function. | |
Args: | |
ndf (int): Number of filters in the convolutional layers (not directly used in this implementation). | |
in_channel (int): Number of input channels (default is 2). | |
Attributes: | |
dim (int): Dimensionality of the feature representation (default is 384). | |
cnn (AdaptCNN): CNN model for feature extraction. | |
att (Sequential): Sequential stack of Conformer blocks for attention processing. | |
pool (PoolAttFF): Attention pooling module. | |
sigmoid (LearnableSigmoid): Sigmoid layer for final output. | |
""" | |
def __init__(self, ndf, in_channel=2): | |
"""Initializes the Discriminator with specified parameters.""" | |
super().__init__() | |
self.dim = 384 # Dimensionality of the feature representation | |
self.cnn = AdaptCNN() # CNN model for feature extraction | |
# Define attention layers using Conformer blocks | |
self.att = nn.Sequential( | |
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4, | |
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2), | |
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4, | |
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2) | |
) | |
# Define attention pooling module | |
self.pool = PoolAttFF() | |
self.sigmoid = LearnableSigmoid(1) # Sigmoid layer for output normalization | |
def forward(self, x, y): | |
"""Defines the forward pass of the Discriminator. | |
Args: | |
x (torch.Tensor): Predicted waveform tensor of shape [batch_size, 1, height, width]. | |
y (torch.Tensor): Ground truth waveform tensor of shape [batch_size, 1, height, width]. | |
Returns: | |
torch.Tensor: Output tensor representing the predicted PESQ value. | |
""" | |
B, _, _, _ = x.size() # Get the batch size from input x | |
x = segment_specs(x.squeeze(1)) # Segment and process predicted waveform | |
y = segment_specs(y.squeeze(1)) # Segment and process ground truth waveform | |
# Concatenate the processed waveforms | |
xy = torch.cat([x, y], dim=1) # Concatenate along the channel dimension | |
cnn_out = self.cnn(xy) # Extract features using CNN | |
_, d = cnn_out.size() # Get dimensions of CNN output | |
cnn_out = cnn_out.view(B, -1, d) # Reshape for attention processing | |
att_out = self.att(cnn_out) # Apply self-attention layers | |
pool_out = self.pool(att_out) # Apply attention pooling module | |
out = self.sigmoid(pool_out) # Normalize output using sigmoid function | |
return out # Return the predicted PESQ value | |