File size: 3,500 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from models.mossformer2_se.mossformer2 import MossFormer_MaskNet
import torch.nn as nn

class MossFormer2_SE_48K(nn.Module):
    """
    The MossFormer2_SE_48K model for speech enhancement.

    This class encapsulates the functionality of the MossFormer MaskNet
    within a higher-level model. It processes input audio data to produce
    enhanced outputs and corresponding masks.

    Arguments
    ---------
    args : Namespace
        Configuration arguments that may include hyperparameters 
        and model settings (not utilized in this implementation but 
        can be extended for flexibility).

    Example
    ---------
    >>> model = MossFormer2_SE_48K(args).model
    >>> x = torch.randn(10, 180, 2000)  # Example input
    >>> outputs, mask = model(x)  # Forward pass
    >>> outputs.shape, mask.shape  # Check output shapes
    """

    def __init__(self, args):
        super(MossFormer2_SE_48K, self).__init__()
        # Initialize the TestNet model, which contains the MossFormer MaskNet
        self.model = TestNet()  # Instance of TestNet

    def forward(self, x):
        """
        Forward pass through the model.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, S], where B is the batch size,
            N is the number of channels (180 in this case), and S is the
            sequence length (e.g., time frames).

        Returns
        -------
        outputs : torch.Tensor
            Enhanced audio output tensor from the model.

        mask : torch.Tensor
            Mask tensor predicted by the model for speech separation.
        """
        outputs, mask = self.model(x)  # Get outputs and mask from TestNet
        return outputs, mask  # Return the outputs and mask


class TestNet(nn.Module):
    """
    The TestNet class for testing the MossFormer MaskNet implementation.

    This class builds a model that integrates the MossFormer_MaskNet
    for processing input audio and generating masks for source separation.

    Arguments
    ---------
    n_layers : int
        The number of layers in the model. It determines the depth
        of the model architecture, we leave this para unused at this moment.
    """

    def __init__(self, n_layers=18):
        super(TestNet, self).__init__()
        self.n_layers = n_layers  # Set the number of layers
        # Initialize the MossFormer MaskNet with specified input and output channels
        self.mossformer = MossFormer_MaskNet(in_channels=180, out_channels=512, out_channels_final=961)

    def forward(self, input):
        """
        Forward pass through the TestNet model.

        Arguments
        ---------
        input : torch.Tensor
            Input tensor of dimension [B, N, S], where B is the batch size,
            N is the number of input channels (180), and S is the sequence length.

        Returns
        -------
        out_list : list
            List containing the mask tensor predicted by the MossFormer_MaskNet.
        """
        out_list = []  # Initialize output list to store outputs
        # Transpose input to match expected shape for MaskNet
        x = input.transpose(1, 2)  # Change shape from [B, N, S] to [B, S, N]
        
        # Get the mask from the MossFormer MaskNet
        mask = self.mossformer(x)  # Forward pass through the MossFormer_MaskNet
        out_list.append(mask)  # Append the mask to the output list

        return out_list  # Return the list containing the mask