Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import io | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from nnsight import LanguageModel | |
from typing import List | |
import pandas as pd | |
# Set up the API key for nnsight | |
from nnsight import CONFIG | |
import os | |
api_key = os.getenv('NNSIGHT_API_KEY') | |
CONFIG.set_default_api_key(api_key) | |
access_token = os.environ['HUGGING_FACE_HUB_TOKEN'] | |
# Load the Language Model | |
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B", token=access_token) | |
#placeholder for reset | |
prompts_with_probs = pd.DataFrame( | |
{ | |
"prompt": [''], | |
"layer": [0], | |
"results": [''], | |
"probs": [0], | |
"expected": [''], | |
}) | |
prompts_with_ranks = pd.DataFrame( | |
{ | |
"prompt": [''], | |
"layer": [0], | |
"results": [''], | |
"ranks": [0], | |
"expected": [''], | |
}) | |
def run_lens(model,PROMPT): | |
logits_lens_token_result_by_layer = [] | |
logits_lens_probs_by_layer = [] | |
logits_lens_ranks_by_layer = [] | |
input_ids = model.tokenizer.encode(PROMPT) | |
with model.trace(input_ids, remote=True) as runner: | |
for layer_ix,layer in enumerate(model.model.layers): | |
hidden_state = layer.output[0][0] | |
logits_lens_normed_last_token = model.model.norm(hidden_state) | |
logits_lens_token_distribution = model.lm_head(logits_lens_normed_last_token) | |
logits_lens_last_token_logits = logits_lens_token_distribution[-1:] | |
logits_lens_probs = F.softmax(logits_lens_last_token_logits, dim=1).save() | |
logits_lens_probs_by_layer.append(logits_lens_probs) | |
logits_lens_next_token = torch.argmax(logits_lens_probs, dim=1).save() | |
logits_lens_token_result_by_layer.append(logits_lens_next_token) | |
tokens_out = llama.lm_head.output.argmax(dim=-1).save() | |
expected_token = tokens_out[0][-1].save() | |
# logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer]) | |
logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer]) | |
#get the rank of the expected token from each layer's distribution | |
for layer_probs in logits_lens_probs_by_layer: | |
# Sort the probabilities in descending order and find the rank of the expected token | |
sorted_probs, sorted_indices = torch.sort(layer_probs, descending=True) | |
# Find the rank of the expected token (1-based rank) | |
expected_token_rank = (sorted_indices == expected_token).nonzero(as_tuple=True)[1].item() + 1 | |
logits_lens_ranks_by_layer.append(expected_token_rank) | |
actual_output = llama.tokenizer.decode(expected_token.item()) | |
logits_lens_results = [model.tokenizer.decode(next_token.item()) for next_token in logits_lens_token_result_by_layer] | |
return logits_lens_results, logits_lens_all_probs, actual_output,logits_lens_ranks_by_layer | |
def process_file(prompts_data,file_path): | |
"""Read uploaded file and return list of prompts.""" | |
prompts = [] | |
if file_path is None: | |
return prompts | |
if file_path.endswith('.csv'): | |
# Process CSV file | |
df = pd.read_csv(file_path) | |
if 'Prompt' in df.columns: | |
prompts = df[['Prompt']].dropna().values.tolist() | |
# Read the file as text and split into lines (one prompt per line) | |
else: | |
with open(file_path, 'r') as file: | |
prompts = [[line] for line in file.read().splitlines()] | |
for prompt in prompts_data: | |
if prompt==['']: | |
continue | |
else: | |
prompts.append(prompt) | |
return prompts | |
def plot_prob(prompts_with_probs): | |
plt.figure(figsize=(10, 6)) | |
# Iterate over each prompt and plot its probabilities | |
for prompt in prompts_with_probs['prompt'].unique(): | |
# Filter the DataFrame for the current prompt | |
prompt_data = prompts_with_probs[prompts_with_probs['prompt'] == prompt] | |
# Plot probabilities for this prompt | |
plt.plot(prompt_data['layer'], prompt_data['probs'], marker='x', label=prompt) | |
# Annotate each point with the corresponding result | |
for layer, prob, result in zip(prompt_data['layer'], prompt_data['probs'], prompt_data['results']): | |
plt.text(layer, prob, result, fontsize=8) | |
# Add labels and title | |
plt.xlabel('Layer Number') | |
plt.ylabel('Probability of Expected Token') | |
plt.title('Prob of expected token across layers\n(annotated with actual decoded output at each layer)') | |
plt.grid(True) | |
plt.ylim(0.0, 1.0) | |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1) | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close() # Close the figure to free memory | |
return img | |
def plot_rank(prompts_with_ranks): | |
plt.figure(figsize=(10, 6)) | |
# Iterate over each prompt and plot its ranks | |
for prompt in prompts_with_ranks['prompt'].unique(): | |
# Filter the DataFrame for the current prompt | |
prompt_data = prompts_with_ranks[prompts_with_ranks['prompt'] == prompt] | |
# Plot ranks for this prompt | |
plt.plot(prompt_data['layer'], prompt_data['ranks'], marker='x', label=prompt) | |
# Annotate each point with the corresponding result | |
for layer, rank, result in zip(prompt_data['layer'], prompt_data['ranks'], prompt_data['results']): | |
plt.text(layer, rank,result, ha='right', va='bottom', fontsize=8) | |
# Add labels and title | |
plt.xlabel('Layer Number') | |
plt.ylabel('Rank of Expected Token') | |
plt.title('Rank of expected token across layers\n(annotated with decoded output at each layer)') | |
plt.grid(True) | |
plt.ylim(bottom=0) # Adjust if needed, depending on your rank values | |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1) | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close() # Close the figure to free memory | |
return img | |
def plot_prob_mean(prompts_with_probs): | |
# Calculate mean probabilities and variance | |
summary_stats = prompts_with_probs.groupby("prompt")["probs"].agg( | |
mean_prob="mean", | |
variance="var" | |
).reset_index() | |
# Set up the bar plot | |
plt.figure(figsize=(10, 6)) | |
bars = plt.bar(summary_stats['prompt'], summary_stats['mean_prob'], | |
yerr=summary_stats['variance']**0.5, # Error bars are the standard deviation | |
capsize=5, color='skyblue') | |
# Add labels and title | |
plt.xlabel('Prompt') | |
plt.ylabel('Mean Probability') | |
plt.title('Mean Probability of Expected Token') | |
plt.xticks(rotation=45, ha='right') | |
plt.grid(axis='y') | |
plt.ylim(0, 1) | |
# Annotate the mean and variance on the bars | |
for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']): | |
yval = bar.get_height() | |
plt.text(bar.get_x() + bar.get_width() / 2, yval, f'Mean: {mean:.2f}\nVar: {var:.2f}', | |
ha='center', va='bottom', fontsize=8, color='black') | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close() # Close the figure to free memory | |
return img | |
def plot_rank_mean(prompts_with_ranks): | |
# Calculate mean ranks and variance | |
summary_stats = prompts_with_ranks.groupby("prompt")["ranks"].agg( | |
mean_rank="mean", | |
variance="var" | |
).reset_index() | |
# Set up the bar plot | |
plt.figure(figsize=(10, 6)) | |
bars = plt.bar(summary_stats['prompt'], summary_stats['mean_rank'], | |
yerr=summary_stats['variance']**0.5, # Error bars are the standard deviation | |
capsize=5, color='salmon') | |
# Add labels and title | |
plt.xlabel('Prompt') | |
plt.ylabel('Mean Rank') | |
plt.title('Mean Rank of Expected Token') | |
plt.xticks(rotation=45, ha='right') | |
plt.grid(axis='y') | |
# Annotate the mean and variance on the bars | |
for bar, mean, var in zip(bars, summary_stats['mean_rank'], summary_stats['variance']): | |
yval = bar.get_height() | |
plt.text(bar.get_x() + bar.get_width() / 2, yval, f'Mean: {mean:.2f}\nVar: {var:.2f}', | |
ha='center', va='bottom', fontsize=8, color='black') | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close() # Close the figure to free memory | |
return img | |
def submit_prompts(prompts_data): | |
# Initialize lists to accumulate results | |
all_prompts = [] | |
all_results = [] | |
all_probs = [] | |
all_expected = [] | |
all_layers = [] | |
all_ranks = [] | |
# Iterate over each prompt | |
for prompt in prompts_data: | |
# If a prompt is an empty string, skip it | |
prompt = prompt[0] | |
if not prompt: | |
continue | |
# Run the lens model on the prompt | |
lens_output = run_lens(llama, prompt) | |
# Accumulate results for each layer | |
for layer_idx in range(len(lens_output[1])): | |
all_prompts.append(prompt) | |
all_results.append(lens_output[0][layer_idx]) | |
all_probs.append(float(lens_output[1][layer_idx])) | |
all_expected.append(lens_output[2]) | |
all_layers.append(int(layer_idx)) | |
all_ranks.append(int(lens_output[3][layer_idx])) | |
# Create DataFrame from accumulated results | |
prompts_with_probs = pd.DataFrame( | |
{ | |
"prompt": all_prompts, | |
"layer": all_layers, | |
"results": all_results, | |
"probs": all_probs, | |
"expected": all_expected, | |
}) | |
prompts_with_ranks = pd.DataFrame( | |
{ | |
"prompt": all_prompts, | |
"layer": all_layers, | |
"results": all_results, | |
"ranks": all_ranks, | |
"expected": all_expected, | |
}) | |
return plot_prob(prompts_with_probs), plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks) | |
def clear_all(prompts): | |
prompts=[['']] | |
# prompt_file=gr.File(type="filepath", label="Upload a File with Prompts") | |
prompt_file = None | |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True) | |
return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks) | |
def gradio_interface(): | |
with gr.Blocks(theme="gradio/monochrome") as demo: | |
prompts=[['']] | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True) | |
with gr.Column(scale=1): | |
prompt_file=gr.File(type="filepath", label="Upload a File with Prompts") | |
prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data]) | |
# Define the outputs | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
submit_btn = gr.Button("Submit") | |
with gr.Row(): | |
prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ") | |
rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ") | |
with gr.Row(): | |
prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ") | |
rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ") | |
clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization]) | |
submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])# | |
prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization]) | |
demo.launch() | |
gradio_interface() |