attention-tracker / models /attn_model.py
pinyuchen's picture
create app (#1)
d62afec verified
import torch
from .model import Model
from .utils import sample_token, get_last_attn
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class AttentionModel(Model):
def __init__(self, config):
super().__init__(config)
self.name = config["model_info"]["name"]
self.max_output_tokens = int(config["params"]["max_output_tokens"])
model_id = config["model_info"]["model_id"]
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation="eager"
).eval()
if config["params"]["important_heads"] == "all":
attn_size = self.get_map_dim()
self.important_heads = [[i, j] for i in range(
attn_size[0]) for j in range(attn_size[1])]
else:
self.important_heads = config["params"]["important_heads"]
self.top_k = 50
self.top_p = None
def get_map_dim(self):
_, _, attention_maps, _, _, _ = self.inference("print hi", "")
attention_map = attention_maps[0]
return len(attention_map), attention_map[0].shape[1]
# def query(self, msg, return_type="normal", max_output_tokens=None):
# text_split = msg.split('\nText: ')
# instruction, data = text_split[0], text_split[1]
# response, output_tokens, attention_maps, tokens, input_range, generated_probs = self.inference(
# instruction, data, max_output_tokens=max_output_tokens)
# if return_type == "attention":
# return response, output_tokens, attention_maps, tokens, input_range, generated_probs
# else:
# return response
def inference(self, instruction, data, max_output_tokens=None):
messages = [
{"role": "system", "content": instruction},
{"role": "user", "content": "\nText: " + data}
]
# Use tokenization with minimal overhead
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
instruction_len = len(self.tokenizer.encode(instruction))
data_len = len(self.tokenizer.encode(data))
model_inputs = self.tokenizer(
[text], return_tensors="pt").to(self.model.device)
input_tokens = self.tokenizer.convert_ids_to_tokens(
model_inputs['input_ids'][0])
if "qwen-attn" in self.name:
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
elif "phi3-attn" in self.name:
data_range = ((1, 1+instruction_len), (-2-data_len, -2))
elif "llama2-13b" in self.name or "llama3-8b" in self.name:
data_range = ((5, 5+instruction_len), (-5-data_len, -5))
else:
raise NotImplementedError
generated_tokens = []
generated_probs = []
input_ids = model_inputs.input_ids
attention_mask = model_inputs.attention_mask
attention_maps = []
if max_output_tokens != None:
n_tokens = max_output_tokens
else:
n_tokens = self.max_output_tokens
with torch.no_grad():
for i in range(n_tokens):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True
)
logits = output.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# next_token_id = logits.argmax(dim=-1).squeeze()
next_token_id = sample_token(
logits[0], top_k=self.top_k, top_p=self.top_p, temperature=1.0)[0]
generated_probs.append(probs[0, next_token_id.item()].item())
generated_tokens.append(next_token_id.item())
if next_token_id.item() == self.tokenizer.eos_token_id:
break
input_ids = torch.cat(
(input_ids, next_token_id.unsqueeze(0).unsqueeze(0)), dim=-1)
attention_mask = torch.cat(
(attention_mask, torch.tensor([[1]], device=input_ids.device)), dim=-1)
attention_map = [attention.detach().cpu().half()
for attention in output['attentions']]
attention_map = [torch.nan_to_num(
attention, nan=0.0) for attention in attention_map]
attention_map = get_last_attn(attention_map)
attention_maps.append(attention_map)
output_tokens = [self.tokenizer.decode(
token, skip_special_tokens=True) for token in generated_tokens]
generated_text = self.tokenizer.decode(
generated_tokens, skip_special_tokens=True)
return generated_text, output_tokens, attention_maps, input_tokens, data_range, generated_probs