Crystalcareai commited on
Commit
21d94a3
·
verified ·
1 Parent(s): b2672e5

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +94 -76
modeling_quiet.py CHANGED
@@ -23,6 +23,7 @@ import math
23
  import copy
24
  import os
25
  import time
 
26
  import seaborn as sns
27
  import matplotlib.pyplot as plt
28
  import wandb
@@ -68,6 +69,73 @@ logger = logging.get_logger(__name__)
68
 
69
  _CONFIG_FOR_DOC = "QuietConfig"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
72
  def _get_unpad_data(attention_mask):
73
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -257,13 +325,6 @@ class QuietAttention(nn.Module):
257
  use_cache: bool = False,
258
  **kwargs,
259
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
260
-
261
- if past_key_value is not None:
262
- expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
263
- if attention_mask.size() != expected_attention_mask_size:
264
- # Assuming the attention mask is larger than expected, slice it to match the expected size
265
- attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
266
-
267
  if "padding_mask" in kwargs:
268
  warnings.warn(
269
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
@@ -277,10 +338,6 @@ class QuietAttention(nn.Module):
277
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
278
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
279
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
280
-
281
- query_states = query_states.to(attention_mask.dtype)
282
- key_states = key_states.to(attention_mask.dtype)
283
- value_states = value_states.to(attention_mask.dtype)
284
 
285
  kv_seq_len = key_states.shape[-2]
286
  if past_key_value is not None:
@@ -311,16 +368,11 @@ class QuietAttention(nn.Module):
311
  )
312
 
313
  if attention_mask is not None:
314
- if attention_mask.dim() == 3:
315
- attention_mask = attention_mask.unsqueeze(1)
316
- elif attention_mask.dim() == 2:
317
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
318
-
319
- if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
320
  raise ValueError(
321
- f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
322
  )
323
-
324
  attn_weights = attn_weights + attention_mask
325
 
326
  # upcast attention to fp32
@@ -697,21 +749,11 @@ class QuietSdpaAttention(QuietAttention):
697
  value_states = repeat_kv(value_states, self.num_key_value_groups)
698
 
699
  if attention_mask is not None:
700
- if attention_mask.dim() == 3:
701
- attention_mask = attention_mask.unsqueeze(1)
702
- elif attention_mask.dim() == 2:
703
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
704
-
705
- if attention_mask is not None:
706
- if attention_mask.dim() == 3:
707
- attention_mask = attention_mask.unsqueeze(1)
708
- elif attention_mask.dim() == 2:
709
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
710
-
711
- if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
712
  raise ValueError(
713
- f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
714
  )
 
715
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
716
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
717
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -719,12 +761,6 @@ class QuietSdpaAttention(QuietAttention):
719
  key_states = key_states.contiguous()
720
  value_states = value_states.contiguous()
721
 
722
-
723
- # Cast query_states, key_states, and value_states to the same data type as attention_mask
724
- query_states = query_states.to(attention_mask.dtype)
725
- key_states = key_states.to(attention_mask.dtype)
726
- value_states = value_states.to(attention_mask.dtype)
727
-
728
  attn_output = torch.nn.functional.scaled_dot_product_attention(
729
  query_states,
730
  key_states,
@@ -1291,28 +1327,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1291
  # Generate the continuation
1292
  continuation_length = self.n_ahead - 2
1293
  new_key_values = past_key_values
 
1294
 
1295
- if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1296
- if attention_mask is None:
1297
- base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1298
- base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1299
- base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1300
- attention_mask = base_attention_mask
1301
- elif attention_mask.dim() == 2:
1302
- if seq_len + past_key_values_length != attention_mask.shape[-1]:
1303
- attention_mask = torch.cat(
1304
- [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1305
- dim=-1
1306
- )
1307
- attention_mask = _prepare_4d_causal_attention_mask(
1308
- attention_mask,
1309
- (batch_size, seq_len),
1310
- inputs_embeds,
1311
- past_key_values_length,
1312
- sliding_window=self.config.sliding_window,
1313
- )
1314
-
1315
- start_time = time.time()
1316
  for continuation_idx in range(continuation_length):
1317
  outputs = self.model(
1318
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
@@ -1326,9 +1342,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1326
  return_dict=return_dict,
1327
  )
1328
  new_key_values = outputs.past_key_values
1329
-
1330
  hidden_states = outputs[0]
1331
-
1332
  logits = self.lm_head(hidden_states)
1333
  logits = logits[:, -1, :] # Only consider the last token
1334
 
@@ -1338,12 +1352,17 @@ class QuietForCausalLM(QuietPreTrainedModel):
1338
 
1339
  # Append the generated token to the input sequence
1340
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
 
1341
  seq_len += 1
1342
 
1343
  # Update the attention mask
1344
  if attention_mask is not None:
1345
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1346
 
 
 
 
 
1347
  # Append the end thought token to the input sequence
1348
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1349
  input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
@@ -1389,7 +1408,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
1389
 
1390
  # Apply the language model head to get the final logits
1391
  logits = self.lm_head(mixed_hidden_states)
1392
- return logits
 
 
 
 
 
1393
 
1394
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1395
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1662,9 +1686,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1662
  prev_rm_logits = rm_logits # for policy gradient
1663
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1664
 
1665
- hidden_states_lm = hidden_states
1666
- logits = self.lm_head(hidden_states_lm)
1667
-
1668
  if ahead_idx == 0:
1669
  hidden_states_lm = hidden_states
1670
  logits = self.lm_head(hidden_states_lm)
@@ -1682,16 +1703,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
1682
  assert self.no_residual
1683
  residual_logits = self.lm_head(hidden_states)
1684
  talk_hidden_states = hidden_states
1685
- if 'hidden_states_lm' not in locals():
1686
- hidden_states_lm = hidden_states
1687
- rm_hidden_states = hidden_states
1688
- if ahead_idx > self.n_ahead - 1:
1689
- cur_base_hidden = torch.cat([
1690
- base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1691
- base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1692
- ], dim=-2)
1693
  else:
1694
- cur_base_hidden = base_hidden_states
 
 
 
 
 
 
1695
 
1696
  if self.use_concat_talk_head:
1697
  # concatenate the hidden states with the original hidden states
@@ -1782,7 +1801,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1782
  if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1783
  loss_list.append(loss)
1784
  talk_loss_list.append(nonzero_mean(loss).detach())
1785
-
1786
 
1787
  if not attempted or self.comparison_mode:
1788
  rm_hidden_states = hidden_states
@@ -2366,4 +2384,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2366
  past_key_values=transformer_outputs.past_key_values,
2367
  hidden_states=transformer_outputs.hidden_states,
2368
  attentions=transformer_outputs.attentions,
2369
- )
 
23
  import copy
24
  import os
25
  import time
26
+ import pandas as pd
27
  import seaborn as sns
28
  import matplotlib.pyplot as plt
29
  import wandb
 
69
 
70
  _CONFIG_FOR_DOC = "QuietConfig"
71
 
72
+ from reportlab.pdfgen import canvas
73
+ from reportlab.lib.pagesizes import letter
74
+ from reportlab.lib.colors import HexColor
75
+
76
+ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
+ c = canvas.Canvas(output_file, pagesize=letter)
78
+ c.setFont("Courier", 8)
79
+ x, y = 50, 750
80
+ previous_text = ""
81
+ current_text = ""
82
+ for token_idx, reward in enumerate(token_rewards):
83
+ current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
+ if current_text != previous_text:
85
+ diff_text = current_text[len(previous_text) :]
86
+ if "\n" in diff_text:
87
+ lines = diff_text.split("\n")
88
+ for line_idx, line in enumerate(lines):
89
+ if line_idx > 0:
90
+ x = 50
91
+ y -= 12
92
+ if abs(reward) < eps:
93
+ opacity = 0
94
+ elif abs(reward) > eps2:
95
+ opacity = 0.8
96
+ else:
97
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
+ text_width = c.stringWidth(line)
99
+ if reward > 0:
100
+ highlight_color = HexColor("#4CCD99")
101
+ else:
102
+ highlight_color = HexColor("#FFC700")
103
+ highlight_color.alpha = opacity
104
+ c.setFillColor(highlight_color)
105
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
+ c.setFillColor(HexColor("#000000"))
107
+ c.drawString(x, y, line)
108
+ x += text_width
109
+ else:
110
+ if abs(reward) < eps:
111
+ opacity = 0
112
+ elif abs(reward) > eps2:
113
+ opacity = 0.8
114
+ else:
115
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
+ text_width = c.stringWidth(diff_text)
117
+ if reward > 0:
118
+ highlight_color = HexColor("#4CCD99")
119
+ else:
120
+ highlight_color = HexColor("#FFC700")
121
+ highlight_color.alpha = opacity
122
+ c.setFillColor(highlight_color)
123
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
+ c.setFillColor(HexColor("#000000"))
125
+ c.drawString(x, y, diff_text)
126
+ x += text_width
127
+ if x > 550:
128
+ x = 50
129
+ y -= 12
130
+ if y < 50:
131
+ c.showPage()
132
+ y = 750
133
+ x = 50
134
+ previous_text = current_text
135
+ c.showPage()
136
+ c.save()
137
+
138
+
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
325
  use_cache: bool = False,
326
  **kwargs,
327
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
328
  if "padding_mask" in kwargs:
329
  warnings.warn(
330
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
 
338
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
339
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
340
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
341
 
342
  kv_seq_len = key_states.shape[-2]
343
  if past_key_value is not None:
 
368
  )
369
 
370
  if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
372
  raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
+
376
  attn_weights = attn_weights + attention_mask
377
 
378
  # upcast attention to fp32
 
749
  value_states = repeat_kv(value_states, self.num_key_value_groups)
750
 
751
  if attention_mask is not None:
752
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
 
 
 
 
 
 
753
  raise ValueError(
754
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
755
  )
756
+
757
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
758
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
759
  if query_states.device.type == "cuda" and attention_mask is not None:
 
761
  key_states = key_states.contiguous()
762
  value_states = value_states.contiguous()
763
 
 
 
 
 
 
 
764
  attn_output = torch.nn.functional.scaled_dot_product_attention(
765
  query_states,
766
  key_states,
 
1327
  # Generate the continuation
1328
  continuation_length = self.n_ahead - 2
1329
  new_key_values = past_key_values
1330
+ generated_tokens = []
1331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1332
  for continuation_idx in range(continuation_length):
1333
  outputs = self.model(
1334
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
 
1342
  return_dict=return_dict,
1343
  )
1344
  new_key_values = outputs.past_key_values
 
1345
  hidden_states = outputs[0]
 
1346
  logits = self.lm_head(hidden_states)
1347
  logits = logits[:, -1, :] # Only consider the last token
1348
 
 
1352
 
1353
  # Append the generated token to the input sequence
1354
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1355
+ generated_tokens.append(next_token_id)
1356
  seq_len += 1
1357
 
1358
  # Update the attention mask
1359
  if attention_mask is not None:
1360
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1361
 
1362
+ # Update the position ids
1363
+ if position_ids is not None:
1364
+ position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1365
+
1366
  # Append the end thought token to the input sequence
1367
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1368
  input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
 
1408
 
1409
  # Apply the language model head to get the final logits
1410
  logits = self.lm_head(mixed_hidden_states)
1411
+
1412
+ # Decode the logits to get the generated text
1413
+ generated_tokens = torch.cat(generated_tokens, dim=-1)
1414
+ generated_text = self.tokenizer.decode(generated_tokens.squeeze(), skip_special_tokens=True)
1415
+
1416
+ return generated_text
1417
 
1418
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1419
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1686
  prev_rm_logits = rm_logits # for policy gradient
1687
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1688
 
 
 
 
1689
  if ahead_idx == 0:
1690
  hidden_states_lm = hidden_states
1691
  logits = self.lm_head(hidden_states_lm)
 
1703
  assert self.no_residual
1704
  residual_logits = self.lm_head(hidden_states)
1705
  talk_hidden_states = hidden_states
 
 
 
 
 
 
 
 
1706
  else:
1707
+ if ahead_idx > self.n_ahead - 1:
1708
+ cur_base_hidden = torch.cat([
1709
+ base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1710
+ base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1711
+ ], dim=-2)
1712
+ else:
1713
+ cur_base_hidden = base_hidden_states
1714
 
1715
  if self.use_concat_talk_head:
1716
  # concatenate the hidden states with the original hidden states
 
1801
  if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1802
  loss_list.append(loss)
1803
  talk_loss_list.append(nonzero_mean(loss).detach())
 
1804
 
1805
  if not attempted or self.comparison_mode:
1806
  rm_hidden_states = hidden_states
 
2384
  past_key_values=transformer_outputs.past_key_values,
2385
  hidden_states=transformer_outputs.hidden_states,
2386
  attentions=transformer_outputs.attentions,
2387
+ )