Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +94 -76
modeling_quiet.py
CHANGED
@@ -23,6 +23,7 @@ import math
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
|
|
26 |
import seaborn as sns
|
27 |
import matplotlib.pyplot as plt
|
28 |
import wandb
|
@@ -68,6 +69,73 @@ logger = logging.get_logger(__name__)
|
|
68 |
|
69 |
_CONFIG_FOR_DOC = "QuietConfig"
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
72 |
def _get_unpad_data(attention_mask):
|
73 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -257,13 +325,6 @@ class QuietAttention(nn.Module):
|
|
257 |
use_cache: bool = False,
|
258 |
**kwargs,
|
259 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
260 |
-
|
261 |
-
if past_key_value is not None:
|
262 |
-
expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
263 |
-
if attention_mask.size() != expected_attention_mask_size:
|
264 |
-
# Assuming the attention mask is larger than expected, slice it to match the expected size
|
265 |
-
attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
|
266 |
-
|
267 |
if "padding_mask" in kwargs:
|
268 |
warnings.warn(
|
269 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
@@ -277,10 +338,6 @@ class QuietAttention(nn.Module):
|
|
277 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
278 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
279 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
280 |
-
|
281 |
-
query_states = query_states.to(attention_mask.dtype)
|
282 |
-
key_states = key_states.to(attention_mask.dtype)
|
283 |
-
value_states = value_states.to(attention_mask.dtype)
|
284 |
|
285 |
kv_seq_len = key_states.shape[-2]
|
286 |
if past_key_value is not None:
|
@@ -311,16 +368,11 @@ class QuietAttention(nn.Module):
|
|
311 |
)
|
312 |
|
313 |
if attention_mask is not None:
|
314 |
-
if attention_mask.
|
315 |
-
attention_mask = attention_mask.unsqueeze(1)
|
316 |
-
elif attention_mask.dim() == 2:
|
317 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
318 |
-
|
319 |
-
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
320 |
raise ValueError(
|
321 |
-
f"Attention mask should be of size (
|
322 |
)
|
323 |
-
|
324 |
attn_weights = attn_weights + attention_mask
|
325 |
|
326 |
# upcast attention to fp32
|
@@ -697,21 +749,11 @@ class QuietSdpaAttention(QuietAttention):
|
|
697 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
698 |
|
699 |
if attention_mask is not None:
|
700 |
-
if attention_mask.
|
701 |
-
attention_mask = attention_mask.unsqueeze(1)
|
702 |
-
elif attention_mask.dim() == 2:
|
703 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
704 |
-
|
705 |
-
if attention_mask is not None:
|
706 |
-
if attention_mask.dim() == 3:
|
707 |
-
attention_mask = attention_mask.unsqueeze(1)
|
708 |
-
elif attention_mask.dim() == 2:
|
709 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
710 |
-
|
711 |
-
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
712 |
raise ValueError(
|
713 |
-
f"Attention mask should be of size (
|
714 |
)
|
|
|
715 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
716 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
717 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -719,12 +761,6 @@ class QuietSdpaAttention(QuietAttention):
|
|
719 |
key_states = key_states.contiguous()
|
720 |
value_states = value_states.contiguous()
|
721 |
|
722 |
-
|
723 |
-
# Cast query_states, key_states, and value_states to the same data type as attention_mask
|
724 |
-
query_states = query_states.to(attention_mask.dtype)
|
725 |
-
key_states = key_states.to(attention_mask.dtype)
|
726 |
-
value_states = value_states.to(attention_mask.dtype)
|
727 |
-
|
728 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
729 |
query_states,
|
730 |
key_states,
|
@@ -1291,28 +1327,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1291 |
# Generate the continuation
|
1292 |
continuation_length = self.n_ahead - 2
|
1293 |
new_key_values = past_key_values
|
|
|
1294 |
|
1295 |
-
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1296 |
-
if attention_mask is None:
|
1297 |
-
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
1298 |
-
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1299 |
-
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1300 |
-
attention_mask = base_attention_mask
|
1301 |
-
elif attention_mask.dim() == 2:
|
1302 |
-
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1303 |
-
attention_mask = torch.cat(
|
1304 |
-
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1305 |
-
dim=-1
|
1306 |
-
)
|
1307 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1308 |
-
attention_mask,
|
1309 |
-
(batch_size, seq_len),
|
1310 |
-
inputs_embeds,
|
1311 |
-
past_key_values_length,
|
1312 |
-
sliding_window=self.config.sliding_window,
|
1313 |
-
)
|
1314 |
-
|
1315 |
-
start_time = time.time()
|
1316 |
for continuation_idx in range(continuation_length):
|
1317 |
outputs = self.model(
|
1318 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
@@ -1326,9 +1342,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1326 |
return_dict=return_dict,
|
1327 |
)
|
1328 |
new_key_values = outputs.past_key_values
|
1329 |
-
|
1330 |
hidden_states = outputs[0]
|
1331 |
-
|
1332 |
logits = self.lm_head(hidden_states)
|
1333 |
logits = logits[:, -1, :] # Only consider the last token
|
1334 |
|
@@ -1338,12 +1352,17 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1338 |
|
1339 |
# Append the generated token to the input sequence
|
1340 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
|
|
1341 |
seq_len += 1
|
1342 |
|
1343 |
# Update the attention mask
|
1344 |
if attention_mask is not None:
|
1345 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1346 |
|
|
|
|
|
|
|
|
|
1347 |
# Append the end thought token to the input sequence
|
1348 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1349 |
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
@@ -1389,7 +1408,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1389 |
|
1390 |
# Apply the language model head to get the final logits
|
1391 |
logits = self.lm_head(mixed_hidden_states)
|
1392 |
-
|
|
|
|
|
|
|
|
|
|
|
1393 |
|
1394 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1395 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@@ -1662,9 +1686,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1662 |
prev_rm_logits = rm_logits # for policy gradient
|
1663 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1664 |
|
1665 |
-
hidden_states_lm = hidden_states
|
1666 |
-
logits = self.lm_head(hidden_states_lm)
|
1667 |
-
|
1668 |
if ahead_idx == 0:
|
1669 |
hidden_states_lm = hidden_states
|
1670 |
logits = self.lm_head(hidden_states_lm)
|
@@ -1682,16 +1703,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1682 |
assert self.no_residual
|
1683 |
residual_logits = self.lm_head(hidden_states)
|
1684 |
talk_hidden_states = hidden_states
|
1685 |
-
if 'hidden_states_lm' not in locals():
|
1686 |
-
hidden_states_lm = hidden_states
|
1687 |
-
rm_hidden_states = hidden_states
|
1688 |
-
if ahead_idx > self.n_ahead - 1:
|
1689 |
-
cur_base_hidden = torch.cat([
|
1690 |
-
base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
|
1691 |
-
base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
|
1692 |
-
], dim=-2)
|
1693 |
else:
|
1694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1695 |
|
1696 |
if self.use_concat_talk_head:
|
1697 |
# concatenate the hidden states with the original hidden states
|
@@ -1782,7 +1801,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1782 |
if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
|
1783 |
loss_list.append(loss)
|
1784 |
talk_loss_list.append(nonzero_mean(loss).detach())
|
1785 |
-
|
1786 |
|
1787 |
if not attempted or self.comparison_mode:
|
1788 |
rm_hidden_states = hidden_states
|
@@ -2366,4 +2384,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
2366 |
past_key_values=transformer_outputs.past_key_values,
|
2367 |
hidden_states=transformer_outputs.hidden_states,
|
2368 |
attentions=transformer_outputs.attentions,
|
2369 |
-
)
|
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
26 |
+
import pandas as pd
|
27 |
import seaborn as sns
|
28 |
import matplotlib.pyplot as plt
|
29 |
import wandb
|
|
|
69 |
|
70 |
_CONFIG_FOR_DOC = "QuietConfig"
|
71 |
|
72 |
+
from reportlab.pdfgen import canvas
|
73 |
+
from reportlab.lib.pagesizes import letter
|
74 |
+
from reportlab.lib.colors import HexColor
|
75 |
+
|
76 |
+
def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
|
77 |
+
c = canvas.Canvas(output_file, pagesize=letter)
|
78 |
+
c.setFont("Courier", 8)
|
79 |
+
x, y = 50, 750
|
80 |
+
previous_text = ""
|
81 |
+
current_text = ""
|
82 |
+
for token_idx, reward in enumerate(token_rewards):
|
83 |
+
current_text = tokenizer.decode(input_ids[: token_idx + 1])
|
84 |
+
if current_text != previous_text:
|
85 |
+
diff_text = current_text[len(previous_text) :]
|
86 |
+
if "\n" in diff_text:
|
87 |
+
lines = diff_text.split("\n")
|
88 |
+
for line_idx, line in enumerate(lines):
|
89 |
+
if line_idx > 0:
|
90 |
+
x = 50
|
91 |
+
y -= 12
|
92 |
+
if abs(reward) < eps:
|
93 |
+
opacity = 0
|
94 |
+
elif abs(reward) > eps2:
|
95 |
+
opacity = 0.8
|
96 |
+
else:
|
97 |
+
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
98 |
+
text_width = c.stringWidth(line)
|
99 |
+
if reward > 0:
|
100 |
+
highlight_color = HexColor("#4CCD99")
|
101 |
+
else:
|
102 |
+
highlight_color = HexColor("#FFC700")
|
103 |
+
highlight_color.alpha = opacity
|
104 |
+
c.setFillColor(highlight_color)
|
105 |
+
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
106 |
+
c.setFillColor(HexColor("#000000"))
|
107 |
+
c.drawString(x, y, line)
|
108 |
+
x += text_width
|
109 |
+
else:
|
110 |
+
if abs(reward) < eps:
|
111 |
+
opacity = 0
|
112 |
+
elif abs(reward) > eps2:
|
113 |
+
opacity = 0.8
|
114 |
+
else:
|
115 |
+
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
116 |
+
text_width = c.stringWidth(diff_text)
|
117 |
+
if reward > 0:
|
118 |
+
highlight_color = HexColor("#4CCD99")
|
119 |
+
else:
|
120 |
+
highlight_color = HexColor("#FFC700")
|
121 |
+
highlight_color.alpha = opacity
|
122 |
+
c.setFillColor(highlight_color)
|
123 |
+
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
124 |
+
c.setFillColor(HexColor("#000000"))
|
125 |
+
c.drawString(x, y, diff_text)
|
126 |
+
x += text_width
|
127 |
+
if x > 550:
|
128 |
+
x = 50
|
129 |
+
y -= 12
|
130 |
+
if y < 50:
|
131 |
+
c.showPage()
|
132 |
+
y = 750
|
133 |
+
x = 50
|
134 |
+
previous_text = current_text
|
135 |
+
c.showPage()
|
136 |
+
c.save()
|
137 |
+
|
138 |
+
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
325 |
use_cache: bool = False,
|
326 |
**kwargs,
|
327 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
if "padding_mask" in kwargs:
|
329 |
warnings.warn(
|
330 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
|
338 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
339 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
340 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
341 |
|
342 |
kv_seq_len = key_states.shape[-2]
|
343 |
if past_key_value is not None:
|
|
|
368 |
)
|
369 |
|
370 |
if attention_mask is not None:
|
371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
|
|
|
|
|
|
|
372 |
raise ValueError(
|
373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
)
|
375 |
+
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
# upcast attention to fp32
|
|
|
749 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
750 |
|
751 |
if attention_mask is not None:
|
752 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
raise ValueError(
|
754 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
755 |
)
|
756 |
+
|
757 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
758 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
759 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
761 |
key_states = key_states.contiguous()
|
762 |
value_states = value_states.contiguous()
|
763 |
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
765 |
query_states,
|
766 |
key_states,
|
|
|
1327 |
# Generate the continuation
|
1328 |
continuation_length = self.n_ahead - 2
|
1329 |
new_key_values = past_key_values
|
1330 |
+
generated_tokens = []
|
1331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1332 |
for continuation_idx in range(continuation_length):
|
1333 |
outputs = self.model(
|
1334 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
|
|
1342 |
return_dict=return_dict,
|
1343 |
)
|
1344 |
new_key_values = outputs.past_key_values
|
|
|
1345 |
hidden_states = outputs[0]
|
|
|
1346 |
logits = self.lm_head(hidden_states)
|
1347 |
logits = logits[:, -1, :] # Only consider the last token
|
1348 |
|
|
|
1352 |
|
1353 |
# Append the generated token to the input sequence
|
1354 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
1355 |
+
generated_tokens.append(next_token_id)
|
1356 |
seq_len += 1
|
1357 |
|
1358 |
# Update the attention mask
|
1359 |
if attention_mask is not None:
|
1360 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1361 |
|
1362 |
+
# Update the position ids
|
1363 |
+
if position_ids is not None:
|
1364 |
+
position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
|
1365 |
+
|
1366 |
# Append the end thought token to the input sequence
|
1367 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1368 |
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
|
|
1408 |
|
1409 |
# Apply the language model head to get the final logits
|
1410 |
logits = self.lm_head(mixed_hidden_states)
|
1411 |
+
|
1412 |
+
# Decode the logits to get the generated text
|
1413 |
+
generated_tokens = torch.cat(generated_tokens, dim=-1)
|
1414 |
+
generated_text = self.tokenizer.decode(generated_tokens.squeeze(), skip_special_tokens=True)
|
1415 |
+
|
1416 |
+
return generated_text
|
1417 |
|
1418 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1419 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1686 |
prev_rm_logits = rm_logits # for policy gradient
|
1687 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1688 |
|
|
|
|
|
|
|
1689 |
if ahead_idx == 0:
|
1690 |
hidden_states_lm = hidden_states
|
1691 |
logits = self.lm_head(hidden_states_lm)
|
|
|
1703 |
assert self.no_residual
|
1704 |
residual_logits = self.lm_head(hidden_states)
|
1705 |
talk_hidden_states = hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1706 |
else:
|
1707 |
+
if ahead_idx > self.n_ahead - 1:
|
1708 |
+
cur_base_hidden = torch.cat([
|
1709 |
+
base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
|
1710 |
+
base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
|
1711 |
+
], dim=-2)
|
1712 |
+
else:
|
1713 |
+
cur_base_hidden = base_hidden_states
|
1714 |
|
1715 |
if self.use_concat_talk_head:
|
1716 |
# concatenate the hidden states with the original hidden states
|
|
|
1801 |
if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
|
1802 |
loss_list.append(loss)
|
1803 |
talk_loss_list.append(nonzero_mean(loss).detach())
|
|
|
1804 |
|
1805 |
if not attempted or self.comparison_mode:
|
1806 |
rm_hidden_states = hidden_states
|
|
|
2384 |
past_key_values=transformer_outputs.past_key_values,
|
2385 |
hidden_states=transformer_outputs.hidden_states,
|
2386 |
attentions=transformer_outputs.attentions,
|
2387 |
+
)
|