jadechoghari commited on
Commit
eec2810
·
verified ·
1 Parent(s): 37d0984

Update mar.py

Browse files
Files changed (1) hide show
  1. mar.py +10 -10
mar.py CHANGED
@@ -10,12 +10,12 @@ from torch.utils.checkpoint import checkpoint
10
 
11
  from timm.models.vision_transformer import Block
12
 
13
- from .diffloss import DiffLoss
 
14
 
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  def mask_by_order(mask_len, order, bsz, seq_len):
17
- masking = torch.zeros(bsz, seq_len).to(device)
18
- masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
19
  return masking
20
 
21
 
@@ -156,7 +156,7 @@ class MAR(nn.Module):
156
  order = np.array(list(range(self.seq_len)))
157
  np.random.shuffle(order)
158
  orders.append(order)
159
- orders = torch.Tensor(np.array(orders)).to(device).long()
160
  return orders
161
 
162
  def random_masking(self, x, orders):
@@ -180,7 +180,7 @@ class MAR(nn.Module):
180
  # random drop class embedding during training
181
  if self.training:
182
  drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
183
- drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(device).to(x.dtype)
184
  class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
185
 
186
  x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
@@ -262,8 +262,8 @@ class MAR(nn.Module):
262
  def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
263
 
264
  # init and sample generation orders
265
- mask = torch.ones(bsz, self.seq_len).to(device)
266
- tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(device)
267
  orders = self.sample_orders(bsz)
268
 
269
  indices = list(range(num_iter))
@@ -291,10 +291,10 @@ class MAR(nn.Module):
291
 
292
  # mask ratio for the next round, following MaskGIT and MAGE.
293
  mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
294
- mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(device)
295
 
296
  # masks out at least one for the next iteration
297
- mask_len = torch.maximum(torch.Tensor([1]).to(device),
298
  torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
299
 
300
  # get masking for next iteration and locations to be predicted in this iteration
 
10
 
11
  from timm.models.vision_transformer import Block
12
 
13
+ from models.diffloss import DiffLoss
14
+
15
 
 
16
  def mask_by_order(mask_len, order, bsz, seq_len):
17
+ masking = torch.zeros(bsz, seq_len).cuda()
18
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
19
  return masking
20
 
21
 
 
156
  order = np.array(list(range(self.seq_len)))
157
  np.random.shuffle(order)
158
  orders.append(order)
159
+ orders = torch.Tensor(np.array(orders)).cuda().long()
160
  return orders
161
 
162
  def random_masking(self, x, orders):
 
180
  # random drop class embedding during training
181
  if self.training:
182
  drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
183
+ drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
184
  class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
185
 
186
  x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
 
262
  def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
263
 
264
  # init and sample generation orders
265
+ mask = torch.ones(bsz, self.seq_len).cuda()
266
+ tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
267
  orders = self.sample_orders(bsz)
268
 
269
  indices = list(range(num_iter))
 
291
 
292
  # mask ratio for the next round, following MaskGIT and MAGE.
293
  mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
294
+ mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
295
 
296
  # masks out at least one for the next iteration
297
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
298
  torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
299
 
300
  # get masking for next iteration and locations to be predicted in this iteration