Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,250 Bytes
d62afec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
import numpy as np
def process_attn(attention, rng, attn_func):
heatmap = np.zeros((len(attention), attention[0].shape[1]))
for i, attn_layer in enumerate(attention):
attn_layer = attn_layer.to(torch.float32).numpy()
if "sum" in attn_func:
last_token_attn_to_inst = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
attn = last_token_attn_to_inst
elif "max" in attn_func:
last_token_attn_to_inst = np.max(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
attn = last_token_attn_to_inst
else: raise NotImplementedError
last_token_attn_to_inst_sum = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
last_token_attn_to_data_sum = np.sum(attn_layer[0, :, -1, rng[1][0]:rng[1][1]], axis=1)
if "normalize" in attn_func:
epsilon = 1e-8
heatmap[i, :] = attn / (last_token_attn_to_inst_sum + last_token_attn_to_data_sum + epsilon)
else:
heatmap[i, :] = attn
heatmap = np.nan_to_num(heatmap, nan=0.0)
return heatmap
def calc_attn_score(heatmap, heads):
score = np.mean([heatmap[l, h] for l, h in heads], axis=0)
return score
|