Guy Shiran commited on
Commit
427926d
·
unverified ·
2 Parent(s): c4b2a35 d5e984f

Merge pull request #33 from LightricksResearch/compress-all-half-channels

Browse files
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -455,6 +455,8 @@ class Decoder(nn.Module):
455
  block_params = block_params if isinstance(block_params, dict) else {}
456
  if block_name == "res_x_y":
457
  output_channel = output_channel * block_params.get("multiplier", 2)
 
 
458
 
459
  self.conv_in = make_conv_nd(
460
  dims,
@@ -503,11 +505,13 @@ class Decoder(nn.Module):
503
  dims=dims, in_channels=input_channel, stride=(1, 2, 2)
504
  )
505
  elif block_name == "compress_all":
 
506
  block = DepthToSpaceUpsample(
507
  dims=dims,
508
  in_channels=input_channel,
509
  stride=(2, 2, 2),
510
  residual=block_params.get("residual", False),
 
511
  )
512
  else:
513
  raise ValueError(f"unknown layer: {block_name}")
@@ -618,10 +622,14 @@ class UNetMidBlock3D(nn.Module):
618
 
619
 
620
  class DepthToSpaceUpsample(nn.Module):
621
- def __init__(self, dims, in_channels, stride, residual=False):
 
 
622
  super().__init__()
623
  self.stride = stride
624
- self.out_channels = np.prod(stride) * in_channels
 
 
625
  self.conv = make_conv_nd(
626
  dims=dims,
627
  in_channels=in_channels,
@@ -631,6 +639,7 @@ class DepthToSpaceUpsample(nn.Module):
631
  causal=True,
632
  )
633
  self.residual = residual
 
634
 
635
  def forward(self, x, causal: bool = True):
636
  if self.residual:
@@ -642,7 +651,8 @@ class DepthToSpaceUpsample(nn.Module):
642
  p2=self.stride[1],
643
  p3=self.stride[2],
644
  )
645
- x_in = x_in.repeat(1, np.prod(self.stride), 1, 1, 1)
 
646
  if self.stride[0] == 2:
647
  x_in = x_in[:, :, 1:, :, :]
648
  x = self.conv(x, causal=causal)
 
455
  block_params = block_params if isinstance(block_params, dict) else {}
456
  if block_name == "res_x_y":
457
  output_channel = output_channel * block_params.get("multiplier", 2)
458
+ if block_name == "compress_all":
459
+ output_channel = output_channel * block_params.get("multiplier", 1)
460
 
461
  self.conv_in = make_conv_nd(
462
  dims,
 
505
  dims=dims, in_channels=input_channel, stride=(1, 2, 2)
506
  )
507
  elif block_name == "compress_all":
508
+ output_channel = output_channel // block_params.get("multiplier", 1)
509
  block = DepthToSpaceUpsample(
510
  dims=dims,
511
  in_channels=input_channel,
512
  stride=(2, 2, 2),
513
  residual=block_params.get("residual", False),
514
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
515
  )
516
  else:
517
  raise ValueError(f"unknown layer: {block_name}")
 
622
 
623
 
624
  class DepthToSpaceUpsample(nn.Module):
625
+ def __init__(
626
+ self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
627
+ ):
628
  super().__init__()
629
  self.stride = stride
630
+ self.out_channels = (
631
+ np.prod(stride) * in_channels // out_channels_reduction_factor
632
+ )
633
  self.conv = make_conv_nd(
634
  dims=dims,
635
  in_channels=in_channels,
 
639
  causal=True,
640
  )
641
  self.residual = residual
642
+ self.out_channels_reduction_factor = out_channels_reduction_factor
643
 
644
  def forward(self, x, causal: bool = True):
645
  if self.residual:
 
651
  p2=self.stride[1],
652
  p3=self.stride[2],
653
  )
654
+ num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
655
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
656
  if self.stride[0] == 2:
657
  x_in = x_in[:, :, 1:, :, :]
658
  x = self.conv(x, causal=causal)