Corey Morris commited on
Commit
8488477
·
1 Parent(s): ca8e784

Hiding filters unless box is selected. Removed model name column because it is the index of the table

Browse files
Files changed (1) hide show
  1. app.py +46 -56
app.py CHANGED
@@ -45,6 +45,9 @@ class MultiURLData:
45
  cols = cols[-1:] + cols[:-1]
46
  data = data[cols]
47
 
 
 
 
48
  # create a new column that averages the results from each of the columns with a name that start with MMLU
49
  data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
50
 
@@ -56,109 +59,96 @@ class MultiURLData:
56
 
57
  return data
58
 
59
-
60
-
61
  def get_data(self, selected_models):
62
- filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
63
  return filtered_data
64
 
65
  data_provider = MultiURLData()
66
 
67
- st.title('Leaderboard')
68
 
69
- # TODO actually use these checkboxes as filters
70
- ## Desired behavior
71
- ## model and column selection is hidden by default
72
- ## when the user clicks the checkbox, the model and column selection appears
73
  filters = st.checkbox('Add filters')
74
 
75
- # Create checkboxes for each column
76
- selected_columns = st.multiselect(
77
- 'Select Columns',
78
- data_provider.data.columns.tolist(),
79
- default=data_provider.data.columns.tolist()
80
- )
81
 
82
- selected_models = st.multiselect(
83
- 'Select Models',
84
- data_provider.data['Model Name'].tolist(),
85
- default=data_provider.data['Model Name'].tolist()
86
- )
 
 
87
 
 
 
 
 
 
88
 
89
  # Get the filtered data and display it in a table
90
  st.header('Sortable table')
91
  filtered_data = data_provider.get_data(selected_models)
92
- st.dataframe(filtered_data)
93
 
94
- def create_plot(df, model_column, arc_column, moral_column, models=None):
95
- # Filter the dataframe if specific models are provided
 
 
 
 
 
96
  if models is not None:
97
- df = df[df[model_column].isin(models)]
98
 
99
- # Create a plot with new data
100
  plot_data = pd.DataFrame({
101
- 'Model': list(df[model_column]),
102
- arc_column: list(df[arc_column]),
103
- moral_column: list(df[moral_column]),
104
  })
105
 
106
- # Calculate color column
107
  plot_data['color'] = 'purple'
108
-
109
- # # TODO maybe change this
110
- # plot_data.loc[plot_data[moral_column] < plot_data[arc_column], 'color'] = 'red'
111
- # plot_data.loc[plot_data[moral_column] > plot_data[arc_column], 'color'] = 'blue'
112
-
113
- # Create the scatter plot with trendline
114
- fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'], trendline="ols") #other option ols
115
- fig.update_layout(showlegend=False, # hide legend
116
- xaxis_title=arc_column,
117
- yaxis_title=moral_column,
118
- xaxis = dict(),
119
- yaxis = dict())
120
 
121
  return fig
122
 
123
 
124
- # models_to_plot = ['Model1', 'Model2', 'Model3']
125
- # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
126
 
127
  st.header('Overall benchmark comparison')
128
 
129
- fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
130
  st.plotly_chart(fig)
131
 
132
- fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_average')
133
  st.plotly_chart(fig)
134
 
135
- fig = create_plot(filtered_data, 'Model Name', 'hellaswag|10', 'MMLU_average')
136
  st.plotly_chart(fig)
137
 
138
- # create a new dataframe that only has the 50 highest performing models on MMLU_average
139
  st.header('Top 50 models on MMLU_average')
140
  top_50 = filtered_data.nlargest(50, 'MMLU_average')
141
- fig = create_plot(top_50, 'Model Name', 'arc:challenge|25', 'MMLU_average')
142
  st.plotly_chart(fig)
143
 
144
- # Add heading to page to say Moral Scenarios
145
  st.header('Moral Scenarios')
146
 
147
- fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_moral_scenarios')
148
  st.plotly_chart(fig)
149
 
150
-
151
- fig = create_plot(filtered_data, 'Model Name', 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
152
  st.plotly_chart(fig)
153
 
154
- fig = create_plot(filtered_data, 'Model Name', 'MMLU_average', 'MMLU_moral_scenarios')
155
  st.plotly_chart(fig)
156
 
157
- # create a histogram of moral scenarios
158
  fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
159
  st.plotly_chart(fig)
160
 
161
- # create a histogram of moral disputes
162
  fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
163
- st.plotly_chart(fig)
164
-
 
45
  cols = cols[-1:] + cols[:-1]
46
  data = data[cols]
47
 
48
+ # remove the Model Name column
49
+ data = data.drop(['Model Name'], axis=1)
50
+
51
  # create a new column that averages the results from each of the columns with a name that start with MMLU
52
  data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
53
 
 
59
 
60
  return data
61
 
62
+ # filter data based on the index
 
63
  def get_data(self, selected_models):
64
+ filtered_data = self.data[self.data.index.isin(selected_models)]
65
  return filtered_data
66
 
67
  data_provider = MultiURLData()
68
 
69
+ st.title('Hugging Face Model Benchmarking including MMLU by task data')
70
 
 
 
 
 
71
  filters = st.checkbox('Add filters')
72
 
73
+ # Create defaults for selected columns and models
74
+ selected_columns = data_provider.data.columns.tolist()
75
+ selected_models = data_provider.data.index.tolist()
 
 
 
76
 
77
+ if filters:
78
+ # Create checkboxes for each column
79
+ selected_columns = st.multiselect(
80
+ 'Select Columns',
81
+ data_provider.data.columns.tolist(),
82
+ default=selected_columns
83
+ )
84
 
85
+ selected_models = st.multiselect(
86
+ 'Select Models',
87
+ data_provider.data.index.tolist(),
88
+ default=selected_models
89
+ )
90
 
91
  # Get the filtered data and display it in a table
92
  st.header('Sortable table')
93
  filtered_data = data_provider.get_data(selected_models)
 
94
 
95
+ # sort the table by the MMLU_average column
96
+ filtered_data = filtered_data.sort_values(by=['MMLU_average'], ascending=False)
97
+ st.dataframe(filtered_data[selected_columns])
98
+
99
+ # The rest of your plotting code...
100
+
101
+ def create_plot(df, arc_column, moral_column, models=None):
102
  if models is not None:
103
+ df = df[df.index.isin(models)]
104
 
 
105
  plot_data = pd.DataFrame({
106
+ 'Model': df.index,
107
+ arc_column: df[arc_column],
108
+ moral_column: df[moral_column],
109
  })
110
 
 
111
  plot_data['color'] = 'purple'
112
+ fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'], trendline="ols")
113
+ fig.update_layout(showlegend=False,
114
+ xaxis_title=arc_column,
115
+ yaxis_title=moral_column,
116
+ xaxis = dict(),
117
+ yaxis = dict())
 
 
 
 
 
 
118
 
119
  return fig
120
 
121
 
 
 
122
 
123
  st.header('Overall benchmark comparison')
124
 
125
+ fig = create_plot(filtered_data, 'arc:challenge|25', 'hellaswag|10')
126
  st.plotly_chart(fig)
127
 
128
+ fig = create_plot(filtered_data, 'arc:challenge|25', 'MMLU_average')
129
  st.plotly_chart(fig)
130
 
131
+ fig = create_plot(filtered_data, 'hellaswag|10', 'MMLU_average')
132
  st.plotly_chart(fig)
133
 
 
134
  st.header('Top 50 models on MMLU_average')
135
  top_50 = filtered_data.nlargest(50, 'MMLU_average')
136
+ fig = create_plot(top_50, 'arc:challenge|25', 'MMLU_average')
137
  st.plotly_chart(fig)
138
 
 
139
  st.header('Moral Scenarios')
140
 
141
+ fig = create_plot(filtered_data, 'arc:challenge|25', 'MMLU_moral_scenarios')
142
  st.plotly_chart(fig)
143
 
144
+ fig = create_plot(filtered_data, 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
 
145
  st.plotly_chart(fig)
146
 
147
+ fig = create_plot(filtered_data, 'MMLU_average', 'MMLU_moral_scenarios')
148
  st.plotly_chart(fig)
149
 
 
150
  fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
151
  st.plotly_chart(fig)
152
 
 
153
  fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
154
+ st.plotly_chart(fig)