natolambert
commited on
Commit
·
0b8c16d
1
Parent(s):
ab74236
upload plot
Browse files- app.py +6 -1
- src/plt.py +53 -0
- src/utils.py +12 -0
app.py
CHANGED
@@ -5,6 +5,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
5 |
from datasets import load_dataset
|
6 |
from src.utils import load_all_data
|
7 |
from src.md import ABOUT_TEXT, TOP_TEXT
|
|
|
8 |
import numpy as np
|
9 |
|
10 |
api = HfApi()
|
@@ -210,7 +211,11 @@ with gr.Blocks() as app:
|
|
210 |
sample_display = gr.Markdown("{sampled data loads here}")
|
211 |
|
212 |
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
|
215 |
# Load data when app starts, TODO make this used somewhere...
|
216 |
# def load_data_on_start():
|
|
|
5 |
from datasets import load_dataset
|
6 |
from src.utils import load_all_data
|
7 |
from src.md import ABOUT_TEXT, TOP_TEXT
|
8 |
+
from src.plt import plot_avg_correlation
|
9 |
import numpy as np
|
10 |
|
11 |
api = HfApi()
|
|
|
211 |
sample_display = gr.Markdown("{sampled data loads here}")
|
212 |
|
213 |
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
|
214 |
+
# removed plot because not pretty enough
|
215 |
+
# with gr.TabItem("Model Correlation"):
|
216 |
+
# with gr.Row():
|
217 |
+
# plot = plot_avg_correlation(herm_data_avg, prefs_data)
|
218 |
+
# gr.Plot(plot)
|
219 |
|
220 |
# Load data when app starts, TODO make this used somewhere...
|
221 |
# def load_data_on_start():
|
src/plt.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import pandas as pd
|
3 |
+
from .utils import undo_hyperlink
|
4 |
+
|
5 |
+
def plot_avg_correlation(df1, df2):
|
6 |
+
"""
|
7 |
+
Plots the "average" column for each unique model that appears in both dataframes.
|
8 |
+
|
9 |
+
Parameters:
|
10 |
+
- df1: pandas DataFrame containing columns "model" and "average".
|
11 |
+
- df2: pandas DataFrame containing columns "model" and "average".
|
12 |
+
"""
|
13 |
+
# Identify the unique models that appear in both DataFrames
|
14 |
+
common_models = pd.Series(list(set(df1['model']) & set(df2['model'])))
|
15 |
+
|
16 |
+
# Set up the plot
|
17 |
+
plt.figure(figsize=(13, 6), constrained_layout=True)
|
18 |
+
|
19 |
+
# axes from 0 to 1 for x and y
|
20 |
+
plt.xlim(0.475, 0.8)
|
21 |
+
plt.ylim(0.475, 0.8)
|
22 |
+
|
23 |
+
# larger font (16)
|
24 |
+
plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14})
|
25 |
+
# plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
|
26 |
+
# plt.tight_layout()
|
27 |
+
# plt.margins(0,0)
|
28 |
+
|
29 |
+
for model in common_models:
|
30 |
+
# Filter data for the current model
|
31 |
+
df1_model_data = df1[df1['model'] == model]['average'].values
|
32 |
+
df2_model_data = df2[df2['model'] == model]['average'].values
|
33 |
+
|
34 |
+
# Plotting
|
35 |
+
plt.scatter(df1_model_data, df2_model_data, label=model)
|
36 |
+
m_name = undo_hyperlink(model)
|
37 |
+
if m_name == "No text found":
|
38 |
+
m_name = "Random"
|
39 |
+
# Add text above each point like
|
40 |
+
# plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
|
41 |
+
plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center')
|
42 |
+
|
43 |
+
# add correlation line to scatter plot
|
44 |
+
# first, compute correlation
|
45 |
+
corr = df1['average'].corr(df2['average'])
|
46 |
+
# add correlation line based on corr
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
plt.xlabel('HERM Eval. Set Avg.', fontsize=16)
|
51 |
+
plt.ylabel('Pref. Test Sets Avg.', fontsize=16)
|
52 |
+
# plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
|
53 |
+
return plt
|
src/utils.py
CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
3 |
from datasets import load_dataset
|
4 |
import numpy as np
|
5 |
import os
|
|
|
6 |
|
7 |
# From Open LLM Leaderboard
|
8 |
def model_hyperlink(link, model_name):
|
@@ -10,6 +11,17 @@ def model_hyperlink(link, model_name):
|
|
10 |
return "random"
|
11 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Define a function to fetch and process data
|
14 |
def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
|
15 |
dir = Path(data_repo)
|
|
|
3 |
from datasets import load_dataset
|
4 |
import numpy as np
|
5 |
import os
|
6 |
+
import re
|
7 |
|
8 |
# From Open LLM Leaderboard
|
9 |
def model_hyperlink(link, model_name):
|
|
|
11 |
return "random"
|
12 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
13 |
|
14 |
+
def undo_hyperlink(html_string):
|
15 |
+
# Regex pattern to match content inside > and <
|
16 |
+
pattern = r'>[^<]+<'
|
17 |
+
match = re.search(pattern, html_string)
|
18 |
+
if match:
|
19 |
+
# Extract the matched text and remove leading '>' and trailing '<'
|
20 |
+
return match.group(0)[1:-1]
|
21 |
+
else:
|
22 |
+
return "No text found"
|
23 |
+
|
24 |
+
|
25 |
# Define a function to fetch and process data
|
26 |
def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to pull the git repo
|
27 |
dir = Path(data_repo)
|