Corey Morris commited on
Commit
bdad6e6
·
1 Parent(s): 1f8cc2a

Refactor of create_plot

Browse files
Files changed (1) hide show
  1. app.py +25 -21
app.py CHANGED
@@ -98,36 +98,40 @@ st.download_button(
98
  mime="text/csv",
99
  )
100
 
101
-
102
- def create_plot(df, arc_column, moral_column, models=None):
103
  if models is not None:
104
  df = df[df.index.isin(models)]
105
 
106
  # remove rows with NaN values
107
- df = df.dropna(subset=[arc_column, moral_column])
108
 
109
  plot_data = pd.DataFrame({
110
  'Model': df.index,
111
- arc_column: df[arc_column],
112
- moral_column: df[moral_column],
113
  })
114
 
115
  plot_data['color'] = 'purple'
116
- fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'], trendline="ols")
117
- fig.update_layout(showlegend=False,
118
- xaxis_title=arc_column,
119
- yaxis_title=moral_column,
120
- xaxis = dict(),
121
- yaxis = dict())
 
 
 
 
 
122
 
123
- # Add a dashed line at 0.25 for the moral columns
124
- x_min = df[arc_column].min()
125
- x_max = df[arc_column].max()
126
 
127
- y_min = df[moral_column].min()
128
- y_max = df[moral_column].max()
129
 
130
- if arc_column.startswith('MMLU'):
131
  fig.add_shape(
132
  type='line',
133
  x0=0.25, x1=0.25,
@@ -139,7 +143,7 @@ def create_plot(df, arc_column, moral_column, models=None):
139
  )
140
  )
141
 
142
- if moral_column.startswith('MMLU'):
143
  fig.add_shape(
144
  type='line',
145
  x0=x_min, x1=x_max,
@@ -151,9 +155,9 @@ def create_plot(df, arc_column, moral_column, models=None):
151
  )
152
  )
153
 
154
-
155
  return fig
156
 
 
157
  # Custom scatter plots
158
  st.header('Custom scatter plots')
159
  st.write("As expected, there is a strong positive relationship between the number of parameters and average performance on the MMLU evaluation.")
@@ -177,11 +181,11 @@ plot_top_n(filtered_data, 'MMLU_abstract_algebra', 10)
177
  fig = create_plot(filtered_data, 'Parameters', 'MMLU_abstract_algebra')
178
  st.plotly_chart(fig)
179
 
 
180
  st.markdown("### Moral Scenarios Performance")
181
  st.write("While smaller models can perform well at many tasks, the model size threshold for decent performance on moral scenarios is much higher. There are no models with less than 13 billion parameters with performance much better than random chance.")
182
 
183
- st.write("Impact of Parameter Count on Accuracy for Moral Scenarios")
184
- fig = create_plot(filtered_data, 'Parameters', 'MMLU_moral_scenarios')
185
  st.plotly_chart(fig)
186
 
187
  fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
 
98
  mime="text/csv",
99
  )
100
 
101
+ def create_plot(df, x_values, y_values, models=None, title=None):
 
102
  if models is not None:
103
  df = df[df.index.isin(models)]
104
 
105
  # remove rows with NaN values
106
+ df = df.dropna(subset=[x_values, y_values])
107
 
108
  plot_data = pd.DataFrame({
109
  'Model': df.index,
110
+ x_values: df[x_values],
111
+ y_values: df[y_values],
112
  })
113
 
114
  plot_data['color'] = 'purple'
115
+ fig = px.scatter(plot_data, x=x_values, y=y_values, color='color', hover_data=['Model'], trendline="ols")
116
+ layout_args = dict(
117
+ showlegend=False,
118
+ xaxis_title=x_values,
119
+ yaxis_title=y_values,
120
+ xaxis=dict(),
121
+ yaxis=dict()
122
+ )
123
+ if title is not None: # Only set the title if provided
124
+ layout_args['title'] = title
125
+ fig.update_layout(**layout_args)
126
 
127
+ # Add a dashed line at 0.25 for the y_values
128
+ x_min = df[x_values].min()
129
+ x_max = df[x_values].max()
130
 
131
+ y_min = df[y_values].min()
132
+ y_max = df[y_values].max()
133
 
134
+ if x_values.startswith('MMLU'):
135
  fig.add_shape(
136
  type='line',
137
  x0=0.25, x1=0.25,
 
143
  )
144
  )
145
 
146
+ if y_values.startswith('MMLU'):
147
  fig.add_shape(
148
  type='line',
149
  x0=x_min, x1=x_max,
 
155
  )
156
  )
157
 
 
158
  return fig
159
 
160
+
161
  # Custom scatter plots
162
  st.header('Custom scatter plots')
163
  st.write("As expected, there is a strong positive relationship between the number of parameters and average performance on the MMLU evaluation.")
 
181
  fig = create_plot(filtered_data, 'Parameters', 'MMLU_abstract_algebra')
182
  st.plotly_chart(fig)
183
 
184
+ # Moral scenarios plots
185
  st.markdown("### Moral Scenarios Performance")
186
  st.write("While smaller models can perform well at many tasks, the model size threshold for decent performance on moral scenarios is much higher. There are no models with less than 13 billion parameters with performance much better than random chance.")
187
 
188
+ fig = create_plot(filtered_data, 'Parameters', 'MMLU_moral_scenarios', title="Impact of Parameter Count on Accuracy for Moral Scenarios")
 
189
  st.plotly_chart(fig)
190
 
191
  fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)