Please consider this: I am attempting to follow the error messages and rewrite the code to be compatible with torch.Tensor in case it is the only shot I’ll ever have at using my GPU with SD. Problem being, I have very little experience with this kind of stuff… I didn’t use a validator yet: first come first server.
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import AdaGroupNorm
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv:
return self.conv(x)
else:
return F.avg_pool1d(x, kernel_size=2, stride=2)
class FirUpsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
if self.use_conv:
convH = weight.shape[2]
convW = weight.shape[3]
inC = weight.shape[1]
pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
output_shape = (
(x.shape[2] - 1) * factor + convH,
(x.shape[3] - 1) * factor + convW,
)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = x.shape[1] // inC
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4).reshape(num_groups * inC, -1, convH, convW)
# Use F.conv_transpose2d to perform inverse convolution
inverse_conv = F.conv_transpose2d(
input=hidden_states, weight=weight, stride=stride, output_padding=output_padding, padding=0
)
# Use upfirdn2d_native to perform upfirdn operation with FIR kernel
output = upfirdn2d_native(
input=inverse_conv,
kernel=torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else:
# Calculate padding and use upfirdn2d_native to perform upfirdn operation with FIR kernel
pad_value = kernel_shape[0] - factor
output = upfirdn2d_native(
input=x,
kernel=torch.tensor(kernel, device=x.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def forward(self, x):
# Perform upsample and convolution (if applicable) using _upsample_2d function
if self.use_conv:
output = self._upsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
output += self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
output = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
return output
class FirDownsample2D(nn.Module):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.weight = self.Conv2d_0.weight
else:
self.weight = None
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(self, x, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = np.asarray(kernel, dtype=np.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = self.weight.shape
p = (kernel.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, self.weight, stride=s, padding=0)
else:
p = kernel.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
return x
class KDownsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
if isinstance(x, torch.Tensor):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
else:
raise TypeError("Expected input to be a torch.Tensor")
class KUpsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
output_shape = (x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3] * 2)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1, output_padding=1, output_size=output_shape)
class ResnetBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
super(ResnetBlock2D, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.bn2 = nn.BatchNorm2d(out_channels)
if in_channels == out_channels:
self.identity = nn.Identity()
else:
self.identity = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels))
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
identity = self.identity(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x += identity
x = self.relu2(x)
return x
def forward(self, x):
residual = x
# conv 1
x = self.conv1(x)
# normalization 1
if self.pre_norm:
x = self.norm1(x)
x = self.activation1(x)
x = self.dropout(x)
else:
x = self.activation1(x)
x = self.norm1(x)
x = self.dropout(x)
# conv 2
x = self.conv2(x)
# normalization 2
x = self.norm2(x)
# shortcut connection
if self.use_conv_shortcut:
if self.use_in_shortcut:
shortcut = self.conv_shortcut(self.norm1_shortcut(residual))
shortcut = self.norm_shortcut(shortcut)
else:
shortcut = self.norm_shortcut(residual)
# add shortcut connection
x += shortcut
# non-linearity 2
x = self.activation2(x)
return x
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
def get_activation(name):
if name == "relu":
return torch.nn.ReLU(inplace=True)
elif name == "leaky_relu":
return torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)
elif name == "mish":
return Mish()
else:
raise ValueError(f"{name} is not a valid activation function name.")
def rearrange_dims(tensor):
if len(tensor.shape) == 2:
return tensor.unsqueeze(-1)
if len(tensor.shape) == 3:
return tensor.unsqueeze(2)
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
def forward(self, x):
x = self.conv1d(x)
x = x.unsqueeze(-1)
x = self.group_norm(x)
x = x.squeeze(-1)
x = self.mish(x)
return x
class ResidualTemporalBlock1D(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
self.time_emb_act = nn.Mish()
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
t = self.time_emb_act(t)
t = self.time_emb(t)
out = self.conv_in(x) + rearrange_dims(t)
out = self.conv_out(out)
return out + self.residual_conv(x)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
a: multiple of the upsampling factor.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
output: Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = torch.tensor([1] * factor, dtype=torch.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)
return output
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = torch.tensor(up, dtype=torch.int64)
down_x = down_y = torch.tensor(down, dtype=torch.int64)
pad_x0 = pad_y0 = torch.tensor(pad[0], dtype=torch.int64)
pad_x1 = pad_y1 = torch.tensor(pad[1], dtype=torch.int64)
_, channel, in_h, in_w = tensor.shape
tensor = tensor.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, torch.tensor(0)), max(pad_x1, torch.tensor(0)), max(pad_y0, torch.tensor(0)), max(pad_y1, torch.tensor(0))])
out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, torch.tensor(0)) : out.shape[1] - max(-pad_y1, torch.tensor(0)),
max(-pad_x0, torch.tensor(0)) : out.shape[2] - max(-pad_x1, torch.tensor(0)),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
Current output: ImportError: cannot import name ‘Downsample2D’ from ‘diffusers.models.resnet’