Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from typing import Union | |
from torch.nn.common_types import _size_2_t | |
class Conv2D_SubChannels(nn.Conv2d): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t = 1, | |
padding: Union[str, _size_2_t] = 0, | |
dilation: _size_2_t = 1, | |
groups: int = 1, | |
bias: bool = True, | |
padding_mode: str = 'zeros', | |
device=None, | |
dtype=None, | |
) -> None: | |
super().__init__(in_channels, out_channels, kernel_size, stride, | |
padding, dilation, groups, bias, padding_mode, device, dtype) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
if prefix+"weight" in state_dict and ((state_dict[prefix+"weight"].shape[0] > self.out_channels) or (state_dict[prefix+"weight"].shape[1] > self.in_channels)): | |
print( | |
f"Model checkpoint has too many channels. Excluding channels of convolution {prefix}.") | |
if self.bias is not None: | |
bias = state_dict[prefix+"bias"][:self.out_channels] | |
state_dict[prefix+"bias"] = bias | |
del bias | |
weight = state_dict[prefix+"weight"] | |
state_dict[prefix+"weight"] = weight[:self.out_channels, | |
:self.in_channels] | |
del weight | |
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
class Conv2D_ExtendedChannels(nn.Conv2d): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t = 1, | |
padding: Union[str, _size_2_t] = 0, | |
dilation: _size_2_t = 1, | |
groups: int = 1, | |
bias: bool = True, | |
padding_mode: str = 'zeros', | |
device=None, | |
dtype=None, | |
in_channel_extension: int = 0, | |
out_channel_extension: int = 0, | |
) -> None: | |
super().__init__(in_channels+in_channel_extension, out_channels+out_channel_extension, kernel_size, stride, | |
padding, dilation, groups, bias, padding_mode, device, dtype) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
print(f"Call extend channel loader with {prefix}") | |
if prefix+"weight" in state_dict and (state_dict[prefix+"weight"].shape[0] < self.out_channels or state_dict[prefix+"weight"].shape[1] < self.in_channels): | |
print( | |
f"Model checkpoint has insufficient channels. Extending channels of convolution {prefix} by adding zeros.") | |
if self.bias is not None: | |
bias = state_dict[prefix+"bias"] | |
state_dict[prefix+"bias"] = torch.cat( | |
[bias, torch.zeros(self.out_channels-len(bias), dtype=bias.dtype, layout=bias.layout, device=bias.device)]) | |
del bias | |
weight = state_dict[prefix+"weight"] | |
extended_weight = torch.zeros(self.out_channels, self.in_channels, | |
weight.shape[2], weight.shape[3], device=weight.device, dtype=weight.dtype, layout=weight.layout) | |
extended_weight[:weight.shape[0], :weight.shape[1]] = weight | |
state_dict[prefix+"weight"] = extended_weight | |
del extended_weight | |
del weight | |
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
if __name__ == "__main__": | |
class MyModel(nn.Module): | |
def __init__(self, conv_type: str, c_in, c_out, in_extension, out_extension) -> None: | |
super().__init__() | |
if not conv_type == "normal": | |
self.conv1 = Conv2D_ExtendedChannels( | |
c_in, c_out, 3, padding=1, in_channel_extension=in_extension, out_channel_extension=out_extension, bias=True) | |
else: | |
self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True) | |
def forward(self, x): | |
return self.conv1(x) | |
c_in = 9 | |
c_out = 12 | |
c_in_ext = 0 | |
c_out_ext = 3 | |
model = MyModel("normal", c_in, c_out, c_in_ext, c_out_ext) | |
input = torch.randn((4, c_in+c_in_ext, 128, 128)) | |
out_normal = model(input[:, :c_in]) | |
torch.save(model.state_dict(), "model_dummy.py") | |
model_2 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) | |
model_2.load_state_dict(torch.load("model_dummy.py")) | |
out_model_2 = model_2(input) | |
out_special = out_model_2[:, :c_out] | |
out_new = out_model_2[:, c_out:] | |
model_3 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext) | |
model_3.load_state_dict(model_2.state_dict()) | |
# out_model_2 = model_2(input) | |
# out_special = out_model_2[:, :c_out] | |
print( | |
f"Difference: Forward pass with extended convolution minus initial convolution: {(out_normal-out_special).abs().max()}") | |
print(f"Compared tensors with shape: ", | |
out_normal.shape, out_special.shape) | |
if model_3.conv1.bias is not None: | |
criterion = nn.MSELoss() | |
before_opt = model_3.conv1.bias.detach().clone() | |
target = torch.ones_like(out_model_2) | |
optimizer = torch.optim.SGD( | |
model_3.parameters(), lr=0.01, momentum=0.9) | |
for iter in range(10): | |
optimizer.zero_grad() | |
out = model_3(input) | |
loss = criterion(out, target) | |
loss.backward() | |
optimizer.step() | |
print( | |
f"Weights before and after are the same? {before_opt[c_out:].detach()} | {model_3.conv1.bias[c_out:].detach()} ") | |
print(model_3.conv1.bias, model_2.conv1.bias) | |