alibabasglab's picture
Upload 161 files
8e8cd3e verified
raw
history blame
14 kB
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.init as init
import torch.nn.functional as F
EPS = 1e-8
class GlobalLayerNorm(nn.Module):
"""Calculate Global Layer Normalization.
Arguments
---------
dim : (int or list or torch.Size)
Input shape from an expected input of size.
eps : float
A value added to the denominator for numerical stability.
elementwise_affine : bool
A boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> GLN = GlobalLayerNorm(10, 3)
>>> x_norm = GLN(x)
"""
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
if shape == 3:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
if shape == 4:
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x):
"""Returns the normalized tensor.
Arguments
---------
x : torch.Tensor
Tensor of size [N, C, K, S] or [N, C, L].
"""
# x = N x C x K x S or N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x K x S
# gln: mean,var N x 1 x 1
if x.dim() == 3:
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
if self.elementwise_affine:
x = (
self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias
)
else:
x = (x - mean) / torch.sqrt(var + self.eps)
if x.dim() == 4:
mean = torch.mean(x, (1, 2, 3), keepdim=True)
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
if self.elementwise_affine:
x = (
self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias
)
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
"""Calculate Cumulative Layer Normalization.
Arguments
---------
dim : int
Dimension that you want to normalize.
elementwise_affine : True
Learnable per-element affine parameters.
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> CLN = CumulativeLayerNorm(10)
>>> x_norm = CLN(x)
"""
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine, eps=1e-8
)
def forward(self, x):
"""Returns the normalized tensor.
Arguments
---------
x : torch.Tensor
Tensor size [N, C, K, S] or [N, C, L]
"""
# x: N x C x K x S or N x C x L
# N x K x S x C
if x.dim() == 4:
x = x.permute(0, 2, 3, 1).contiguous()
# N x K x S x C == only channel norm
x = super().forward(x)
# N x C x K x S
x = x.permute(0, 3, 1, 2).contiguous()
if x.dim() == 3:
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
def select_norm(norm, dim, shape):
"""Just a wrapper to select the normalization type.
"""
if norm == "gln":
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
if norm == "cln":
return CumulativeLayerNorm(dim, elementwise_affine=True)
if norm == "ln":
return nn.GroupNorm(1, dim, eps=1e-8)
else:
return nn.BatchNorm1d(dim)
class Swish(nn.Module):
"""
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
to a variety of challenging domains such as Image classification and Machine translation.
"""
def __init__(self):
super(Swish, self).__init__()
def forward(self, inputs: Tensor) -> Tensor:
return inputs * inputs.sigmoid()
class GLU(nn.Module):
"""
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
in the paper “Language Modeling with Gated Convolutional Networks”
"""
def __init__(self, dim: int) -> None:
super(GLU, self).__init__()
self.dim = dim
def forward(self, inputs: Tensor) -> Tensor:
outputs, gate = inputs.chunk(2, dim=self.dim)
return outputs * gate.sigmoid()
class Transpose(nn.Module):
""" Wrapper class of torch.transpose() for Sequential module. """
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.transpose(*self.shape)
class Linear(nn.Module):
"""
Wrapper class of torch.nn.Linear
Weight initialize by xavier initialization and bias initialize to zeros.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
init.xavier_uniform_(self.linear.weight)
if bias:
init.zeros_(self.linear.bias)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
class DepthwiseConv1d(nn.Module):
"""
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
this operation is termed in literature as depthwise convolution.
Args:
in_channels (int): Number of channels in the input
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Inputs: inputs
- **inputs** (batch, in_channels, time): Tensor containing input vector
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: Tensor) -> Tensor:
return self.conv(inputs)
class PointwiseConv1d(nn.Module):
"""
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
This operation often used to match dimensions.
Args:
in_channels (int): Number of channels in the input
out_channels (int): Number of channels produced by the convolution
stride (int, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Inputs: inputs
- **inputs** (batch, in_channels, time): Tensor containing input vector
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super(PointwiseConv1d, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: Tensor) -> Tensor:
return self.conv(inputs)
class ConvModule(nn.Module):
"""
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
dropout_p (float, optional): probability of dropout
Inputs: inputs
inputs (batch, time, dim): Tensor contains input sequences
Outputs: outputs
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
def __init__(
self,
in_channels: int,
kernel_size: int = 17,
expansion_factor: int = 2,
dropout_p: float = 0.1,
) -> None:
super(ConvModule, self).__init__()
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
)
def forward(self, inputs: Tensor) -> Tensor:
return inputs + self.sequential(inputs).transpose(1, 2)
class ConvModule_Dilated(nn.Module):
"""
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
dropout_p (float, optional): probability of dropout
Inputs: inputs
inputs (batch, time, dim): Tensor contains input sequences
Outputs: outputs
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
def __init__(
self,
in_channels: int,
kernel_size: int = 17,
expansion_factor: int = 2,
dropout_p: float = 0.1,
) -> None:
super(ConvModule_Gating, self).__init__()
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
)
def forward(self, inputs: Tensor) -> Tensor:
return inputs + self.sequential(inputs).transpose(1, 2)
class DilatedDenseNet(nn.Module):
def __init__(self, depth=4, lorder=20, in_channels=64):
super(DilatedDenseNet, self).__init__()
self.depth = depth
self.in_channels = in_channels
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
self.twidth = lorder*2-1
self.kernel_size = (self.twidth, 1)
for i in range(self.depth):
dil = 2 ** i
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
setattr(self, 'conv{}'.format(i + 1),
nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
dilation=(dil, 1), groups=self.in_channels, bias=False))
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
def forward(self, x):
x = torch.unsqueeze(x, 1)
x_per = x.permute(0, 3, 2, 1)
skip = x_per
for i in range(self.depth):
out = getattr(self, 'pad{}'.format(i + 1))(skip)
out = getattr(self, 'conv{}'.format(i + 1))(out)
out = getattr(self, 'norm{}'.format(i + 1))(out)
out = getattr(self, 'prelu{}'.format(i + 1))(out)
skip = torch.cat([out, skip], dim=1)
out1 = out.permute(0, 3, 2, 1)
return out1.squeeze(1)
class FFConvM_Dilated(nn.Module):
def __init__(
self,
dim_in,
dim_out,
norm_klass = nn.LayerNorm,
dropout = 0.1
):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in),
nn.Linear(dim_in, dim_out),
nn.SiLU(),
DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
nn.Dropout(dropout)
)
def forward(
self,
x,
):
output = self.mdl(x)
return output