|
from dataclasses import dataclass, field |
|
from typing import Callable, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
from torch.autograd import Function |
|
from torch.cuda.amp import custom_bwd, custom_fwd |
|
|
|
from spar3d.models.utils import BaseModule, normalize |
|
from spar3d.utils import get_device |
|
|
|
|
|
def conditional_decorator(decorator_with_args, condition, *args, **kwargs): |
|
def wrapper(fn): |
|
if condition: |
|
if len(kwargs) == 0: |
|
return decorator_with_args |
|
return decorator_with_args(*args, **kwargs)(fn) |
|
else: |
|
return fn |
|
|
|
return wrapper |
|
|
|
|
|
class PixelShuffleUpsampleNetwork(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
in_channels: int = 1024 |
|
out_channels: int = 40 |
|
scale_factor: int = 4 |
|
|
|
conv_layers: int = 4 |
|
conv_kernel_size: int = 3 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
layers = [] |
|
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2 |
|
|
|
in_channels = self.cfg.in_channels |
|
for i in range(self.cfg.conv_layers): |
|
cur_out_channels = ( |
|
in_channels if i != self.cfg.conv_layers - 1 else output_channels |
|
) |
|
layers.append( |
|
nn.Conv2d( |
|
in_channels, |
|
cur_out_channels, |
|
self.cfg.conv_kernel_size, |
|
padding=(self.cfg.conv_kernel_size - 1) // 2, |
|
) |
|
) |
|
if i != self.cfg.conv_layers - 1: |
|
layers.append(nn.ReLU(inplace=True)) |
|
|
|
layers.append(nn.PixelShuffle(self.cfg.scale_factor)) |
|
|
|
self.upsample = nn.Sequential(*layers) |
|
|
|
def forward( |
|
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"] |
|
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]: |
|
return rearrange( |
|
self.upsample( |
|
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) |
|
), |
|
"(B Np) Co Hp Wp -> B Np Co Hp Wp", |
|
Np=3, |
|
) |
|
|
|
|
|
class _TruncExp(Function): |
|
|
|
|
|
@staticmethod |
|
@conditional_decorator( |
|
custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32 |
|
) |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return torch.exp(x) |
|
|
|
@staticmethod |
|
@conditional_decorator(custom_bwd, "cuda" in get_device()) |
|
def backward(ctx, g): |
|
x = ctx.saved_tensors[0] |
|
return g * torch.exp(torch.clamp(x, max=15)) |
|
|
|
|
|
trunc_exp = _TruncExp.apply |
|
|
|
|
|
def get_activation(name) -> Callable: |
|
if name is None: |
|
return lambda x: x |
|
name = name.lower() |
|
if name == "none" or name == "linear" or name == "identity": |
|
return lambda x: x |
|
elif name == "lin2srgb": |
|
return lambda x: torch.where( |
|
x > 0.0031308, |
|
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, |
|
12.92 * x, |
|
).clamp(0.0, 1.0) |
|
elif name == "exp": |
|
return lambda x: torch.exp(x) |
|
elif name == "shifted_exp": |
|
return lambda x: torch.exp(x - 1.0) |
|
elif name == "trunc_exp": |
|
return trunc_exp |
|
elif name == "shifted_trunc_exp": |
|
return lambda x: trunc_exp(x - 1.0) |
|
elif name == "sigmoid": |
|
return lambda x: torch.sigmoid(x) |
|
elif name == "tanh": |
|
return lambda x: torch.tanh(x) |
|
elif name == "shifted_softplus": |
|
return lambda x: F.softplus(x - 1.0) |
|
elif name == "scale_-11_01": |
|
return lambda x: x * 0.5 + 0.5 |
|
elif name == "negative": |
|
return lambda x: -x |
|
elif name == "normalize_channel_last": |
|
return lambda x: normalize(x) |
|
elif name == "normalize_channel_first": |
|
return lambda x: normalize(x, dim=1) |
|
else: |
|
try: |
|
return getattr(F, name) |
|
except AttributeError: |
|
raise ValueError(f"Unknown activation function: {name}") |
|
|
|
|
|
class LambdaModule(torch.nn.Module): |
|
def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]): |
|
super().__init__() |
|
self.lambd = lambd |
|
|
|
def forward(self, x): |
|
return self.lambd(x) |
|
|
|
|
|
def get_activation_module(name) -> torch.nn.Module: |
|
return LambdaModule(get_activation(name)) |
|
|
|
|
|
@dataclass |
|
class HeadSpec: |
|
name: str |
|
out_channels: int |
|
n_hidden_layers: int |
|
output_activation: Optional[str] = None |
|
out_bias: float = 0.0 |
|
|
|
|
|
class MaterialMLP(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
in_channels: int = 120 |
|
n_neurons: int = 64 |
|
activation: str = "silu" |
|
heads: List[HeadSpec] = field(default_factory=lambda: []) |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
assert len(self.cfg.heads) > 0 |
|
heads = {} |
|
for head in self.cfg.heads: |
|
head_layers = [] |
|
for i in range(head.n_hidden_layers): |
|
head_layers += [ |
|
nn.Linear( |
|
self.cfg.in_channels if i == 0 else self.cfg.n_neurons, |
|
self.cfg.n_neurons, |
|
), |
|
self.make_activation(self.cfg.activation), |
|
] |
|
head_layers += [ |
|
nn.Linear( |
|
self.cfg.n_neurons, |
|
head.out_channels, |
|
), |
|
] |
|
heads[head.name] = nn.Sequential(*head_layers) |
|
self.heads = nn.ModuleDict(heads) |
|
|
|
def make_activation(self, activation): |
|
if activation == "relu": |
|
return nn.ReLU(inplace=True) |
|
elif activation == "silu": |
|
return nn.SiLU(inplace=True) |
|
else: |
|
raise NotImplementedError |
|
|
|
def keys(self): |
|
return self.heads.keys() |
|
|
|
def forward( |
|
self, x, include: Optional[List] = None, exclude: Optional[List] = None |
|
): |
|
if include is not None and exclude is not None: |
|
raise ValueError("Cannot specify both include and exclude.") |
|
if include is not None: |
|
heads = [h for h in self.cfg.heads if h.name in include] |
|
elif exclude is not None: |
|
heads = [h for h in self.cfg.heads if h.name not in exclude] |
|
else: |
|
heads = self.cfg.heads |
|
|
|
out = { |
|
head.name: get_activation(head.output_activation)( |
|
self.heads[head.name](x) + head.out_bias |
|
) |
|
for head in heads |
|
} |
|
|
|
return out |
|
|