Spaces:
Running
on
Zero
Running
on
Zero
File size: 18,972 Bytes
8e8cd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.frcrn_se.complex_nn as complex_nn
from models.frcrn_se.se_layer import SELayer
class Encoder(nn.Module):
"""
Encoder module for a neural network, responsible for downsampling input features.
This module consists of a convolutional layer followed by batch normalization and a Leaky ReLU activation.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (tuple): Size of the convolutional kernel.
stride (tuple): Stride of the convolution.
padding (tuple, optional): Padding for the convolution. If None, 'SAME' padding is applied.
complex (bool, optional): If True, use complex convolution layers. Default is False.
padding_mode (str, optional): Padding mode for convolution. Default is "zeros".
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=None, complex=False, padding_mode="zeros"):
super().__init__()
# Determine padding for 'SAME' padding if not provided
if padding is None:
padding = [(i - 1) // 2 for i in kernel_size]
# Select convolution and batch normalization layers based on complex flag
if complex:
conv = complex_nn.ComplexConv2d
bn = complex_nn.ComplexBatchNorm2d
else:
conv = nn.Conv2d
bn = nn.BatchNorm2d
# Define convolutional layer, batch normalization, and activation function
self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x):
"""
Forward pass through the encoder.
Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W) where B is batch size,
C is the number of channels, H is height, and W is width.
Returns:
torch.Tensor: Output tensor after applying convolution, batch normalization, and activation.
"""
x = self.conv(x) # Apply convolution
x = self.bn(x) # Apply batch normalization
x = self.relu(x) # Apply Leaky ReLU activation
return x
class Decoder(nn.Module):
"""
Decoder module for a neural network, responsible for upsampling input features.
This module consists of a transposed convolutional layer followed by batch normalization
and a Leaky ReLU activation.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (tuple): Size of the transposed convolutional kernel.
stride (tuple): Stride of the transposed convolution.
padding (tuple, optional): Padding for the transposed convolution. Default is (0, 0).
complex (bool, optional): If True, use complex transposed convolution layers. Default is False.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=(0, 0), complex=False):
super().__init__()
# Select transposed convolution and batch normalization layers based on complex flag
if complex:
tconv = complex_nn.ComplexConvTranspose2d
bn = complex_nn.ComplexBatchNorm2d
else:
tconv = nn.ConvTranspose2d
bn = nn.BatchNorm2d
# Define transposed convolutional layer, batch normalization, and activation function
self.transconv = tconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x):
"""
Forward pass through the decoder.
Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W) where B is batch size,
C is the number of channels, H is height, and W is width.
Returns:
torch.Tensor: Output tensor after applying transposed convolution, batch normalization, and activation.
"""
x = self.transconv(x) # Apply transposed convolution
x = self.bn(x) # Apply batch normalization
x = self.relu(x) # Apply Leaky ReLU activation
return x
class UNet(nn.Module):
"""
U-Net architecture for handling both real and complex inputs.
This model uses an encoder-decoder structure with skip connections between corresponding encoder
and decoder layers. Squeeze-and-Excitation (SE) layers are integrated into the network for channel
attention enhancement.
Args:
input_channels (int, optional): Number of input channels. Default is 1.
complex (bool, optional): If True, use complex layers. Default is False.
model_complexity (int, optional): Determines the number of channels in the model. Default is 45.
model_depth (int, optional): Depth of the U-Net model (number of encoder/decoder pairs). Default is 20.
padding_mode (str, optional): Padding mode for convolutions. Default is "zeros".
"""
def __init__(self, input_channels=1,
complex=False,
model_complexity=45,
model_depth=20,
padding_mode="zeros"):
super().__init__()
# Adjust model complexity for complex models
if complex:
model_complexity = int(model_complexity // 1.414)
# Initialize model parameters based on specified complexity and depth
self.set_size(model_complexity=model_complexity, input_channels=input_channels, model_depth=model_depth)
self.encoders = []
self.model_length = model_depth // 2
self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
self.se_layers_enc = []
self.fsmn_enc = []
# Build the encoder structure
for i in range(self.model_length):
fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module("fsmn_enc{}".format(i), fsmn_enc)
self.fsmn_enc.append(fsmn_enc)
module = Encoder(self.enc_channels[i], self.enc_channels[i + 1], kernel_size=self.enc_kernel_sizes[i],
stride=self.enc_strides[i], padding=self.enc_paddings[i], complex=complex, padding_mode=padding_mode)
self.add_module("encoder{}".format(i), module)
self.encoders.append(module)
se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
self.add_module("se_layer_enc{}".format(i), se_layer_enc)
self.se_layers_enc.append(se_layer_enc)
# Build the decoder structure
self.decoders = []
self.fsmn_dec = []
self.se_layers_dec = []
for i in range(self.model_length):
fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module("fsmn_dec{}".format(i), fsmn_dec)
self.fsmn_dec.append(fsmn_dec)
module = Decoder(self.dec_channels[i] * 2, self.dec_channels[i + 1], kernel_size=self.dec_kernel_sizes[i],
stride=self.dec_strides[i], padding=self.dec_paddings[i], complex=complex)
self.add_module("decoder{}".format(i), module)
self.decoders.append(module)
if i < self.model_length - 1:
se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
self.add_module("se_layer_dec{}".format(i), se_layer_dec)
self.se_layers_dec.append(se_layer_dec)
# Define final linear layer based on complex flag
if complex:
conv = complex_nn.ComplexConv2d
else:
conv = nn.Conv2d
linear = conv(self.dec_channels[-1], 1, 1) # Final layer to output desired channels
self.add_module("linear", linear)
self.complex = complex
self.padding_mode = padding_mode
# Convert lists to ModuleLists for proper parameter registration
self.decoders = nn.ModuleList(self.decoders)
self.encoders = nn.ModuleList(self.encoders)
self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
self.fsmn_dec = nn.ModuleList(self.fsmn_dec)
def forward(self, inputs):
"""
Forward pass for the UNet model.
This method processes the input tensor through the encoder-decoder architecture,
applying convolutional layers, FSMNs, and SE layers. Skip connections are used
to merge features from the encoder to the decoder.
Args:
inputs (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Output tensor after processing, representing the computed features.
"""
x = inputs # Initialize input tensor
xs = [] # List to store input tensors for skip connections
xs_se = [] # List to store outputs after applying SE layers
xs_se.append(x) # Add the initial input to the SE outputs list
# Forward pass through the encoder layers
for i, encoder in enumerate(self.encoders):
xs.append(x) # Store the current input for skip connections
if i > 0:
x = self.fsmn_enc[i](x) # Apply FSMN if not the first encoder
x = encoder(x) # Apply the encoder layer
xs_se.append(self.se_layers_enc[i](x)) # Apply SE layer and store the result
x = self.fsmn(x) # Apply the final FSMN after all encoders
p = x # Initialize output tensor for decoders
# Forward pass through the decoder layers
for i, decoder in enumerate(self.decoders):
p = decoder(p) # Apply the decoder layer
if i < self.model_length - 1:
p = self.fsmn_dec[i](p) # Apply FSMN if not the last decoder
if i == self.model_length - 1:
break # Stop processing at the last decoder layer
if i < self.model_length - 2:
p = self.se_layers_dec[i](p) # Apply SE layer for intermediate decoders
p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1) # Concatenate skip connection
# Final output processing
# cmp_spec: [batch, 1, 513, 64, 2]
cmp_spec = self.linear(p) # Apply linear transformation to produce final output
return cmp_spec # Return the computed output tensor
def set_size(self, model_complexity, model_depth=20, input_channels=1):
"""
Set the architecture parameters for the UNet model based on specified complexity and depth.
This method configures the encoder and decoder layers of the UNet by setting the number of channels,
kernel sizes, strides, and paddings for each layer according to the provided model complexity
and depth.
Args:
model_complexity (int): Base number of channels for the model.
model_depth (int, optional): Depth of the UNet model, determining the number of encoder/decoder pairs.
Default is 20.
input_channels (int, optional): Number of input channels to the model. Default is 1.
Raises:
ValueError: If an unknown model depth is provided.
"""
# Configuration for model depth of 14
if model_depth == 14:
# Set encoder channels for model depth of 14
self.enc_channels = [input_channels,
128,
128,
128,
128,
128,
128,
128]
# Define kernel sizes for encoder layers
self.enc_kernel_sizes = [(5, 2),
(5, 2),
(5, 2),
(5, 2),
(5, 2),
(5, 2),
(2, 2)]
# Define strides for encoder layers
self.enc_strides = [(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1)]
# Define paddings for encoder layers
self.enc_paddings = [(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1)]
# Set decoder channels for model depth of 14
self.dec_channels = [64,
128,
128,
128,
128,
128,
128,
1]
# Define kernel sizes for decoder layers
self.dec_kernel_sizes = [(2, 2),
(5, 2),
(5, 2),
(5, 2),
(6, 2),
(5, 2),
(5, 2)]
# Define strides for decoder layers
self.dec_strides = [(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1),
(2, 1)]
# Define paddings for decoder layers
self.dec_paddings = [(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1)]
# Configuration for model depth of 20
elif model_depth == 20:
# Set encoder channels for model depth of 20
self.enc_channels = [input_channels,
model_complexity,
model_complexity,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
128]
# Define kernel sizes for encoder layers
self.enc_kernel_sizes = [(7, 1),
(1, 7),
(6, 4),
(7, 5),
(5, 3),
(5, 3),
(5, 3),
(5, 3),
(5, 3),
(5, 3)]
# Define strides for encoder layers
self.enc_strides = [(1, 1),
(1, 1),
(2, 2),
(2, 1),
(2, 2),
(2, 1),
(2, 2),
(2, 1),
(2, 2),
(2, 1)]
# Define paddings for encoder layers
self.enc_paddings = [(3, 0),
(0, 3),
None, # None padding for certain layers
None,
None, # Adjusted padding based on layer requirements
None,
None,
None,
None,
None]
# Set decoder channels for model depth of 20
self.dec_channels = [0,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2,
model_complexity * 2]
# Define kernel sizes for decoder layers
self.dec_kernel_sizes = [(4, 3),
(4, 2),
(4, 3),
(4, 2),
(4, 3),
(4, 2),
(6, 3),
(7, 4),
(1, 7),
(7, 1)]
# Define strides for decoder layers
self.dec_strides = [(2, 1),
(2, 2),
(2, 1),
(2, 2),
(2, 1),
(2, 2),
(2, 1),
(2, 2),
(1, 1),
(1, 1)]
# Define paddings for decoder layers
self.dec_paddings = [(1, 1),
(1, 0),
(1, 1),
(1, 0),
(1, 1),
(1, 0),
(2, 1),
(2, 1),
(0, 3),
(3, 0)]
else:
# Raise an error if an unknown model depth is specified
raise ValueError("Unknown model depth : {}".format(model_depth))
|