jadechoghari
commited on
Update mar.py
Browse files
mar.py
CHANGED
@@ -10,12 +10,12 @@ from torch.utils.checkpoint import checkpoint
|
|
10 |
|
11 |
from timm.models.vision_transformer import Block
|
12 |
|
13 |
-
from .diffloss import DiffLoss
|
|
|
14 |
|
15 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
-
masking = torch.zeros(bsz, seq_len).
|
18 |
-
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).
|
19 |
return masking
|
20 |
|
21 |
|
@@ -156,7 +156,7 @@ class MAR(nn.Module):
|
|
156 |
order = np.array(list(range(self.seq_len)))
|
157 |
np.random.shuffle(order)
|
158 |
orders.append(order)
|
159 |
-
orders = torch.Tensor(np.array(orders)).
|
160 |
return orders
|
161 |
|
162 |
def random_masking(self, x, orders):
|
@@ -180,7 +180,7 @@ class MAR(nn.Module):
|
|
180 |
# random drop class embedding during training
|
181 |
if self.training:
|
182 |
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
|
183 |
-
drop_latent_mask = drop_latent_mask.unsqueeze(-1).
|
184 |
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
|
185 |
|
186 |
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
|
@@ -262,8 +262,8 @@ class MAR(nn.Module):
|
|
262 |
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
|
263 |
|
264 |
# init and sample generation orders
|
265 |
-
mask = torch.ones(bsz, self.seq_len).
|
266 |
-
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).
|
267 |
orders = self.sample_orders(bsz)
|
268 |
|
269 |
indices = list(range(num_iter))
|
@@ -291,10 +291,10 @@ class MAR(nn.Module):
|
|
291 |
|
292 |
# mask ratio for the next round, following MaskGIT and MAGE.
|
293 |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
294 |
-
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).
|
295 |
|
296 |
# masks out at least one for the next iteration
|
297 |
-
mask_len = torch.maximum(torch.Tensor([1]).
|
298 |
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
299 |
|
300 |
# get masking for next iteration and locations to be predicted in this iteration
|
|
|
10 |
|
11 |
from timm.models.vision_transformer import Block
|
12 |
|
13 |
+
from models.diffloss import DiffLoss
|
14 |
+
|
15 |
|
|
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
+
masking = torch.zeros(bsz, seq_len).cuda()
|
18 |
+
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
|
19 |
return masking
|
20 |
|
21 |
|
|
|
156 |
order = np.array(list(range(self.seq_len)))
|
157 |
np.random.shuffle(order)
|
158 |
orders.append(order)
|
159 |
+
orders = torch.Tensor(np.array(orders)).cuda().long()
|
160 |
return orders
|
161 |
|
162 |
def random_masking(self, x, orders):
|
|
|
180 |
# random drop class embedding during training
|
181 |
if self.training:
|
182 |
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
|
183 |
+
drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
|
184 |
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
|
185 |
|
186 |
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
|
|
|
262 |
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
|
263 |
|
264 |
# init and sample generation orders
|
265 |
+
mask = torch.ones(bsz, self.seq_len).cuda()
|
266 |
+
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
|
267 |
orders = self.sample_orders(bsz)
|
268 |
|
269 |
indices = list(range(num_iter))
|
|
|
291 |
|
292 |
# mask ratio for the next round, following MaskGIT and MAGE.
|
293 |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
|
294 |
+
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
|
295 |
|
296 |
# masks out at least one for the next iteration
|
297 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
298 |
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
|
299 |
|
300 |
# get masking for next iteration and locations to be predicted in this iteration
|