File size: 2,200 Bytes
41f97d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import torch.nn as nn
import clip
from model import ST_GCN_18
class ContrastiveModule(nn.Module):
def __init__(self, args):
super(ContrastiveModule, self).__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
del model.visual
self.model = model
base_channel = 3
base_channel = base_channel * 2 if args.gyro else base_channel
base_channel = base_channel * 2 if args.stft else base_channel
self.model.acc = ST_GCN_18(in_channels=base_channel)
self.model = self.model.float()
if args.stage == 'finetune':
self.fc = nn.Linear(512, args.num_class)
def encode_image(self, image):
return self.model.acc(image.float()).squeeze(-1).squeeze(-1)
def encode_text(self, text):
x = self.model.token_embedding(text).float() # b,t,512
x = x + self.model.positional_embedding.float()
x = x.permute(1, 0, 2) # b,t,512 -> t,b,512
x = self.model.transformer(x)
x = x.permute(1, 0, 2) # t,b,512 -> b,t,512
x = self.model.ln_final(x).float() # b,t,512
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection # b,512
return x
def classifier(self, image):
# for fine-tuning
imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1)
out = self.fc(imu_features)
return out
def forward(self, inputs_imu, inputs_text):
imu_features = self.encode_image(inputs_imu)
text_features = self.encode_text(inputs_text)
# normalized features
imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# logits
logit_scale = self.model.logit_scale.exp()
logits_per_image = logit_scale * imu_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text |