Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
class UniDeepFsmn(nn.Module): | |
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None, dropout_p=0.1): | |
super(UniDeepFsmn, self).__init__() | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
if lorder is None: | |
return | |
self.lorder = lorder | |
self.rorder = lorder | |
self.hidden_size = hidden_size | |
self.linear = nn.Linear(input_dim, hidden_size) | |
self.project = nn.Linear(hidden_size, output_dim, bias=False) | |
self.conv1 = nn.Conv2d(input_dim, output_dim, [self.lorder+self.rorder-1, 1], [1, 1], groups=input_dim, bias=False) | |
self.norm = nn.LayerNorm(input_dim) | |
self.dropout = nn.Dropout(p=dropout_p) | |
self.swish = Swish() | |
def forward(self, input): | |
## input: batch (b) x sequence(T) x feature (h) | |
f1 = self.swish(self.linear(self.norm(input))) | |
p1 = self.project(f1) | |
x = torch.unsqueeze(p1, 1) | |
#x: batch (b) x channel (c) x sequence(T) x feature (h) | |
x_per = x.permute(0, 3, 2, 1) | |
#x_per: batch (b) x feature (h) x sequence(T) x channel (c) | |
y = F.pad(x_per, [0, 0, self.lorder - 1, self.rorder - 1]) | |
out = x_per + self.conv1(y) | |
out1 = out.permute(0, 3, 2, 1) | |
#out1: batch (b) x channel (c) x sequence(T) x feature (h) | |
return input + out1.squeeze() | |
class GlobalLayerNorm(nn.Module): | |
"""Calculate Global Layer Normalization. | |
Arguments | |
--------- | |
dim : (int or list or torch.Size) | |
Input shape from an expected input of size. | |
eps : float | |
A value added to the denominator for numerical stability. | |
elementwise_affine : bool | |
A boolean value that when set to True, | |
this module has learnable per-element affine parameters | |
initialized to ones (for weights) and zeros (for biases). | |
Example | |
------- | |
>>> x = torch.randn(5, 10, 20) | |
>>> GLN = GlobalLayerNorm(10, 3) | |
>>> x_norm = GLN(x) | |
""" | |
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): | |
super(GlobalLayerNorm, self).__init__() | |
self.dim = dim | |
self.eps = eps | |
self.elementwise_affine = elementwise_affine | |
if self.elementwise_affine: | |
if shape == 3: | |
self.weight = nn.Parameter(torch.ones(self.dim, 1)) | |
self.bias = nn.Parameter(torch.zeros(self.dim, 1)) | |
if shape == 4: | |
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) | |
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
def forward(self, x): | |
"""Returns the normalized tensor. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Tensor of size [N, C, K, S] or [N, C, L]. | |
""" | |
# x = N x C x K x S or N x C x L | |
# N x 1 x 1 | |
# cln: mean,var N x 1 x K x S | |
# gln: mean,var N x 1 x 1 | |
if x.dim() == 3: | |
mean = torch.mean(x, (1, 2), keepdim=True) | |
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) | |
if self.elementwise_affine: | |
x = ( | |
self.weight * (x - mean) / torch.sqrt(var + self.eps) | |
+ self.bias | |
) | |
else: | |
x = (x - mean) / torch.sqrt(var + self.eps) | |
if x.dim() == 4: | |
mean = torch.mean(x, (1, 2, 3), keepdim=True) | |
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) | |
if self.elementwise_affine: | |
x = ( | |
self.weight * (x - mean) / torch.sqrt(var + self.eps) | |
+ self.bias | |
) | |
else: | |
x = (x - mean) / torch.sqrt(var + self.eps) | |
return x | |
class CumulativeLayerNorm(nn.LayerNorm): | |
"""Calculate Cumulative Layer Normalization. | |
Arguments | |
--------- | |
dim : int | |
Dimension that you want to normalize. | |
elementwise_affine : True | |
Learnable per-element affine parameters. | |
Example | |
------- | |
>>> x = torch.randn(5, 10, 20) | |
>>> CLN = CumulativeLayerNorm(10) | |
>>> x_norm = CLN(x) | |
""" | |
def __init__(self, dim, elementwise_affine=True): | |
super(CumulativeLayerNorm, self).__init__( | |
dim, elementwise_affine=elementwise_affine, eps=1e-8 | |
) | |
def forward(self, x): | |
"""Returns the normalized tensor. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Tensor size [N, C, K, S] or [N, C, L] | |
""" | |
# x: N x C x K x S or N x C x L | |
# N x K x S x C | |
if x.dim() == 4: | |
x = x.permute(0, 2, 3, 1).contiguous() | |
# N x K x S x C == only channel norm | |
x = super().forward(x) | |
# N x C x K x S | |
x = x.permute(0, 3, 1, 2).contiguous() | |
if x.dim() == 3: | |
x = torch.transpose(x, 1, 2) | |
# N x L x C == only channel norm | |
x = super().forward(x) | |
# N x C x L | |
x = torch.transpose(x, 1, 2) | |
return x | |
def select_norm(norm, dim, shape): | |
"""Just a wrapper to select the normalization type. | |
""" | |
if norm == "gln": | |
return GlobalLayerNorm(dim, shape, elementwise_affine=True) | |
if norm == "cln": | |
return CumulativeLayerNorm(dim, elementwise_affine=True) | |
if norm == "ln": | |
return nn.GroupNorm(1, dim, eps=1e-8) | |
else: | |
return nn.BatchNorm1d(dim) | |
class Swish(nn.Module): | |
""" | |
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied | |
to a variety of challenging domains such as Image classification and Machine translation. | |
""" | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, inputs: Tensor) -> Tensor: | |
return inputs * inputs.sigmoid() | |
class GLU(nn.Module): | |
""" | |
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing | |
in the paper “Language Modeling with Gated Convolutional Networks” | |
""" | |
def __init__(self, dim: int) -> None: | |
super(GLU, self).__init__() | |
self.dim = dim | |
def forward(self, inputs: Tensor) -> Tensor: | |
outputs, gate = inputs.chunk(2, dim=self.dim) | |
return outputs * gate.sigmoid() | |
class Transpose(nn.Module): | |
""" Wrapper class of torch.transpose() for Sequential module. """ | |
def __init__(self, shape: tuple): | |
super(Transpose, self).__init__() | |
self.shape = shape | |
def forward(self, x: Tensor) -> Tensor: | |
return x.transpose(*self.shape) | |
class Linear(nn.Module): | |
""" | |
Wrapper class of torch.nn.Linear | |
Weight initialize by xavier initialization and bias initialize to zeros. | |
""" | |
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
super(Linear, self).__init__() | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
init.xavier_uniform_(self.linear.weight) | |
if bias: | |
init.zeros_(self.linear.bias) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.linear(x) | |
class DepthwiseConv1d(nn.Module): | |
""" | |
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
this operation is termed in literature as depthwise convolution. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the convolving kernel | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = False, | |
) -> None: | |
super(DepthwiseConv1d, self).__init__() | |
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
groups=in_channels, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.conv(inputs) | |
class DepthwiseConv2d(nn.Module): | |
""" | |
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
this operation is termed in literature as depthwise convolution. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the convolving kernel | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = False, | |
) -> None: | |
super(DepthwiseConv2d, self).__init__() | |
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" | |
self.lorder = kernel_size | |
self.rorder = self.lorder | |
self.conv = nn.Conv2d(in_channels, out_channels, [self.lorder+self.rorder-1, 1], [1, 1], groups=in_channels, bias=False) | |
''' | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
groups=in_channels, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
''' | |
def forward(self, inputs: Tensor) -> Tensor: | |
##input: batch x feature x sequence | |
x = torch.unsqueeze(inputs, -1) | |
#x_per = x.permute(0, 3, 2, 1) | |
#x_per: batch (b) x feature (h) x sequence(T) x channel (c) | |
#y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) | |
y = F.pad(x, [0, 0, self.lorder - 1, self.rorder - 1]) | |
out = x + self.conv(y) | |
#out1 = out.permute(0, 3, 2, 1) | |
#out1: batch (b) x channel (c) x sequence(T) x feature (h) | |
return out.squeeze(-1) | |
class PointwiseConv1d(nn.Module): | |
""" | |
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. | |
This operation often used to match dimensions. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = True, | |
) -> None: | |
super(PointwiseConv1d, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.conv(inputs) | |
class ConvModule(nn.Module): | |
""" | |
Modified from Conformer convolution module | |
Args: | |
in_channels (int): Number of channels in the input | |
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
dropout_p (float, optional): probability of dropout | |
Inputs: inputs | |
inputs (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
kernel_size: int = 31, | |
expansion_factor: int = 2, | |
dropout_p: float = 0.1, | |
) -> None: | |
super(ConvModule, self).__init__() | |
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
self.sequential = nn.Sequential( | |
Transpose(shape=(1, 2)), | |
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return inputs + self.sequential(inputs).transpose(1, 2) | |
class ConvModule_Gating(nn.Module): | |
""" | |
Modified from Conformer convolution module | |
Args: | |
in_channels (int): Number of channels in the input | |
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
dropout_p (float, optional): probability of dropout | |
Inputs: inputs | |
inputs (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
kernel_size: int = 20, | |
expansion_factor: int = 2, | |
dropout_p: float = 0.1, | |
) -> None: | |
super(ConvModule_Gating, self).__init__() | |
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
self.sequential = nn.Sequential( | |
Transpose(shape=(1, 2)), | |
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return inputs * self.sequential(inputs).transpose(1, 2) | |
class Conformer_ConvModule(nn.Module): | |
""" | |
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). | |
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution | |
to aid training deep models. | |
Args: | |
in_channels (int): Number of channels in the input | |
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
dropout_p (float, optional): probability of dropout | |
Inputs: inputs | |
inputs (batch, dim, time): Tensor contains input sequences | |
Outputs: outputs | |
outputs (batch, dim, time): Tensor produces by conformer convolution module. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
kernel_size: int = 21, | |
expansion_factor: int = 2, | |
dropout_p: float = 0.1, | |
) -> None: | |
super(Conformer_ConvModule, self).__init__() | |
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
self.sequential = nn.Sequential( | |
select_norm('ln',in_channels,3), | |
PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True), | |
GLU(dim=1), | |
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), | |
select_norm('bn',in_channels,3), | |
Swish(), | |
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), | |
nn.Dropout(p=dropout_p), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return inputs + self.sequential(inputs) | |
class FeedForwardModule(nn.Module): | |
""" | |
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit | |
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps | |
regularizing the network. | |
Args: | |
encoder_dim (int): Dimension of conformer encoder | |
expansion_factor (int): Expansion factor of feed forward module. | |
dropout_p (float): Ratio of dropout | |
Inputs: inputs | |
- **inputs** (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
- **outputs** (batch, time, dim): Tensor produces by feed forward module. | |
""" | |
def __init__( | |
self, | |
encoder_dim: int = 512, | |
expansion_factor: int = 4, | |
dropout_p: float = 0.1, | |
) -> None: | |
super(FeedForwardModule, self).__init__() | |
self.sequential = nn.Sequential( | |
nn.LayerNorm(encoder_dim), | |
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), | |
Swish(), | |
nn.Dropout(p=dropout_p), | |
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), | |
nn.Dropout(p=dropout_p), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.sequential(inputs) | |