wissamantoun commited on
Commit
88e0f7f
·
1 Parent(s): 7922adf

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +486 -0
  2. deberta_results.csv +0 -0
  3. exp_utils.py +1157 -0
  4. visualize_utils.py +57 -0
app.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.figure_factory as ff
7
+ import plotly.graph_objects as go
8
+ import streamlit as st
9
+ from plotly.subplots import make_subplots
10
+
11
+ from exp_utils import MODELS
12
+ from visualize_utils import viridis_rgb
13
+
14
+ #
15
+
16
+ st.set_page_config(
17
+ page_title="Results Viewer",
18
+ page_icon="📊",
19
+ initial_sidebar_state="expanded",
20
+ layout="wide",
21
+ )
22
+
23
+ MODELS_SIZE_MAPPING = {k: v["model_size"] for k, v in MODELS.items()}
24
+ MODELS_FAMILY_MAPPING = {k: v["model_family"] for k, v in MODELS.items()}
25
+ MODEL_FAMILES = set([model["model_family"] for model in MODELS.values()])
26
+ MODEL_NAMES = list(MODELS.keys())
27
+
28
+ MODEL_NAMES_SORTED_BY_NAME_AND_SIZE = sorted(
29
+ MODEL_NAMES, key=lambda x: (MODELS[x]["model_family"], MODELS[x]["model_size"])
30
+ )
31
+
32
+ MODEL_NAMES_SORTED_BY_SIZE = sorted(
33
+ MODEL_NAMES, key=lambda x: (MODELS[x]["model_size"], MODELS[x]["model_family"])
34
+ )
35
+
36
+
37
+ # sort MODELS_SIZE_MAPPING by value then by key
38
+ MODELS_SIZE_MAPPING = {
39
+ k: v
40
+ for k, v in sorted(MODELS_SIZE_MAPPING.items(), key=lambda item: (item[1], item[0]))
41
+ }
42
+
43
+ MODELS_SIZE_MAPPING_LIST = list(MODELS_SIZE_MAPPING.keys())
44
+
45
+
46
+ CHAT_MODELS = [x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if MODELS[x]["is_chat"]]
47
+
48
+
49
+ def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
50
+ # remove all columns that have "_loss" and "_runtime" in them
51
+ words_to_remove = [
52
+ "epoch",
53
+ "loss",
54
+ "runtime",
55
+ "samples_per_second",
56
+ "steps_per_second",
57
+ "samples",
58
+ "results_dir",
59
+ ]
60
+ df = df.loc[
61
+ :,
62
+ ~df.columns.str.contains("|".join(words_to_remove), case=False, regex=True),
63
+ ]
64
+
65
+ # rename the rest of the columns by replacing "_roc_auc" with ""
66
+ df.columns = df.columns.str.replace("_roc_auc", "")
67
+ df.columns = df.columns.str.replace("eval_", "")
68
+
69
+ df["model_family"] = df["model_name"].map(MODELS_FAMILY_MAPPING)
70
+ # create a dict with the model_name and the model_family
71
+ model_family_dict = {
72
+ k: v
73
+ for k, v in zip(
74
+ df["model_name"].values.tolist(), df["model_family"].values.tolist()
75
+ )
76
+ }
77
+
78
+ # average the results over the 5 seeds for each model (seed column is exp_seed)
79
+ df_avg = df.groupby(["model_name"]).mean()
80
+ df_std = df.groupby(["model_name"]).std()
81
+
82
+ # remove the exp_seed column
83
+ df_avg = df_avg.drop(columns=["exp_seed"])
84
+ df_std = df_std.drop(columns=["exp_seed"])
85
+ df_avg["model_family"] = df_avg.index.map(model_family_dict)
86
+ df_std["model_family"] = df_std.index.map(model_family_dict)
87
+ df_avg["model_size"] = df_avg.index.map(MODELS_SIZE_MAPPING)
88
+ df_std["model_size"] = df_std.index.map(MODELS_SIZE_MAPPING)
89
+
90
+ # sort rows by model family then model size
91
+ df_avg = df_avg.sort_values(
92
+ by=["model_family", "model_size"], ascending=[True, True]
93
+ )
94
+ df_std = df_std.sort_values(
95
+ by=["model_family", "model_size"], ascending=[True, True]
96
+ )
97
+
98
+ availables_rows = [x for x in df_avg.columns if x in df_avg.index]
99
+ df_avg = df_avg.reindex(availables_rows)
100
+
101
+ availables_rows = [x for x in df_std.columns if x in df_std.index]
102
+ df_std = df_std.reindex(availables_rows)
103
+
104
+ return df_avg, df_std
105
+
106
+
107
+ def get_data(path):
108
+ df, df_std = clean_dataframe(pd.read_csv(path, index_col=0))
109
+ return df, df_std
110
+
111
+
112
+ def filter_df(
113
+ df: pd.DataFrame,
114
+ model_family_train: list,
115
+ model_family_test: list,
116
+ model_size_train: tuple,
117
+ model_size_test: tuple,
118
+ is_chat_train: bool,
119
+ is_chat_test: bool,
120
+ sort_by_size: bool,
121
+ split_chat_models: bool,
122
+ is_debug: bool,
123
+ ) -> pd.DataFrame:
124
+ # remove all columns and rows that have "pythia-70m" in the name
125
+
126
+ # filter rows
127
+ if is_debug:
128
+ st.write("No filters")
129
+ st.write(df)
130
+ df = df.loc[
131
+ (df["model_size"] >= model_size_train[0] * 1e9)
132
+ & (df["model_size"] <= model_size_train[1] * 1e9)
133
+ ]
134
+ if is_debug:
135
+ st.write("Filter model size train")
136
+ st.write(df)
137
+ df = df.loc[df["model_family"].isin(model_family_train)]
138
+ if is_debug:
139
+ st.write("Filter model family train")
140
+ st.write(df)
141
+ if is_chat_train != "Both":
142
+ df = df.loc[df["is_chat"] == is_chat_train]
143
+ if is_debug:
144
+ st.write("Filter is chat train")
145
+ st.write(df)
146
+
147
+ # filter columns
148
+ if is_debug:
149
+ st.write("No filters")
150
+ st.write(df)
151
+ columns_to_keep = []
152
+ for column in df.columns:
153
+ if column in MODELS.keys():
154
+ model_size = MODELS[column]["model_size"]
155
+ if (
156
+ model_size >= model_size_test[0] * 1e9
157
+ and model_size <= model_size_test[1] * 1e9
158
+ ):
159
+ columns_to_keep.append(column)
160
+
161
+ df = df[list(sorted(list(set(columns_to_keep))))]
162
+ if is_debug:
163
+ st.write("Filter model size test")
164
+ st.write(df)
165
+
166
+ # filter columns
167
+ columns_to_keep = []
168
+ for column in df.columns:
169
+ for model_family in model_family_test:
170
+ if model_family == MODELS[column]["model_family"]:
171
+ columns_to_keep.append(column)
172
+ df = df[list(sorted(list(set(columns_to_keep))))]
173
+ if is_debug:
174
+ st.write("Filter model family test")
175
+ st.write(df)
176
+
177
+ if is_chat_test != "Both":
178
+ # filter columns
179
+ columns_to_keep = []
180
+ for column in df.columns:
181
+ if MODELS[column]["is_chat"] == is_chat_test:
182
+ columns_to_keep.append(column)
183
+ df = df[list(sorted(list(set(columns_to_keep))))]
184
+ if is_debug:
185
+ st.write("Filter is chat test")
186
+ st.write(df)
187
+
188
+ df = df.select_dtypes(include="number")
189
+ if is_debug:
190
+ st.write("Select dtypes to be only numbers")
191
+ st.write(df)
192
+
193
+ if sort_by_size:
194
+ columns_in = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.columns]
195
+ else:
196
+ columns_in = [x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.columns]
197
+ df = df[columns_in]
198
+ if is_debug:
199
+ st.write("Sort columns")
200
+ st.write(df)
201
+
202
+ # sort rows by size according the MODELS_SIZE_MAPPING_LIST
203
+ if sort_by_size:
204
+ availables_rows = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.index]
205
+ df = df.reindex(availables_rows)
206
+ else:
207
+ availables_rows = [
208
+ x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.index
209
+ ]
210
+ df = df.reindex(availables_rows)
211
+ if is_debug:
212
+ st.write("Sort rows")
213
+ st.write(df)
214
+
215
+ if split_chat_models:
216
+ # put chat models at the end of the columns
217
+ chat_models = [x for x in CHAT_MODELS if x in df.columns]
218
+ # sort chat models by size
219
+ chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"])
220
+ df = df[[x for x in df.columns if x not in chat_models] + chat_models]
221
+
222
+ # put chat models at the end of the rows
223
+ chat_models = [x for x in CHAT_MODELS if x in df.index]
224
+ # sort chat models by size
225
+ chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"])
226
+ df = df.reindex([x for x in df.index if x not in chat_models] + chat_models)
227
+ if is_debug:
228
+ st.write("Split chat models")
229
+ st.write(df)
230
+ return df
231
+
232
+
233
+ df, df_std = get_data("./deberta_results.csv")
234
+
235
+ with open("./ood_results.json", "r") as f:
236
+ ood_results = json.load(f)
237
+
238
+ ood_results = pd.DataFrame(ood_results)
239
+ ood_results = ood_results.set_index("model_name")
240
+ ood_results = ood_results.drop(
241
+ columns=["exp_name", "accuracy", "f1", "precision", "recall"]
242
+ )
243
+ ood_results.columns = ["seed", "Adversarial"]
244
+
245
+ ood_results_avg = ood_results.groupby(["model_name"]).mean()
246
+ ood_results_std = ood_results.groupby(["model_name"]).std()
247
+
248
+ # filters
249
+ show_diff = st.sidebar.checkbox("Show Diff", value=False)
250
+ sort_by_size = st.sidebar.checkbox("Sort by size", value=False)
251
+ split_chat_models = st.sidebar.checkbox("Split chat models", value=False)
252
+ add_mean = st.sidebar.checkbox("Add mean", value=False)
253
+ show_std = st.sidebar.checkbox("Show std", value=False)
254
+ model_size_train = st.sidebar.slider(
255
+ "Train Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1
256
+ )
257
+ model_size_test = st.sidebar.slider(
258
+ "Test Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1
259
+ )
260
+ is_chat_train = st.sidebar.selectbox("(Train) Is Chat?", [True, False, "Both"], index=2)
261
+ is_chat_test = st.sidebar.selectbox("(Test) Is Chat?", [True, False, "Both"], index=2)
262
+ model_family_train = st.sidebar.multiselect(
263
+ "Model Family Train",
264
+ MODEL_FAMILES,
265
+ default=MODEL_FAMILES,
266
+ )
267
+ model_family_test = st.sidebar.multiselect(
268
+ "Model Family Test",
269
+ list(MODEL_FAMILES) + ["Adversarial"],
270
+ default=MODEL_FAMILES,
271
+ )
272
+
273
+ add_adversarial = False
274
+ if "Adversarial" in model_family_test:
275
+ model_family_test.remove("Adversarial")
276
+ add_adversarial = True
277
+
278
+ sort_by_adversarial = False
279
+ if add_adversarial:
280
+ sort_by_adversarial = st.sidebar.checkbox("Sort by adversarial", value=False)
281
+
282
+ if st.sidebar.checkbox("Use default color scale", value=False):
283
+ color_scale = "Viridis_r"
284
+ else:
285
+ color_scale = viridis_rgb
286
+
287
+
288
+ is_debug = st.sidebar.checkbox("Debug", value=False)
289
+
290
+ if show_std:
291
+ selected_df = df_std.copy()
292
+ else:
293
+ selected_df = df.copy()
294
+
295
+ if show_diff:
296
+ # get those 3 columns {'model_size', 'model_family', 'is_chat'}
297
+ columns_to_keep = ["model_size", "model_family", "is_chat"]
298
+ to_be_added = selected_df[columns_to_keep]
299
+ selected_df = selected_df.drop(columns=columns_to_keep)
300
+ selected_df = selected_df.sub(selected_df.values.diagonal(), axis=1)
301
+ selected_df = selected_df.join(to_be_added)
302
+
303
+
304
+ filtered_df = filter_df(
305
+ selected_df,
306
+ model_family_train,
307
+ model_family_test,
308
+ model_size_train,
309
+ model_size_test,
310
+ is_chat_train,
311
+ is_chat_test,
312
+ sort_by_size,
313
+ split_chat_models,
314
+ is_debug,
315
+ )
316
+
317
+
318
+ # subtract each row by the diagonal
319
+
320
+ # if show_diff:
321
+ # filtered_df = filtered_df.sub(filtered_df.values.diagonal(), axis=1)
322
+ if add_adversarial:
323
+ filtered_df = filtered_df.join(ood_results_avg)
324
+
325
+ if add_mean:
326
+ col_mean = filtered_df.mean(axis=1)
327
+ row_mean = filtered_df.mean(axis=0)
328
+ diag = filtered_df.values.diagonal()
329
+ filtered_df["mean"] = col_mean
330
+ filtered_df.loc["mean"] = row_mean
331
+
332
+
333
+ filtered_df = filtered_df * 100
334
+ filtered_df = filtered_df.round(0)
335
+
336
+ # sort by the column called Adversarial
337
+ if sort_by_adversarial:
338
+ filtered_df = filtered_df.sort_values(by=["Adversarial"], ascending=False)
339
+
340
+ # check if the df has columns and rows
341
+ if filtered_df.shape[0] == 0:
342
+ st.write("No results found")
343
+ st.stop()
344
+
345
+ if filtered_df.shape[1] == 0:
346
+ st.write("No results found")
347
+ st.stop()
348
+
349
+ fig = px.imshow(
350
+ filtered_df.values,
351
+ x=list(filtered_df.columns),
352
+ y=list(filtered_df.index),
353
+ color_continuous_scale=color_scale,
354
+ contrast_rescaling=None,
355
+ text_auto=True,
356
+ aspect="auto",
357
+ )
358
+
359
+
360
+ width = st.sidebar.text_input("Width", "1920")
361
+ height = st.sidebar.text_input("Height", "1080")
362
+ scale = st.sidebar.text_input("Scale", "1.0")
363
+ margin = st.sidebar.text_input("Margin[l,r,b,t]", "200,100,100,100")
364
+ fig.update_traces(textfont_size=9)
365
+ fig.update_layout(
366
+ xaxis={"side": "top"},
367
+ yaxis={"side": "left"},
368
+ margin=dict(
369
+ l=int(margin.split(",")[0]),
370
+ r=int(margin.split(",")[1]),
371
+ b=int(margin.split(",")[2]),
372
+ t=int(margin.split(",")[3]),
373
+ ),
374
+ font=dict(size=10),
375
+ )
376
+ fig.update_xaxes(tickangle=45)
377
+
378
+ fig.update_xaxes(tickmode="linear")
379
+ fig.update_yaxes(tickmode="linear")
380
+ # change the font in the heatmap
381
+ st.plotly_chart(fig, use_container_width=True)
382
+
383
+
384
+ if st.sidebar.button("save", key="save"):
385
+ fig.write_image(
386
+ "fig1.pdf",
387
+ width=int(width),
388
+ height=int(height),
389
+ validate=True,
390
+ scale=float(scale),
391
+ )
392
+
393
+
394
+ # plot the col mean vs model size
395
+ if add_mean and not show_diff:
396
+ # check if any of the chat models are in the filtered df columns and index
397
+ if len([x for x in CHAT_MODELS if x in filtered_df.columns]) > 0 or len(
398
+ [x for x in CHAT_MODELS if x in filtered_df.index]
399
+ ):
400
+ st.warning(
401
+ "Chat models are in the filtered df columns or index."
402
+ "This will cause the mean graph to be skewed."
403
+ )
404
+
405
+ fig3 = px.scatter(
406
+ y=row_mean,
407
+ x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]],
408
+ # hover_data=[x for x in filtered_df.index if x not in ["mean"]],
409
+ color=[
410
+ MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"]
411
+ ],
412
+ color_discrete_sequence=px.colors.qualitative.Plotly,
413
+ title="",
414
+ # x axis title
415
+ labels={
416
+ "x": "Target Model Size",
417
+ "y": "Average ROC AUC",
418
+ "color": "Model Family",
419
+ },
420
+ log_x=True,
421
+ trendline="ols",
422
+ )
423
+ fig4 = px.scatter(
424
+ y=diag,
425
+ x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]],
426
+ # hover_data=[x for x in filtered_df.index if x not in ["mean"]],
427
+ color=[
428
+ MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"]
429
+ ],
430
+ color_discrete_sequence=px.colors.qualitative.Plotly,
431
+ title="",
432
+ # x axis title
433
+ labels={
434
+ "x": "Target Model Size",
435
+ "y": "Self ROC AUC",
436
+ "color": "Model Family",
437
+ },
438
+ log_x=True,
439
+ trendline="ols",
440
+ )
441
+
442
+ # put the two plots side by side
443
+ fig_subplot = make_subplots(
444
+ rows=1,
445
+ cols=2,
446
+ shared_yaxes=False,
447
+ subplot_titles=("Self Detection ROC AUC", "Average Target ROC AUC"),
448
+ )
449
+ for i, figure in enumerate([fig4, fig3]):
450
+ for trace in range(len(figure["data"])):
451
+ trace_data = figure["data"][trace]
452
+ if i == 1:
453
+ trace_data["showlegend"] = False
454
+ fig_subplot.append_trace(trace_data, row=1, col=i + 1)
455
+
456
+ fig_subplot.update_xaxes(type="log")
457
+ # y axis range
458
+ fig_subplot.update_yaxes(range=[0.90, 1])
459
+
460
+ fig_subplot.update_layout(
461
+ height=500,
462
+ width=1200,
463
+ )
464
+ # put the legend on the bottom
465
+ fig_subplot.update_layout(
466
+ legend=dict(orientation="h", yanchor="bottom", y=-0.2, x=0.09)
467
+ )
468
+ st.plotly_chart(fig_subplot, use_container_width=True)
469
+
470
+ fig2 = px.scatter(
471
+ y=col_mean,
472
+ x=[MODELS_SIZE_MAPPING[x] for x in filtered_df.index if x not in ["mean"]],
473
+ # hover_data=[x for x in filtered_df.index if x not in ["mean"]],
474
+ color=[
475
+ MODELS_FAMILY_MAPPING[x] for x in filtered_df.index if x not in ["mean"]
476
+ ],
477
+ color_discrete_sequence=px.colors.qualitative.Plotly,
478
+ title="Mean vs Train Model Size",
479
+ log_x=True,
480
+ trendline="ols",
481
+ )
482
+ fig2.update_layout(
483
+ height=600,
484
+ width=900,
485
+ )
486
+ st.plotly_chart(fig2, use_container_width=False)
deberta_results.csv ADDED
The diff for this file is too large to render. See raw diff
 
exp_utils.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLAMA2
2
+ # <s>[INST] <<SYS>>
3
+ # {{ system_prompt }}
4
+ # <</SYS>>
5
+
6
+ # {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
7
+
8
+ ZERO_SHOT_PROMPT = """A chat between a curious human and an artificial intelligence assistant.
9
+ The assistant gives helpful, detailed, and polite answers to the human's questions.
10
+ Human: {{ user_message }}
11
+ Assistant: """
12
+
13
+ ZERO_SHOT_STOPWORD = "Human:"
14
+
15
+ LM_PROMPT = """Give the best continuation of the following text: {{ user_message }}"""
16
+
17
+ LLAMA2_PROMPT = """<s>[INST] <<SYS>>
18
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
19
+
20
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
21
+ <</SYS>>
22
+
23
+ {{ user_message }} [/INST] """
24
+
25
+ LLAMA2_STOPWORD = "</s>"
26
+
27
+ MPT_PROMPT_7B = """<|im_start|>system
28
+ - You are a helpful assistant chatbot trained by MosaicML.
29
+ - You answer questions.
30
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
31
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>
32
+ <|im_start|>user
33
+ {{ user_message }}<|im_end|>
34
+ <|im_start|>assistant
35
+ """
36
+
37
+ MPT_LM_PROMPT_7B = """<|im_start|>system
38
+ - You are a helpful assistant chatbot trained by MosaicML.
39
+ - You answer questions.
40
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
41
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>
42
+ <|im_start|>user
43
+ {{ user_message }}<|im_end|>
44
+ <|im_start|>assistant
45
+ """
46
+
47
+ MPT_PROMPT_30B = """<|im_start|>system
48
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>
49
+ <|im_start|>user
50
+ {{ user_message }}<|im_end|>
51
+ <|im_start|>assistant
52
+ """
53
+
54
+ MPT_STOPWORD = "<|im_end|>"
55
+
56
+ FALCON_PROMPT = """The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins.
57
+ User: {{ user_message }}
58
+ Falcon: """
59
+
60
+ FALCON_STOPWORD = "User:"
61
+
62
+ ALFRED_PROMPT = """Alfred is a large language model trained by LightOn. Knowledge cutoff: November 2022. Current date: 31 July, 2023
63
+
64
+ User: {{ user_message }}
65
+ Alfred: """
66
+
67
+ ALFRED_STOPWORD = "User:"
68
+
69
+ VICUNA_PROMPT = """A chat between a curious user and an artificial intelligence assistant.
70
+ The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ user_message }} ASSISTANT: """
71
+
72
+ VICUNA_STOPWORD = ""
73
+
74
+ MODELS = {
75
+ ################################################
76
+ # llama-2 #
77
+ ################################################
78
+ "llama-2-70b": {
79
+ "name": "llama-2-70b",
80
+ "model_name": "NousResearch/llama-2-70b-hf",
81
+ "model_path": "NousResearch-llama-2-70b-hf",
82
+ "num_gpus": 4,
83
+ "batch_size": 2,
84
+ "is_chat": False,
85
+ "max_total_tokens": 2048,
86
+ "max_input_length": 1024,
87
+ "max_batch_prefill_tokens": 1024,
88
+ "to_be_quantized": True,
89
+ "to_be_watermarked": True,
90
+ "model_size": 70e9,
91
+ "model_family": "llama-2",
92
+ },
93
+ "llama-2-13b": {
94
+ "name": "llama-2-13b",
95
+ "model_name": "NousResearch/llama-2-13b-hf",
96
+ "model_path": "NousResearch-llama-2-13b-hf",
97
+ "num_gpus": 2,
98
+ "batch_size": 8,
99
+ "is_chat": False,
100
+ "max_total_tokens": 2048,
101
+ "max_input_length": 1024,
102
+ "max_batch_prefill_tokens": 1024,
103
+ "to_be_quantized": True,
104
+ "to_be_watermarked": True,
105
+ "model_size": 13e9,
106
+ "model_family": "llama-2",
107
+ },
108
+ "llama-2-7b": {
109
+ "name": "llama-2-7b",
110
+ "model_name": "NousResearch/llama-2-7b-hf",
111
+ "model_path": "NousResearch-llama-2-7b-hf",
112
+ "num_gpus": 1,
113
+ "batch_size": 4,
114
+ "is_chat": False,
115
+ "max_total_tokens": 2048,
116
+ "max_input_length": 1024,
117
+ "max_batch_prefill_tokens": 1024,
118
+ "to_be_quantized": True,
119
+ "to_be_watermarked": True,
120
+ "model_size": 7e9,
121
+ "model_family": "llama-2",
122
+ },
123
+ ################################################
124
+ # llama-2 #
125
+ ################################################
126
+ "llama-2-70b-chat": {
127
+ "name": "llama-2-70b-chat",
128
+ "model_name": "NousResearch/llama-2-70b-chat-hf",
129
+ "model_path": "NousResearch-llama-2-70b-chat-hf",
130
+ "num_gpus": 4,
131
+ "batch_size": 2,
132
+ "is_chat": True,
133
+ "prompt": LLAMA2_PROMPT,
134
+ "stopword": LLAMA2_STOPWORD,
135
+ "max_total_tokens": 2048,
136
+ "max_input_length": 1024,
137
+ "max_batch_prefill_tokens": 1024,
138
+ "model_size": 70e9,
139
+ "model_family": "llama-2",
140
+ },
141
+ "llama-2-13b-chat": {
142
+ "name": "llama-2-13b-chat",
143
+ "model_name": "NousResearch/llama-2-13b-chat-hf",
144
+ "model_path": "NousResearch-llama-2-13b-chat-hf",
145
+ "num_gpus": 2,
146
+ "batch_size": 8,
147
+ "is_chat": True,
148
+ "prompt": LLAMA2_PROMPT,
149
+ "stopword": LLAMA2_STOPWORD,
150
+ "max_total_tokens": 2048,
151
+ "max_input_length": 1024,
152
+ "max_batch_prefill_tokens": 1024,
153
+ "model_size": 13e9,
154
+ "model_family": "llama-2",
155
+ },
156
+ "llama-2-7b-chat": {
157
+ "name": "llama-2-7b-chat",
158
+ "model_name": "NousResearch/llama-2-7b-chat-hf",
159
+ "model_path": "NousResearch-llama-2-7b-chat-hf",
160
+ "num_gpus": 1,
161
+ "batch_size": 4,
162
+ "is_chat": True,
163
+ "prompt": LLAMA2_PROMPT,
164
+ "stopword": LLAMA2_STOPWORD,
165
+ "max_total_tokens": 2048,
166
+ "max_input_length": 1024,
167
+ "max_batch_prefill_tokens": 1024,
168
+ "model_size": 7e9,
169
+ "model_family": "llama-2",
170
+ },
171
+ ################################################
172
+ # llama-1 #
173
+ ################################################
174
+ "llama-65b": {
175
+ "name": "llama-65b",
176
+ "model_name": "huggyllama/llama-65b",
177
+ "model_path": "huggyllama-llama-65b",
178
+ "num_gpus": 4,
179
+ "batch_size": 2,
180
+ "is_chat": False,
181
+ "max_total_tokens": 2048,
182
+ "max_input_length": 1024,
183
+ "max_batch_prefill_tokens": 1024,
184
+ "to_be_quantized": True,
185
+ "to_be_watermarked": True,
186
+ "model_size": 65e9,
187
+ "model_family": "llama-1",
188
+ },
189
+ "llama-30b": {
190
+ "name": "llama-30b",
191
+ "model_name": "huggyllama/llama-30b",
192
+ "model_path": "huggyllama-llama-30b",
193
+ "num_gpus": 2,
194
+ "batch_size": 2,
195
+ "is_chat": False,
196
+ "max_total_tokens": 2048,
197
+ "max_input_length": 1024,
198
+ "max_batch_prefill_tokens": 1024,
199
+ "to_be_quantized": True,
200
+ "to_be_watermarked": True,
201
+ "model_size": 30e9,
202
+ "model_family": "llama-1",
203
+ },
204
+ "llama-13b": {
205
+ "name": "llama-13b",
206
+ "model_name": "huggyllama/llama-13b",
207
+ "model_path": "huggyllama-llama-13b",
208
+ "num_gpus": 2,
209
+ "batch_size": 8,
210
+ "is_chat": False,
211
+ "max_total_tokens": 2048,
212
+ "max_input_length": 1024,
213
+ "max_batch_prefill_tokens": 1024,
214
+ "to_be_quantized": True,
215
+ "to_be_watermarked": True,
216
+ "model_size": 13e9,
217
+ "model_family": "llama-1",
218
+ },
219
+ "llama-7b": {
220
+ "name": "llama-7b",
221
+ "model_name": "huggyllama/llama-7b",
222
+ "model_path": "huggyllama-llama-7b",
223
+ "num_gpus": 1,
224
+ "batch_size": 4,
225
+ "is_chat": False,
226
+ "max_total_tokens": 2048,
227
+ "max_input_length": 1024,
228
+ "max_batch_prefill_tokens": 1024,
229
+ "to_be_quantized": True,
230
+ "to_be_watermarked": True,
231
+ "model_size": 7e9,
232
+ "model_family": "llama-1",
233
+ },
234
+ ################################################
235
+ # OPT #
236
+ ################################################
237
+ "opt-66b": {
238
+ "name": "opt-66b",
239
+ "model_name": "facebook/opt-66b",
240
+ "model_path": "facebook-opt-66b",
241
+ "num_gpus": 4,
242
+ "batch_size": 2,
243
+ "is_chat": False,
244
+ "max_total_tokens": 1024,
245
+ "max_input_length": 256,
246
+ "max_batch_prefill_tokens": 1024,
247
+ "model_size": 66e9,
248
+ "model_family": "opt",
249
+ },
250
+ "opt-30b": {
251
+ "name": "opt-30b",
252
+ "model_name": "facebook/opt-30b",
253
+ "model_path": "facebook-opt-30b",
254
+ "num_gpus": 4,
255
+ "batch_size": 1,
256
+ "is_chat": False,
257
+ "no_api": True,
258
+ "model_size": 30e9,
259
+ "model_family": "opt",
260
+ },
261
+ "opt-13b": {
262
+ "name": "opt-13b",
263
+ "model_name": "facebook/opt-13b",
264
+ "model_path": "facebook-opt-13b",
265
+ "num_gpus": 2,
266
+ "batch_size": 1,
267
+ "is_chat": False,
268
+ "no_api": True,
269
+ "model_size": 13e9,
270
+ "model_family": "opt",
271
+ },
272
+ "opt-6.7b": {
273
+ "name": "opt-6.7b",
274
+ "model_name": "facebook/opt-6.7b",
275
+ "model_path": "facebook-opt-6.7b",
276
+ "num_gpus": 1,
277
+ "batch_size": 4,
278
+ "is_chat": False,
279
+ "no_api": True,
280
+ "model_size": 6.7e9,
281
+ "model_family": "opt",
282
+ },
283
+ "opt-2.7b": {
284
+ "name": "opt-2.7b",
285
+ "model_name": "facebook/opt-2.7b",
286
+ "model_path": "facebook-opt-2.7b",
287
+ "num_gpus": 1,
288
+ "batch_size": 16,
289
+ "is_chat": False,
290
+ "max_total_tokens": 1024,
291
+ "max_input_length": 256,
292
+ "max_batch_prefill_tokens": 4096,
293
+ "model_size": 2.7e9,
294
+ "model_family": "opt",
295
+ },
296
+ "opt-1.3b": {
297
+ "name": "opt-1.3b",
298
+ "model_name": "facebook/opt-1.3b",
299
+ "model_path": "facebook-opt-1.3b",
300
+ "num_gpus": 1,
301
+ "batch_size": 16,
302
+ "is_chat": False,
303
+ "use_flash_attention": True,
304
+ "max_total_tokens": 1024,
305
+ "max_input_length": 256,
306
+ "max_batch_prefill_tokens": 4096,
307
+ "model_size": 1.3e9,
308
+ "model_family": "opt",
309
+ },
310
+ "opt-350m": {
311
+ "name": "opt-350m",
312
+ "model_name": "facebook/opt-350m",
313
+ "model_path": "facebook-opt-350m",
314
+ "num_gpus": 1,
315
+ "batch_size": 16,
316
+ "is_chat": False,
317
+ "no_api": True,
318
+ "model_size": 350e6,
319
+ "model_family": "opt",
320
+ },
321
+ "opt-125m": {
322
+ "name": "opt-125m",
323
+ "model_name": "facebook/opt-125m",
324
+ "model_path": "facebook-opt-125m",
325
+ "num_gpus": 1,
326
+ "batch_size": 16,
327
+ "is_chat": False,
328
+ "max_total_tokens": 1024,
329
+ "max_input_length": 256,
330
+ "max_batch_prefill_tokens": 4096,
331
+ "model_size": 125e6,
332
+ "model_family": "opt",
333
+ },
334
+ ################################################
335
+ # MPT #
336
+ ################################################
337
+ "mpt-30b": {
338
+ "name": "mpt-30b",
339
+ "model_name": "mosaicml/mpt-30b",
340
+ "model_path": "mosaicml-mpt-30b",
341
+ "num_gpus": 2,
342
+ "batch_size": 2,
343
+ "is_chat": False,
344
+ "max_total_tokens": 2048,
345
+ "max_input_length": 1024,
346
+ "max_batch_prefill_tokens": 1024,
347
+ "model_size": 30e9,
348
+ "model_family": "mpt",
349
+ },
350
+ "mpt-7b": {
351
+ "name": "mpt-7b",
352
+ "model_name": "mosaicml/mpt-7b",
353
+ "model_path": "mosaicml-mpt-7b",
354
+ "num_gpus": 1,
355
+ "batch_size": 4,
356
+ "is_chat": False,
357
+ "max_total_tokens": 2048,
358
+ "max_input_length": 1024,
359
+ "max_batch_prefill_tokens": 4096,
360
+ "model_size": 7e9,
361
+ "model_family": "mpt",
362
+ },
363
+ ################################################
364
+ # MPT-Chat #
365
+ ################################################
366
+ "mpt-30b-chat": {
367
+ "name": "mpt-30b-chat",
368
+ "model_name": "mosaicml/mpt-30b-chat",
369
+ "model_path": "mosaicml-mpt-30b-chat",
370
+ "num_gpus": 2,
371
+ "batch_size": 2,
372
+ "is_chat": True,
373
+ "prompt": MPT_PROMPT_30B,
374
+ "stopword": MPT_STOPWORD,
375
+ "max_total_tokens": 1024,
376
+ "max_input_length": 256,
377
+ "max_batch_prefill_tokens": 4096,
378
+ "model_size": 30e9,
379
+ "model_family": "mpt",
380
+ },
381
+ "mpt-7b-chat": {
382
+ "name": "mpt-7b-chat",
383
+ "model_name": "mosaicml/mpt-7b-chat",
384
+ "model_path": "mosaicml-mpt-7b-chat",
385
+ "num_gpus": 1,
386
+ "batch_size": 4,
387
+ "is_chat": True,
388
+ "prompt": MPT_PROMPT_7B,
389
+ "stopword": MPT_STOPWORD,
390
+ "max_total_tokens": 2048,
391
+ "max_input_length": 1024,
392
+ "max_batch_prefill_tokens": 4096,
393
+ "model_size": 7e9,
394
+ "model_family": "mpt",
395
+ },
396
+ ################################################
397
+ # OPENLLAMA #
398
+ ################################################
399
+ "openllama-13b": {
400
+ "name": "openllama-13b",
401
+ "model_name": "openlm-research/open_llama_13b",
402
+ "model_path": "openlm-research-open_llama_13b",
403
+ "num_gpus": 2,
404
+ "batch_size": 8,
405
+ "is_chat": False,
406
+ "max_total_tokens": 2048,
407
+ "max_input_length": 1024,
408
+ "max_batch_prefill_tokens": 4096,
409
+ "model_size": 13e9,
410
+ "model_family": "openllama",
411
+ },
412
+ "openllama-7b": {
413
+ "name": "openllama-7b",
414
+ "model_name": "openlm-research/open_llama_7b",
415
+ "model_path": "openlm-research-open_llama_7b",
416
+ "num_gpus": 1,
417
+ "batch_size": 8,
418
+ "is_chat": False,
419
+ "max_total_tokens": 2048,
420
+ "max_input_length": 1024,
421
+ "max_batch_prefill_tokens": 4096,
422
+ "model_size": 7e9,
423
+ "model_family": "openllama",
424
+ },
425
+ "openllama-3b": {
426
+ "name": "openllama-3b",
427
+ "model_name": "openlm-research/open_llama_3b",
428
+ "model_path": "openlm-research-open_llama_3b",
429
+ "num_gpus": 1,
430
+ "batch_size": 16,
431
+ "is_chat": False,
432
+ "use_flash_attention": False,
433
+ "max_total_tokens": 2048,
434
+ "max_input_length": 1024,
435
+ "max_batch_prefill_tokens": 4096,
436
+ "model_size": 3e9,
437
+ "model_family": "openllama",
438
+ },
439
+ ################################################
440
+ # OPENLLAMA-2 #
441
+ ################################################
442
+ # "openllama-2-13b": {
443
+ # "name": "openllama-2-13b",
444
+ # "model_name": "openlm-research/open_llama_13b_v2",
445
+ # "model_path": "openlm-research-open_llama_13b_v2",
446
+ # "num_gpus": 2,
447
+ # "batch_size": 1,
448
+ # "is_chat": False,
449
+ # },
450
+ "openllama-2-7b": {
451
+ "name": "openllama-2-7b",
452
+ "model_name": "openlm-research/open_llama_7b_v2",
453
+ "model_path": "openlm-research-open_llama_7b_v2",
454
+ "num_gpus": 1,
455
+ "batch_size": 8,
456
+ "is_chat": False,
457
+ "max_total_tokens": 2048,
458
+ "max_input_length": 1024,
459
+ "max_batch_prefill_tokens": 4096,
460
+ "model_size": 7e9,
461
+ "model_family": "openllama-2",
462
+ },
463
+ "openllama-2-3b": {
464
+ "name": "openllama-2-3b",
465
+ "model_name": "openlm-research/open_llama_3b_v2",
466
+ "model_path": "openlm-research-open_llama_3b_v2",
467
+ "num_gpus": 1,
468
+ "batch_size": 16,
469
+ "is_chat": False,
470
+ "use_flash_attention": False,
471
+ "max_total_tokens": 2048,
472
+ "max_input_length": 1024,
473
+ "max_batch_prefill_tokens": 4096,
474
+ "model_size": 3e9,
475
+ "model_family": "openllama-2",
476
+ },
477
+ ################################################
478
+ # Pythia #
479
+ ################################################
480
+ "pythia-12b": {
481
+ "name": "pythia-12b",
482
+ "model_name": "EleutherAI/pythia-12b",
483
+ "model_path": "EleutherAI-pythia-12b",
484
+ "num_gpus": 2,
485
+ "batch_size": 8,
486
+ "is_chat": False,
487
+ "max_total_tokens": 2048,
488
+ "max_input_length": 1024,
489
+ "max_batch_prefill_tokens": 4096,
490
+ "model_size": 12e9,
491
+ "model_family": "pythia",
492
+ },
493
+ "pythia-6.9b": {
494
+ "name": "pythia-6.9b",
495
+ "model_name": "EleutherAI/pythia-6.9b",
496
+ "model_path": "EleutherAI-pythia-6.9b",
497
+ "num_gpus": 1,
498
+ "batch_size": 8,
499
+ "is_chat": False,
500
+ "max_total_tokens": 2048,
501
+ "max_input_length": 1024,
502
+ "max_batch_prefill_tokens": 4096,
503
+ "model_size": 6.9e9,
504
+ "model_family": "pythia",
505
+ },
506
+ "pythia-2.8b": {
507
+ "name": "pythia-2.8b",
508
+ "model_name": "EleutherAI/pythia-2.8b",
509
+ "model_path": "EleutherAI-pythia-2.8b",
510
+ "num_gpus": 1,
511
+ "batch_size": 16,
512
+ "is_chat": False,
513
+ "max_total_tokens": 2048,
514
+ "max_input_length": 1024,
515
+ "max_batch_prefill_tokens": 4096,
516
+ "model_size": 2.8e9,
517
+ "model_family": "pythia",
518
+ },
519
+ "pythia-1.4b": {
520
+ "name": "pythia-1.4b",
521
+ "model_name": "EleutherAI/pythia-1.4b",
522
+ "model_path": "EleutherAI-pythia-1.4b",
523
+ "num_gpus": 1,
524
+ "batch_size": 16,
525
+ "is_chat": False,
526
+ "max_total_tokens": 2048,
527
+ "max_input_length": 256,
528
+ "max_batch_prefill_tokens": 4096,
529
+ "model_size": 1.4e9,
530
+ "model_family": "pythia",
531
+ },
532
+ "pythia-1b": {
533
+ "name": "pythia-1b",
534
+ "model_name": "EleutherAI/pythia-1b",
535
+ "model_path": "EleutherAI-pythia-1b",
536
+ "num_gpus": 1,
537
+ "batch_size": 1,
538
+ "is_chat": False,
539
+ "use_flash_attention": False,
540
+ "max_total_tokens": 1024,
541
+ "max_input_length": 256,
542
+ "max_batch_prefill_tokens": 4096,
543
+ "model_size": 1e9,
544
+ "model_family": "pythia",
545
+ },
546
+ "pythia-410m": {
547
+ "name": "pythia-410m",
548
+ "model_name": "EleutherAI/pythia-410m",
549
+ "model_path": "EleutherAI-pythia-410m",
550
+ "num_gpus": 1,
551
+ "batch_size": 16,
552
+ "is_chat": False,
553
+ "max_total_tokens": 2048,
554
+ "max_input_length": 1024,
555
+ "max_batch_prefill_tokens": 4096,
556
+ "model_size": 410e6,
557
+ "model_family": "pythia",
558
+ },
559
+ "pythia-160m": {
560
+ "name": "pythia-160m",
561
+ "model_name": "EleutherAI/pythia-160m",
562
+ "model_path": "EleutherAI-pythia-160m",
563
+ "num_gpus": 1,
564
+ "batch_size": 16,
565
+ "is_chat": False,
566
+ "max_total_tokens": 2048,
567
+ "max_input_length": 1024,
568
+ "max_batch_prefill_tokens": 4096,
569
+ "model_size": 160e6,
570
+ "model_family": "pythia",
571
+ },
572
+ "pythia-70m": {
573
+ "name": "pythia-70m",
574
+ "model_name": "EleutherAI/pythia-70m",
575
+ "model_path": "EleutherAI-pythia-70m",
576
+ "num_gpus": 1,
577
+ "batch_size": 16,
578
+ "is_chat": False,
579
+ "max_total_tokens": 2048,
580
+ "max_input_length": 1024,
581
+ "max_batch_prefill_tokens": 4096,
582
+ "model_size": 70e6,
583
+ "model_family": "pythia",
584
+ },
585
+ ################################################
586
+ # Pythia-deduped #
587
+ ################################################
588
+ "pythia-12b-deduped": {
589
+ "name": "pythia-12b-deduped",
590
+ "model_name": "EleutherAI/pythia-12b-deduped",
591
+ "model_path": "EleutherAI-pythia-12b-deduped",
592
+ "num_gpus": 2,
593
+ "batch_size": 8,
594
+ "is_chat": False,
595
+ "max_total_tokens": 2048,
596
+ "max_input_length": 1024,
597
+ "max_batch_prefill_tokens": 4096,
598
+ "model_family": "pythia-deduped",
599
+ "model_size": 12e9,
600
+ },
601
+ "pythia-6.9b-deduped": {
602
+ "name": "pythia-6.9b-deduped",
603
+ "model_name": "EleutherAI/pythia-6.9b-deduped",
604
+ "model_path": "EleutherAI-pythia-6.9b-deduped",
605
+ "num_gpus": 1,
606
+ "batch_size": 8,
607
+ "is_chat": False,
608
+ "max_total_tokens": 2048,
609
+ "max_input_length": 1024,
610
+ "max_batch_prefill_tokens": 4096,
611
+ "model_family": "pythia-deduped",
612
+ "model_size": 6.9e9,
613
+ },
614
+ "pythia-2.8b-deduped": {
615
+ "name": "pythia-2.8b-deduped",
616
+ "model_name": "EleutherAI/pythia-2.8b-deduped",
617
+ "model_path": "EleutherAI-pythia-2.8b-deduped",
618
+ "num_gpus": 1,
619
+ "batch_size": 16,
620
+ "is_chat": False,
621
+ "max_total_tokens": 2048,
622
+ "max_input_length": 1024,
623
+ "max_batch_prefill_tokens": 4096,
624
+ "model_family": "pythia-deduped",
625
+ "model_size": 2.8e9,
626
+ },
627
+ "pythia-1.4b-deduped": {
628
+ "name": "pythia-1.4b-deduped",
629
+ "model_name": "EleutherAI/pythia-1.4b-deduped",
630
+ "model_path": "EleutherAI-pythia-1.4b-deduped",
631
+ "num_gpus": 1,
632
+ "batch_size": 16,
633
+ "is_chat": False,
634
+ "max_total_tokens": 2048,
635
+ "max_input_length": 1024,
636
+ "max_batch_prefill_tokens": 4096,
637
+ "model_family": "pythia-deduped",
638
+ "model_size": 1.4e9,
639
+ },
640
+ "pythia-1b-deduped": {
641
+ "name": "pythia-1b-deduped",
642
+ "model_name": "EleutherAI/pythia-1b-deduped",
643
+ "model_path": "EleutherAI-pythia-1b-deduped",
644
+ "num_gpus": 1,
645
+ "batch_size": 16,
646
+ "is_chat": False,
647
+ "use_flash_attention": False,
648
+ "max_total_tokens": 2048,
649
+ "max_input_length": 256,
650
+ "max_batch_prefill_tokens": 4096,
651
+ "model_family": "pythia-deduped",
652
+ "model_size": 1e9,
653
+ },
654
+ "pythia-410m-deduped": {
655
+ "name": "pythia-410m-deduped",
656
+ "model_name": "EleutherAI/pythia-410m-deduped",
657
+ "model_path": "EleutherAI-pythia-410m-deduped",
658
+ "num_gpus": 1,
659
+ "batch_size": 16,
660
+ "is_chat": False,
661
+ "max_total_tokens": 2048,
662
+ "max_input_length": 1024,
663
+ "max_batch_prefill_tokens": 4096,
664
+ "model_family": "pythia-deduped",
665
+ "model_size": 410e6,
666
+ },
667
+ "pythia-160m-deduped": {
668
+ "name": "pythia-160m-deduped",
669
+ "model_name": "EleutherAI/pythia-160m-deduped",
670
+ "model_path": "EleutherAI-pythia-160m-deduped",
671
+ "num_gpus": 1,
672
+ "batch_size": 16,
673
+ "is_chat": False,
674
+ "max_total_tokens": 2048,
675
+ "max_input_length": 1024,
676
+ "max_batch_prefill_tokens": 4096,
677
+ "model_family": "pythia-deduped",
678
+ "model_size": 160e6,
679
+ },
680
+ "pythia-70m-deduped": {
681
+ "name": "pythia-70m-deduped",
682
+ "model_name": "EleutherAI/pythia-70m-deduped",
683
+ "model_path": "EleutherAI-pythia-70m-deduped",
684
+ "num_gpus": 1,
685
+ "batch_size": 16,
686
+ "is_chat": False,
687
+ "max_total_tokens": 2048,
688
+ "max_input_length": 1024,
689
+ "max_batch_prefill_tokens": 4096,
690
+ "model_family": "pythia-deduped",
691
+ "model_size": 70e6,
692
+ },
693
+ ################################################
694
+ # GPT2 #
695
+ ################################################
696
+ "gpt2-xl": {
697
+ "name": "gpt2-xl",
698
+ "model_name": "gpt2-xl",
699
+ "model_path": "gpt2-xl",
700
+ "num_gpus": 1,
701
+ "batch_size": 16,
702
+ "is_chat": False,
703
+ "max_total_tokens": 1024,
704
+ "max_input_length": 256,
705
+ "max_batch_prefill_tokens": 4096,
706
+ "model_size": 1.5e9,
707
+ "model_family": "gpt2",
708
+ },
709
+ "gpt2-large": {
710
+ "name": "gpt2-large",
711
+ "model_name": "gpt2-large",
712
+ "model_path": "gpt2-large",
713
+ "num_gpus": 1,
714
+ "batch_size": 16,
715
+ "is_chat": False,
716
+ "max_total_tokens": 1024,
717
+ "max_input_length": 256,
718
+ "max_batch_prefill_tokens": 4096,
719
+ "model_size": 774e6,
720
+ "model_family": "gpt2",
721
+ },
722
+ "gpt2-medium": {
723
+ "name": "gpt2-medium",
724
+ "model_name": "gpt2-medium",
725
+ "model_path": "gpt2-medium",
726
+ "num_gpus": 1,
727
+ "batch_size": 16,
728
+ "is_chat": False,
729
+ "max_total_tokens": 2048,
730
+ "max_input_length": 1024,
731
+ "max_batch_prefill_tokens": 4096,
732
+ "model_size": 355e6,
733
+ "model_family": "gpt2",
734
+ },
735
+ "gpt2": {
736
+ "name": "gpt2",
737
+ "model_name": "gpt2",
738
+ "model_path": "gpt2",
739
+ "num_gpus": 1,
740
+ "batch_size": 16,
741
+ "is_chat": False,
742
+ "max_total_tokens": 2048,
743
+ "max_input_length": 1024,
744
+ "max_batch_prefill_tokens": 4096,
745
+ "model_size": 124e6,
746
+ "model_family": "gpt2",
747
+ },
748
+ ################################################
749
+ # CEREBRAS #
750
+ ################################################
751
+ "cerebras-gpt-13b": { # add 2 gpus but sharded equals to false
752
+ "name": "cerebras-gpt-13b",
753
+ "model_name": "cerebras/Cerebras-GPT-13B",
754
+ "model_path": "cerebras-Cerebras-GPT-13B",
755
+ "num_gpus": 1,
756
+ "batch_size": 8,
757
+ "is_chat": False,
758
+ "max_total_tokens": 2048,
759
+ "max_input_length": 1024,
760
+ "max_batch_prefill_tokens": 4096,
761
+ "model_family": "cerebras",
762
+ "model_size": 13e9,
763
+ },
764
+ "cerebras-gpt-6.7b": {
765
+ "name": "cerebras-gpt-6.7b",
766
+ "model_name": "cerebras/Cerebras-GPT-6.7B",
767
+ "model_path": "cerebras-Cerebras-GPT-6.7B",
768
+ "num_gpus": 1,
769
+ "batch_size": 8,
770
+ "is_chat": False,
771
+ "max_total_tokens": 1024,
772
+ "max_input_length": 256,
773
+ "max_batch_prefill_tokens": 4096,
774
+ "model_family": "cerebras",
775
+ "model_size": 6.7e9,
776
+ },
777
+ "cerebras-gpt-2.7b": {
778
+ "name": "cerebras-gpt-2.7b",
779
+ "model_name": "cerebras/Cerebras-GPT-2.7B",
780
+ "model_path": "cerebras-Cerebras-GPT-2.7B",
781
+ "num_gpus": 1,
782
+ "batch_size": 1,
783
+ "is_chat": False,
784
+ "max_total_tokens": 2048,
785
+ "max_input_length": 1024,
786
+ "max_batch_prefill_tokens": 4096,
787
+ "model_family": "cerebras",
788
+ "model_size": 2.7e9,
789
+ },
790
+ "cerebras-gpt-1.3b": {
791
+ "name": "cerebras-gpt-1.3b",
792
+ "model_name": "cerebras/Cerebras-GPT-1.3B",
793
+ "model_path": "cerebras-Cerebras-GPT-1.3B",
794
+ "num_gpus": 1,
795
+ "batch_size": 1,
796
+ "is_chat": False,
797
+ "max_total_tokens": 1024,
798
+ "max_input_length": 256,
799
+ "max_batch_prefill_tokens": 4096,
800
+ "model_family": "cerebras",
801
+ "model_size": 1.3e9,
802
+ },
803
+ "cerebras-gpt-256m": {
804
+ "name": "cerebras-gpt-256m",
805
+ "model_name": "cerebras/Cerebras-GPT-256M",
806
+ "model_path": "cerebras-Cerebras-GPT-256M",
807
+ "num_gpus": 1,
808
+ "batch_size": 16,
809
+ "is_chat": False,
810
+ "max_total_tokens": 2048,
811
+ "max_input_length": 1024,
812
+ "max_batch_prefill_tokens": 4096,
813
+ "model_family": "cerebras",
814
+ "model_size": 256e6,
815
+ },
816
+ "cerebras-gpt-111m": {
817
+ "name": "cerebras-gpt-111m",
818
+ "model_name": "cerebras/Cerebras-GPT-111M",
819
+ "model_path": "cerebras-Cerebras-GPT-111M",
820
+ "num_gpus": 1,
821
+ "batch_size": 16,
822
+ "is_chat": False,
823
+ "max_total_tokens": 2048,
824
+ "max_input_length": 1024,
825
+ "max_batch_prefill_tokens": 4096,
826
+ "model_family": "cerebras",
827
+ "model_size": 111e6,
828
+ },
829
+ ################################################
830
+ # Bloom #
831
+ ################################################
832
+ "bloom-7.1b": {
833
+ "name": "bloom-7.1b",
834
+ "model_name": "bigscience/bloom-7b1",
835
+ "model_path": "bigscience-bloom-7b1",
836
+ "num_gpus": 1,
837
+ "batch_size": 8,
838
+ "is_chat": False,
839
+ "max_total_tokens": 1024,
840
+ "max_input_length": 256,
841
+ "max_batch_prefill_tokens": 4096,
842
+ "model_size": 7.1e9,
843
+ "model_family": "bloom",
844
+ },
845
+ "bloom-3b": {
846
+ "name": "bloom-3b",
847
+ "model_name": "bigscience/bloom-3b",
848
+ "model_path": "bigscience-bloom-3b",
849
+ "num_gpus": 1,
850
+ "batch_size": 16,
851
+ "is_chat": False,
852
+ "max_total_tokens": 2048,
853
+ "max_input_length": 1024,
854
+ "max_batch_prefill_tokens": 4096,
855
+ "model_size": 3e9,
856
+ "model_family": "bloom",
857
+ },
858
+ "bloom-1.7b": {
859
+ "name": "bloom-1.7b",
860
+ "model_name": "bigscience/bloom-1b7",
861
+ "model_path": "bigscience-bloom-1b7",
862
+ "num_gpus": 1,
863
+ "batch_size": 16,
864
+ "is_chat": False,
865
+ "max_total_tokens": 1024,
866
+ "max_input_length": 256,
867
+ "max_batch_prefill_tokens": 4096,
868
+ "model_size": 1.7e9,
869
+ "model_family": "bloom",
870
+ },
871
+ "bloom-1.1b": {
872
+ "name": "bloom-1.1b",
873
+ "model_name": "bigscience/bloom-1b1",
874
+ "model_path": "bigscience-bloom-1b1",
875
+ "num_gpus": 1,
876
+ "batch_size": 16,
877
+ "is_chat": False,
878
+ "max_total_tokens": 2048,
879
+ "max_input_length": 1024,
880
+ "max_batch_prefill_tokens": 4096,
881
+ "model_size": 1.1e9,
882
+ "model_family": "bloom",
883
+ },
884
+ "bloom-560m": {
885
+ "name": "bloom-560m",
886
+ "model_name": "bigscience/bloom-560m",
887
+ "model_path": "bigscience-bloom-560m",
888
+ "num_gpus": 1,
889
+ "batch_size": 16,
890
+ "is_chat": False,
891
+ "max_total_tokens": 1024,
892
+ "max_input_length": 256,
893
+ "max_batch_prefill_tokens": 4096,
894
+ "model_size": 560e6,
895
+ "model_family": "bloom",
896
+ },
897
+ ################################################
898
+ # Falcon #
899
+ ################################################
900
+ "falcon-40b": {
901
+ "name": "falcon-40b",
902
+ "model_name": "tiiuae/falcon-40b",
903
+ "model_path": "tiiuae-falcon-40b",
904
+ "num_gpus": 4,
905
+ "batch_size": 4,
906
+ "is_chat": False,
907
+ "max_total_tokens": 2048,
908
+ "max_input_length": 1024,
909
+ "max_batch_prefill_tokens": 4096,
910
+ "model_size": 40e9,
911
+ "model_family": "falcon",
912
+ },
913
+ "falcon-7b": {
914
+ "name": "falcon-7b",
915
+ "model_name": "tiiuae/falcon-7b",
916
+ "model_path": "tiiuae-falcon-7b",
917
+ "num_gpus": 1,
918
+ "batch_size": 8,
919
+ "is_chat": False,
920
+ "max_total_tokens": 2048,
921
+ "max_input_length": 1024,
922
+ "max_batch_prefill_tokens": 4096,
923
+ "model_size": 7e9,
924
+ "model_family": "falcon",
925
+ },
926
+ ################################################
927
+ # Falcon-chat #
928
+ ################################################
929
+ "falcon-40b-instruct": {
930
+ "name": "falcon-40b-instruct",
931
+ "model_name": "tiiuae/falcon-40b-instruct",
932
+ "model_path": "tiiuae-falcon-40b-instruct",
933
+ "num_gpus": 4,
934
+ "batch_size": 4,
935
+ "is_chat": True,
936
+ "prompt": FALCON_PROMPT,
937
+ "stopword": FALCON_STOPWORD,
938
+ "max_total_tokens": 2048,
939
+ "max_input_length": 1024,
940
+ "max_batch_prefill_tokens": 4096,
941
+ "model_family": "falcon",
942
+ "model_size": 40e9,
943
+ },
944
+ "falcon-7b-instruct": {
945
+ "name": "falcon-7b-instruct",
946
+ "model_name": "tiiuae/falcon-7b-instruct",
947
+ "model_path": "tiiuae-falcon-7b-instruct",
948
+ "num_gpus": 1,
949
+ "batch_size": 5,
950
+ "is_chat": True,
951
+ "prompt": FALCON_PROMPT,
952
+ "stopword": FALCON_STOPWORD,
953
+ "max_total_tokens": 2048,
954
+ "max_input_length": 1024,
955
+ "max_batch_prefill_tokens": 4096,
956
+ "model_family": "falcon",
957
+ "model_size": 7e9,
958
+ },
959
+ "alfred-40b-0723": {
960
+ "name": "alfred-40b-0723",
961
+ "model_name": "lightonai/alfred-40b-0723",
962
+ "model_path": "lightonai-alfred-40b-0723",
963
+ "num_gpus": 4,
964
+ "batch_size": 4,
965
+ "is_chat": True,
966
+ "prompt": ALFRED_PROMPT,
967
+ "stopword": ALFRED_STOPWORD,
968
+ "max_total_tokens": 2048,
969
+ "max_input_length": 1024,
970
+ "max_batch_prefill_tokens": 4096,
971
+ "model_family": "falcon",
972
+ "model_size": 40e9,
973
+ },
974
+ ################################################
975
+ # Vicuna v1.3 #
976
+ ################################################
977
+ "vicuna-33b-v1.3": {
978
+ "name": "vicuna-33b-v1.3",
979
+ "model_name": "lmsys/vicuna-33b-v1.3",
980
+ "model_path": "lmsys-vicuna-33b-v1.3",
981
+ "num_gpus": 2,
982
+ "batch_size": 2,
983
+ "is_chat": True,
984
+ "prompt": VICUNA_PROMPT,
985
+ "stopword": VICUNA_STOPWORD,
986
+ "max_total_tokens": 2048,
987
+ "max_input_length": 1024,
988
+ "max_batch_prefill_tokens": 4096,
989
+ "model_family": "vicuna",
990
+ "model_size": 33e9,
991
+ },
992
+ "vicuna-13b-v1.3": {
993
+ "name": "vicuna-13b-v1.3",
994
+ "model_name": "lmsys/vicuna-13b-v1.3",
995
+ "model_path": "lmsys-vicuna-13b-v1.3",
996
+ "num_gpus": 2,
997
+ "batch_size": 8,
998
+ "is_chat": True,
999
+ "prompt": VICUNA_PROMPT,
1000
+ "stopword": VICUNA_STOPWORD,
1001
+ "max_total_tokens": 2048,
1002
+ "max_input_length": 1024,
1003
+ "max_batch_prefill_tokens": 4096,
1004
+ "model_family": "vicuna",
1005
+ "model_size": 13e9,
1006
+ },
1007
+ "vicuna-7b-v1.3": {
1008
+ "name": "vicuna-7b-v1.3",
1009
+ "model_name": "lmsys/vicuna-7b-v1.3",
1010
+ "model_path": "lmsys-vicuna-7b-v1.3",
1011
+ "num_gpus": 1,
1012
+ "batch_size": 4,
1013
+ "is_chat": True,
1014
+ "prompt": VICUNA_PROMPT,
1015
+ "stopword": VICUNA_STOPWORD,
1016
+ "max_total_tokens": 2048,
1017
+ "max_input_length": 1024,
1018
+ "max_batch_prefill_tokens": 4096,
1019
+ "model_family": "vicuna",
1020
+ "model_size": 7e9,
1021
+ },
1022
+ }
1023
+
1024
+
1025
+ MODEL_FAMILY_PRETRAINING_DATASETS = {
1026
+ "llama-2": ["UNK-commoncrawl"],
1027
+ "llama-1": [
1028
+ "llama",
1029
+ "c4",
1030
+ "github",
1031
+ "wikipedia",
1032
+ "books3",
1033
+ "gutenberg",
1034
+ "arxiv",
1035
+ "stackexchange",
1036
+ ],
1037
+ "openllama": [
1038
+ "redpajama",
1039
+ "c4",
1040
+ "github",
1041
+ "wikipedia",
1042
+ "books3",
1043
+ "gutenberg",
1044
+ "arxiv",
1045
+ "stackexchange",
1046
+ ],
1047
+ "openllama-2": [
1048
+ "refinedweb",
1049
+ "github",
1050
+ "wikipedia",
1051
+ "books3",
1052
+ "gutenberg",
1053
+ "arxiv",
1054
+ "stackexchange",
1055
+ ],
1056
+ "pythia": [
1057
+ "thepile",
1058
+ "pubmed",
1059
+ "books3",
1060
+ "arxiv",
1061
+ "github",
1062
+ "openwebtext2",
1063
+ "freelaw",
1064
+ "wikipedia",
1065
+ "stackexchange",
1066
+ "uspto",
1067
+ "gutenberg",
1068
+ "opensubtitles",
1069
+ "mathematics",
1070
+ "bookcorpus2",
1071
+ "ubuntuIRC",
1072
+ "europarl",
1073
+ "philpapers",
1074
+ "nih-grants" "hackernews",
1075
+ "enron",
1076
+ ],
1077
+ "gpt2": ["openwebtext"],
1078
+ "cerebras": [
1079
+ "thepile",
1080
+ "pubmed",
1081
+ "books3",
1082
+ "arxiv",
1083
+ "github",
1084
+ "openwebtext2",
1085
+ "freelaw",
1086
+ "wikipedia",
1087
+ "stackexchange",
1088
+ "uspto",
1089
+ "gutenberg",
1090
+ "opensubtitles",
1091
+ "mathematics",
1092
+ "bookcorpus2",
1093
+ "ubuntuIRC",
1094
+ "europarl",
1095
+ "philpapers",
1096
+ "nih-grants" "hackernews",
1097
+ "enron",
1098
+ ],
1099
+ "bloom": [
1100
+ "oscar",
1101
+ "github",
1102
+ "commoncrawl-bloom",
1103
+ ],
1104
+ "falcon": [
1105
+ "refinedweb",
1106
+ "pubmed",
1107
+ "books3",
1108
+ "arxiv",
1109
+ "github",
1110
+ "openwebtext2",
1111
+ "freelaw",
1112
+ "wikipedia",
1113
+ "stackexchange",
1114
+ "uspto",
1115
+ "gutenberg",
1116
+ "opensubtitles",
1117
+ "mathematics",
1118
+ "bookcorpus2",
1119
+ "ubuntuIRC",
1120
+ "europarl",
1121
+ "philpapers",
1122
+ "nih-grants" "hackernews",
1123
+ "enron",
1124
+ ],
1125
+ "mpt": [
1126
+ "c4",
1127
+ "mc4",
1128
+ "redpajama",
1129
+ "github",
1130
+ "wikipedia",
1131
+ "books3",
1132
+ "gutenberg",
1133
+ "arxiv",
1134
+ "stackexchange",
1135
+ ],
1136
+ "opt": [
1137
+ "cc-news",
1138
+ "cc-stories",
1139
+ "thepile",
1140
+ "reddit" "pubmed",
1141
+ "books3",
1142
+ "github",
1143
+ "openwebtext2",
1144
+ "wikipedia",
1145
+ "uspto",
1146
+ "gutenberg",
1147
+ "opensubtitles",
1148
+ "mathematics",
1149
+ "bookcorpus2",
1150
+ "hackernews",
1151
+ ],
1152
+ }
1153
+
1154
+
1155
+ if __name__ == "__main__":
1156
+ print(len(MODELS))
1157
+ print("\n".join(MODELS.keys()))
visualize_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hex_to_rgb(value):
5
+ """
6
+ Calculates rgb values from a hex color code.
7
+
8
+ :param (string) value: Hex color string
9
+
10
+ :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values
11
+ """
12
+ value = value.lstrip("#")
13
+ hex_total_length = len(value)
14
+ rgb_section_length = hex_total_length // 3
15
+ return tuple(
16
+ int(value[i : i + rgb_section_length], 16)
17
+ for i in range(0, hex_total_length, rgb_section_length)
18
+ )
19
+
20
+
21
+ viridis = [
22
+ [0, "#440154"],
23
+ [0.06274509803921569, "#48186a"],
24
+ [0.12549019607843137, "#472d7b"],
25
+ [0.18823529411764706, "#424086"],
26
+ [0.25098039215686274, "#3b528b"],
27
+ [0.3137254901960784, "#33638d"],
28
+ [0.3764705882352941, "#2c728e"],
29
+ [0.4392156862745098, "#26828e"],
30
+ [0.5019607843137255, "#21918c"],
31
+ [0.5647058823529412, "#1fa088"],
32
+ [0.6274509803921569, "#28ae80"],
33
+ [0.6901960784313725, "#3fbc73"],
34
+ [0.7529411764705882, "#5ec962"],
35
+ [0.8156862745098039, "#84d44b"],
36
+ [0.8784313725490196, "#addc30"],
37
+ [0.9411764705882353, "#d8e219"],
38
+ [1, "#fde725"],
39
+ ]
40
+ # Define the power parameter for the transformation
41
+ power = 0.23 # You can adjust this value as needed
42
+
43
+ # Apply the power transformation to the values in the colorscale
44
+ for i in range(len(viridis)):
45
+ viridis[i][0] = np.power(viridis[i][0], power)
46
+
47
+ # Normalize the transformed values to [0, 1]
48
+ max_value = max(v[0] for v in viridis)
49
+ for i in range(len(viridis)):
50
+ viridis[i][0] /= max_value
51
+
52
+ # Sort the colorscale by the normalized values
53
+ viridis.sort(key=lambda x: x[0])
54
+ viridis_rgb = [[x[0], "rgb" + str(hex_to_rgb(x[1]))] for x in viridis]
55
+
56
+ # reverse the colorscale
57
+ viridis_rgb = [[x[0], y[1]] for x, y in zip(viridis_rgb, viridis_rgb[::-1])]