Hugo Flores commited on
Commit
9439b64
·
1 Parent(s): b862275

fix sampling logic for paella

Browse files
Files changed (1) hide show
  1. vampnet/modules/base.py +13 -18
vampnet/modules/base.py CHANGED
@@ -153,7 +153,6 @@ class VampBase(at.ml.BaseModel):
153
  sampling_steps: int = 12,
154
  start_tokens: Optional[torch.Tensor] = None,
155
  mask: Optional[torch.Tensor] = None,
156
- device: str = "cpu",
157
  temperature: Union[float, Tuple[float, float]] = 1.0,
158
  top_k: int = None,
159
  sample: str = "gumbel",
@@ -164,7 +163,8 @@ class VampBase(at.ml.BaseModel):
164
  typical_min_tokens=1,
165
  return_signal=True,
166
  ):
167
- r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(device)
 
168
  if renoise_steps == None:
169
  renoise_steps = sampling_steps - 1
170
 
@@ -186,7 +186,7 @@ class VampBase(at.ml.BaseModel):
186
  if self.noise_mode == "noise":
187
  z = torch.randint(
188
  0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
189
- ).to(device)
190
  elif self.noise_mode == "mask":
191
  z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
192
  else:
@@ -197,19 +197,14 @@ class VampBase(at.ml.BaseModel):
197
  assert z.shape[0] == 1, f"batch size must be 1"
198
 
199
  if mask is None:
200
- mask = torch.ones(z.shape[0], z.shape[-1]).to(device).int()
201
-
202
- # apply mask
203
- assert mask.shape == (
204
- z.shape[0],
205
- z.shape[-1],
206
- ), f"mask must be shape (batch, seq_len), got {mask.shape}"
207
- mask = mask[:, None, :]
208
- mask = mask.repeat(1, z.shape[1], 1)
209
  mask[:, : self.n_conditioning_codebooks, :] = 0.0
210
 
211
- if self.noise_mode == "mask":
212
- z_true = z.clone()
213
 
214
  z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
215
  z_init = z.clone()
@@ -228,8 +223,8 @@ class VampBase(at.ml.BaseModel):
228
 
229
  z = self.sample_from_logits(
230
  logits,
231
- tmpt,
232
- top_k,
233
  sample=sample,
234
  typical_filtering=typical_filtering,
235
  typical_mass=typical_mass,
@@ -323,7 +318,7 @@ class VampBase(at.ml.BaseModel):
323
  # how many codebooks are we inferring vs conditioning on?
324
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
325
 
326
- for i in tqdm(range(sampling_steps)):
327
  # our current temperature
328
  tmpt = temperature[i]
329
 
@@ -450,7 +445,7 @@ class VampBase(at.ml.BaseModel):
450
  probs = torch.softmax(logits, dim=-1)
451
  inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
452
  elif sample == "argmax":
453
- inferred = torch.softmax(probs, dim=-1).argmax(dim=-1)
454
  elif sample == "gumbel":
455
  inferred = gumbel_sample(logits, dim=-1)
456
  else:
 
153
  sampling_steps: int = 12,
154
  start_tokens: Optional[torch.Tensor] = None,
155
  mask: Optional[torch.Tensor] = None,
 
156
  temperature: Union[float, Tuple[float, float]] = 1.0,
157
  top_k: int = None,
158
  sample: str = "gumbel",
 
163
  typical_min_tokens=1,
164
  return_signal=True,
165
  ):
166
+
167
+ r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
168
  if renoise_steps == None:
169
  renoise_steps = sampling_steps - 1
170
 
 
186
  if self.noise_mode == "noise":
187
  z = torch.randint(
188
  0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
189
+ ).to(self.device)
190
  elif self.noise_mode == "mask":
191
  z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
192
  else:
 
197
  assert z.shape[0] == 1, f"batch size must be 1"
198
 
199
  if mask is None:
200
+ mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
201
+ mask = mask[:, None, :]
202
+ mask = mask.repeat(1, z.shape[1], 1)
203
+
 
 
 
 
 
204
  mask[:, : self.n_conditioning_codebooks, :] = 0.0
205
 
206
+
207
+ z_true = z.clone()
208
 
209
  z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
210
  z_init = z.clone()
 
223
 
224
  z = self.sample_from_logits(
225
  logits,
226
+ top_k=top_k,
227
+ temperature=tmpt,
228
  sample=sample,
229
  typical_filtering=typical_filtering,
230
  typical_mass=typical_mass,
 
318
  # how many codebooks are we inferring vs conditioning on?
319
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
320
 
321
+ for i in range(sampling_steps):
322
  # our current temperature
323
  tmpt = temperature[i]
324
 
 
445
  probs = torch.softmax(logits, dim=-1)
446
  inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
447
  elif sample == "argmax":
448
+ inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
449
  elif sample == "gumbel":
450
  inferred = gumbel_sample(logits, dim=-1)
451
  else: