ShortGpt / short_gpt /metrics.py
deepnet's picture
Upload folder using huggingface_hub
821537b verified
raw
history blame contribute delete
717 Bytes
import torch
def block_influence(
input_hidden_state: torch.Tensor,
output_hidden_state: torch.Tensor,
angular=False,
):
"""
input_hidden_state: B, S, D
output_hidden_state: B, S, D
"""
_, _, d = input_hidden_state.shape
input_hidden_state = input_hidden_state.reshape(-1, d)
output_hidden_state = output_hidden_state.reshape(-1, d)
norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
norm_output = output_hidden_state.norm(dim=-1, keepdim=True)
sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
sim = sim.diagonal().nan_to_num(nan=0.5)
if angular:
return (torch.arccos(sim) / torch.pi)
return 1 - sim