File size: 6,428 Bytes
38dbec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from dataclasses import dataclass, field
from typing import Any, List, Optional

import alpha_clip
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from torchvision.transforms import Normalize

from spar3d.models.network import get_activation
from spar3d.models.utils import BaseModule

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    output_bias: float = 0.0
    add_to_decoder_features: bool = False
    shape: Optional[list[int]] = None
    distribution_eval: str = "sample"


class ClipBasedHeadEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        model: str = "ViT-L/14@336px"

        distribution: str = "beta"

        # ["mean", "mode", "sample", "sample_mean"]
        distribution_eval: str = "mode"

        activation: str = "relu"
        hidden_features: int = 512
        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self):
        self.model, _ = alpha_clip.load(
            self.cfg.model,
        )  # change to your own ckpt path
        self.model.eval()

        if not hasattr(self.model.visual, "input_resolution"):
            self.img_size = 224
        else:
            self.img_size = self.model.visual.input_resolution
            # Check if img_size is subscribable and pick the first element
            if hasattr(self.img_size, "__getitem__"):
                self.img_size = self.img_size[0]

        # Do not add the weights in self.model to the optimizer
        for param in self.model.parameters():
            param.requires_grad = False

        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []
            in_feature = self.model.visual.output_dim

            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        in_feature if i == 0 else self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                ]

            head_layers = [nn.Sequential(*head_layers)]
            head_layers += [
                nn.Sequential(
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                    nn.Linear(self.cfg.hidden_features, 1),
                )
                for _ in range(2)
            ]
            heads[head.name] = nn.ModuleList(head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        cond_image: Float[Tensor, "B 1 H W 4"],
        sample: bool = True,
    ) -> dict[str, Any]:
        # Run the model
        # Resize cond_image to 224
        cond_image = cond_image.flatten(0, 1)
        cond_image = nn.functional.interpolate(
            cond_image.permute(0, 3, 1, 2),
            size=(self.img_size, self.img_size),
            mode="bilinear",
            align_corners=False,
        )
        mask = cond_image[:, 3:4]
        cond_image = cond_image[:, :3] * mask
        cond_image = Normalize(
            mean=OPENAI_DATASET_MEAN,
            std=OPENAI_DATASET_STD,
        )(cond_image)
        mask = Normalize(0.5, 0.26)(mask).half()
        image_features = self.model.visual(cond_image.half(), mask).float()

        # Run the heads
        outputs = {}

        for head_dict in self.cfg.heads:
            head_name = head_dict.name
            shared_head, d1_h, d2_h = self.heads[head_name]
            shared_features = shared_head(image_features)
            d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
            if self.cfg.distribution == "normal":
                mean = d1
                var = d2
                if mean.shape[-1] == 1:
                    outputs[head_name] = torch.distributions.Normal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var),
                    )
                else:
                    outputs[head_name] = torch.distributions.MultivariateNormal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var).diag_embed(),
                    )
            elif self.cfg.distribution == "beta":
                outputs[head_name] = torch.distributions.Beta(
                    torch.nn.functional.softplus(d1 + head_dict.output_bias),
                    torch.nn.functional.softplus(d2 + head_dict.output_bias),
                )
            else:
                raise NotImplementedError

        if sample:
            for head_dict in self.cfg.heads:
                head_name = head_dict.name
                dist = outputs[head_name]

                if head_dict.distribution_eval == "mean":
                    out = dist.mean
                elif head_dict.distribution_eval == "mode":
                    out = dist.mode
                elif head_dict.distribution_eval == "sample_mean":
                    out = dist.sample([10]).mean(-1)
                else:
                    # use rsample if gradient is needed
                    out = dist.rsample() if self.training else dist.sample()

                outputs[head_name] = get_activation(head_dict.output_activation)(out)
                outputs[f"{head_name}_dist"] = dist

        for head in self.cfg.heads:
            if head.shape:
                if not sample:
                    raise ValueError(
                        "Cannot reshape non-sampled probabilisitic outputs"
                    )
                outputs[head.name] = outputs[head.name].reshape(*head.shape)

            if head.add_to_decoder_features:
                outputs[f"decoder_{head.name}"] = outputs[head.name]
                del outputs[head.name]

        return outputs