guysrn commited on
Commit
d5e984f
·
1 Parent(s): cda00c1

causal_video_autoencoder: add option to half channels in depth to space upsample block

Browse files
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -455,6 +455,8 @@ class Decoder(nn.Module):
455
  block_params = block_params if isinstance(block_params, dict) else {}
456
  if block_name == "res_x_y":
457
  output_channel = output_channel * block_params.get("multiplier", 2)
 
 
458
 
459
  self.conv_in = make_conv_nd(
460
  dims,
@@ -501,11 +503,13 @@ class Decoder(nn.Module):
501
  dims=dims, in_channels=input_channel, stride=(1, 2, 2)
502
  )
503
  elif block_name == "compress_all":
 
504
  block = DepthToSpaceUpsample(
505
  dims=dims,
506
  in_channels=input_channel,
507
  stride=(2, 2, 2),
508
  residual=block_params.get("residual", False),
 
509
  )
510
  else:
511
  raise ValueError(f"unknown layer: {block_name}")
@@ -614,10 +618,14 @@ class UNetMidBlock3D(nn.Module):
614
 
615
 
616
  class DepthToSpaceUpsample(nn.Module):
617
- def __init__(self, dims, in_channels, stride, residual=False):
 
 
618
  super().__init__()
619
  self.stride = stride
620
- self.out_channels = np.prod(stride) * in_channels
 
 
621
  self.conv = make_conv_nd(
622
  dims=dims,
623
  in_channels=in_channels,
@@ -627,6 +635,7 @@ class DepthToSpaceUpsample(nn.Module):
627
  causal=True,
628
  )
629
  self.residual = residual
 
630
 
631
  def forward(self, x, causal: bool = True):
632
  if self.residual:
@@ -638,7 +647,8 @@ class DepthToSpaceUpsample(nn.Module):
638
  p2=self.stride[1],
639
  p3=self.stride[2],
640
  )
641
- x_in = x_in.repeat(1, np.prod(self.stride), 1, 1, 1)
 
642
  if self.stride[0] == 2:
643
  x_in = x_in[:, :, 1:, :, :]
644
  x = self.conv(x, causal=causal)
 
455
  block_params = block_params if isinstance(block_params, dict) else {}
456
  if block_name == "res_x_y":
457
  output_channel = output_channel * block_params.get("multiplier", 2)
458
+ if block_name == "compress_all":
459
+ output_channel = output_channel * block_params.get("multiplier", 1)
460
 
461
  self.conv_in = make_conv_nd(
462
  dims,
 
503
  dims=dims, in_channels=input_channel, stride=(1, 2, 2)
504
  )
505
  elif block_name == "compress_all":
506
+ output_channel = output_channel // block_params.get("multiplier", 1)
507
  block = DepthToSpaceUpsample(
508
  dims=dims,
509
  in_channels=input_channel,
510
  stride=(2, 2, 2),
511
  residual=block_params.get("residual", False),
512
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
513
  )
514
  else:
515
  raise ValueError(f"unknown layer: {block_name}")
 
618
 
619
 
620
  class DepthToSpaceUpsample(nn.Module):
621
+ def __init__(
622
+ self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
623
+ ):
624
  super().__init__()
625
  self.stride = stride
626
+ self.out_channels = (
627
+ np.prod(stride) * in_channels // out_channels_reduction_factor
628
+ )
629
  self.conv = make_conv_nd(
630
  dims=dims,
631
  in_channels=in_channels,
 
635
  causal=True,
636
  )
637
  self.residual = residual
638
+ self.out_channels_reduction_factor = out_channels_reduction_factor
639
 
640
  def forward(self, x, causal: bool = True):
641
  if self.residual:
 
647
  p2=self.stride[1],
648
  p3=self.stride[2],
649
  )
650
+ num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
651
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
652
  if self.stride[0] == 2:
653
  x_in = x_in[:, :, 1:, :, :]
654
  x = self.conv(x, causal=causal)