StreamingT2V / t2v_enhanced /model /layers /conv_channel_extension.py
hpoghos's picture
add code
f949b3f
raw
history blame
6.09 kB
import torch
import torch.nn as nn
from typing import Union
from torch.nn.common_types import _size_2_t
class Conv2D_SubChannels(nn.Conv2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None,
) -> None:
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode, device, dtype)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if prefix+"weight" in state_dict and ((state_dict[prefix+"weight"].shape[0] > self.out_channels) or (state_dict[prefix+"weight"].shape[1] > self.in_channels)):
print(
f"Model checkpoint has too many channels. Excluding channels of convolution {prefix}.")
if self.bias is not None:
bias = state_dict[prefix+"bias"][:self.out_channels]
state_dict[prefix+"bias"] = bias
del bias
weight = state_dict[prefix+"weight"]
state_dict[prefix+"weight"] = weight[:self.out_channels,
:self.in_channels]
del weight
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class Conv2D_ExtendedChannels(nn.Conv2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None,
in_channel_extension: int = 0,
out_channel_extension: int = 0,
) -> None:
super().__init__(in_channels+in_channel_extension, out_channels+out_channel_extension, kernel_size, stride,
padding, dilation, groups, bias, padding_mode, device, dtype)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
print(f"Call extend channel loader with {prefix}")
if prefix+"weight" in state_dict and (state_dict[prefix+"weight"].shape[0] < self.out_channels or state_dict[prefix+"weight"].shape[1] < self.in_channels):
print(
f"Model checkpoint has insufficient channels. Extending channels of convolution {prefix} by adding zeros.")
if self.bias is not None:
bias = state_dict[prefix+"bias"]
state_dict[prefix+"bias"] = torch.cat(
[bias, torch.zeros(self.out_channels-len(bias), dtype=bias.dtype, layout=bias.layout, device=bias.device)])
del bias
weight = state_dict[prefix+"weight"]
extended_weight = torch.zeros(self.out_channels, self.in_channels,
weight.shape[2], weight.shape[3], device=weight.device, dtype=weight.dtype, layout=weight.layout)
extended_weight[:weight.shape[0], :weight.shape[1]] = weight
state_dict[prefix+"weight"] = extended_weight
del extended_weight
del weight
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if __name__ == "__main__":
class MyModel(nn.Module):
def __init__(self, conv_type: str, c_in, c_out, in_extension, out_extension) -> None:
super().__init__()
if not conv_type == "normal":
self.conv1 = Conv2D_ExtendedChannels(
c_in, c_out, 3, padding=1, in_channel_extension=in_extension, out_channel_extension=out_extension, bias=True)
else:
self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True)
def forward(self, x):
return self.conv1(x)
c_in = 9
c_out = 12
c_in_ext = 0
c_out_ext = 3
model = MyModel("normal", c_in, c_out, c_in_ext, c_out_ext)
input = torch.randn((4, c_in+c_in_ext, 128, 128))
out_normal = model(input[:, :c_in])
torch.save(model.state_dict(), "model_dummy.py")
model_2 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext)
model_2.load_state_dict(torch.load("model_dummy.py"))
out_model_2 = model_2(input)
out_special = out_model_2[:, :c_out]
out_new = out_model_2[:, c_out:]
model_3 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext)
model_3.load_state_dict(model_2.state_dict())
# out_model_2 = model_2(input)
# out_special = out_model_2[:, :c_out]
print(
f"Difference: Forward pass with extended convolution minus initial convolution: {(out_normal-out_special).abs().max()}")
print(f"Compared tensors with shape: ",
out_normal.shape, out_special.shape)
if model_3.conv1.bias is not None:
criterion = nn.MSELoss()
before_opt = model_3.conv1.bias.detach().clone()
target = torch.ones_like(out_model_2)
optimizer = torch.optim.SGD(
model_3.parameters(), lr=0.01, momentum=0.9)
for iter in range(10):
optimizer.zero_grad()
out = model_3(input)
loss = criterion(out, target)
loss.backward()
optimizer.step()
print(
f"Weights before and after are the same? {before_opt[c_out:].detach()} | {model_3.conv1.bias[c_out:].detach()} ")
print(model_3.conv1.bias, model_2.conv1.bias)