wissamantoun commited on
Commit
7cf7655
·
verified ·
1 Parent(s): a1925cb

added watermarking and quantization exp

Browse files
Files changed (1) hide show
  1. app.py +263 -31
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
 
3
  import numpy as np
4
  import pandas as pd
@@ -11,8 +12,6 @@ from plotly.subplots import make_subplots
11
  from exp_utils import MODELS
12
  from visualize_utils import viridis_rgb
13
 
14
- #
15
-
16
  st.set_page_config(
17
  page_title="Results Viewer",
18
  page_icon="📊",
@@ -23,14 +22,35 @@ st.set_page_config(
23
  MODELS_SIZE_MAPPING = {k: v["model_size"] for k, v in MODELS.items()}
24
  MODELS_FAMILY_MAPPING = {k: v["model_family"] for k, v in MODELS.items()}
25
  MODEL_FAMILES = set([model["model_family"] for model in MODELS.values()])
26
- MODEL_NAMES = list(MODELS.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  MODEL_NAMES_SORTED_BY_NAME_AND_SIZE = sorted(
29
- MODEL_NAMES, key=lambda x: (MODELS[x]["model_family"], MODELS[x]["model_size"])
 
 
 
 
30
  )
31
 
32
  MODEL_NAMES_SORTED_BY_SIZE = sorted(
33
- MODEL_NAMES, key=lambda x: (MODELS[x]["model_size"], MODELS[x]["model_family"])
 
 
 
 
34
  )
35
 
36
 
@@ -43,7 +63,11 @@ MODELS_SIZE_MAPPING = {
43
  MODELS_SIZE_MAPPING_LIST = list(MODELS_SIZE_MAPPING.keys())
44
 
45
 
46
- CHAT_MODELS = [x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if MODELS[x]["is_chat"]]
 
 
 
 
47
 
48
 
49
  def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
@@ -66,7 +90,11 @@ def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
66
  df.columns = df.columns.str.replace("_roc_auc", "")
67
  df.columns = df.columns.str.replace("eval_", "")
68
 
69
- df["model_family"] = df["model_name"].map(MODELS_FAMILY_MAPPING)
 
 
 
 
70
  # create a dict with the model_name and the model_family
71
  model_family_dict = {
72
  k: v
@@ -84,8 +112,16 @@ def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
84
  df_std = df_std.drop(columns=["exp_seed"])
85
  df_avg["model_family"] = df_avg.index.map(model_family_dict)
86
  df_std["model_family"] = df_std.index.map(model_family_dict)
87
- df_avg["model_size"] = df_avg.index.map(MODELS_SIZE_MAPPING)
88
- df_std["model_size"] = df_std.index.map(MODELS_SIZE_MAPPING)
 
 
 
 
 
 
 
 
89
 
90
  # sort rows by model family then model size
91
  df_avg = df_avg.sort_values(
@@ -101,10 +137,15 @@ def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
101
  availables_rows = [x for x in df_std.columns if x in df_std.index]
102
  df_std = df_std.reindex(availables_rows)
103
 
 
 
 
 
 
104
  return df_avg, df_std
105
 
106
 
107
- def get_data(path):
108
  df, df_std = clean_dataframe(pd.read_csv(path, index_col=0))
109
  return df, df_std
110
 
@@ -117,8 +158,15 @@ def filter_df(
117
  model_size_test: tuple,
118
  is_chat_train: bool,
119
  is_chat_test: bool,
 
 
 
 
120
  sort_by_size: bool,
121
  split_chat_models: bool,
 
 
 
122
  is_debug: bool,
123
  ) -> pd.DataFrame:
124
  # remove all columns and rows that have "pythia-70m" in the name
@@ -143,6 +191,16 @@ def filter_df(
143
  if is_debug:
144
  st.write("Filter is chat train")
145
  st.write(df)
 
 
 
 
 
 
 
 
 
 
146
 
147
  # filter columns
148
  if is_debug:
@@ -150,8 +208,13 @@ def filter_df(
150
  st.write(df)
151
  columns_to_keep = []
152
  for column in df.columns:
153
- if column in MODELS.keys():
154
- model_size = MODELS[column]["model_size"]
 
 
 
 
 
155
  if (
156
  model_size >= model_size_test[0] * 1e9
157
  and model_size <= model_size_test[1] * 1e9
@@ -167,7 +230,12 @@ def filter_df(
167
  columns_to_keep = []
168
  for column in df.columns:
169
  for model_family in model_family_test:
170
- if model_family == MODELS[column]["model_family"]:
 
 
 
 
 
171
  columns_to_keep.append(column)
172
  df = df[list(sorted(list(set(columns_to_keep))))]
173
  if is_debug:
@@ -178,13 +246,44 @@ def filter_df(
178
  # filter columns
179
  columns_to_keep = []
180
  for column in df.columns:
181
- if MODELS[column]["is_chat"] == is_chat_test:
 
 
 
 
 
182
  columns_to_keep.append(column)
183
  df = df[list(sorted(list(set(columns_to_keep))))]
184
  if is_debug:
185
  st.write("Filter is chat test")
186
  st.write(df)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  df = df.select_dtypes(include="number")
189
  if is_debug:
190
  st.write("Select dtypes to be only numbers")
@@ -227,10 +326,121 @@ def filter_df(
227
  if is_debug:
228
  st.write("Split chat models")
229
  st.write(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  return df
231
 
232
 
233
  df, df_std = get_data("./deberta_results.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  with open("./ood_results.json", "r") as f:
236
  ood_results = json.load(f)
@@ -258,11 +468,14 @@ st.write(
258
  )
259
 
260
  # filters
261
- show_diff = st.sidebar.checkbox("Show Diff", value=False)
262
- sort_by_size = st.sidebar.checkbox("Sort by size", value=False)
263
- split_chat_models = st.sidebar.checkbox("Split chat models", value=False)
 
 
264
  add_mean = st.sidebar.checkbox("Add mean", value=False)
265
  show_std = st.sidebar.checkbox("Show std", value=False)
 
266
  model_size_train = st.sidebar.slider(
267
  "Train Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1
268
  )
@@ -271,6 +484,18 @@ model_size_test = st.sidebar.slider(
271
  )
272
  is_chat_train = st.sidebar.selectbox("(Train) Is Chat?", [True, False, "Both"], index=2)
273
  is_chat_test = st.sidebar.selectbox("(Test) Is Chat?", [True, False, "Both"], index=2)
 
 
 
 
 
 
 
 
 
 
 
 
274
  model_family_train = st.sidebar.multiselect(
275
  "Model Family Train",
276
  MODEL_FAMILES,
@@ -282,6 +507,8 @@ model_family_test = st.sidebar.multiselect(
282
  default=MODEL_FAMILES,
283
  )
284
 
 
 
285
  add_adversarial = False
286
  if "Adversarial" in model_family_test:
287
  model_family_test.remove("Adversarial")
@@ -304,14 +531,6 @@ if show_std:
304
  else:
305
  selected_df = df.copy()
306
 
307
- if show_diff:
308
- # get those 3 columns {'model_size', 'model_family', 'is_chat'}
309
- columns_to_keep = ["model_size", "model_family", "is_chat"]
310
- to_be_added = selected_df[columns_to_keep]
311
- selected_df = selected_df.drop(columns=columns_to_keep)
312
- selected_df = selected_df.sub(selected_df.values.diagonal(), axis=1)
313
- selected_df = selected_df.join(to_be_added)
314
-
315
 
316
  filtered_df = filter_df(
317
  selected_df,
@@ -321,18 +540,32 @@ filtered_df = filter_df(
321
  model_size_test,
322
  is_chat_train,
323
  is_chat_test,
 
 
 
 
324
  sort_by_size,
325
  split_chat_models,
 
 
 
326
  is_debug,
327
  )
328
 
329
 
330
- # subtract each row by the diagonal
 
 
 
331
 
332
- # if show_diff:
333
- # filtered_df = filtered_df.sub(filtered_df.values.diagonal(), axis=1)
334
  if add_adversarial:
335
- filtered_df = filtered_df.join(ood_results_avg)
 
 
 
 
 
336
 
337
  if add_mean:
338
  col_mean = filtered_df.mean(axis=1)
@@ -341,7 +574,6 @@ if add_mean:
341
  filtered_df["mean"] = col_mean
342
  filtered_df.loc["mean"] = row_mean
343
 
344
-
345
  filtered_df = filtered_df * 100
346
  filtered_df = filtered_df.round(0)
347
 
@@ -364,7 +596,7 @@ fig = px.imshow(
364
  y=list(filtered_df.index),
365
  color_continuous_scale=color_scale,
366
  contrast_rescaling=None,
367
- text_auto=True,
368
  aspect="auto",
369
  )
370
 
 
1
  import json
2
+ from typing import Tuple
3
 
4
  import numpy as np
5
  import pandas as pd
 
12
  from exp_utils import MODELS
13
  from visualize_utils import viridis_rgb
14
 
 
 
15
  st.set_page_config(
16
  page_title="Results Viewer",
17
  page_icon="📊",
 
22
  MODELS_SIZE_MAPPING = {k: v["model_size"] for k, v in MODELS.items()}
23
  MODELS_FAMILY_MAPPING = {k: v["model_family"] for k, v in MODELS.items()}
24
  MODEL_FAMILES = set([model["model_family"] for model in MODELS.values()])
25
+ Q_W_MODELS = [
26
+ "llama-7b",
27
+ "llama-2-7b",
28
+ "llama-13b",
29
+ "llama-2-13b",
30
+ "llama-30b",
31
+ "llama-65b",
32
+ "llama-2-70b",
33
+ ]
34
+ Q_W_MODELS = [f"{model}_quantized" for model in Q_W_MODELS] + [
35
+ f"{model}_watermarked" for model in Q_W_MODELS
36
+ ]
37
+
38
+ MODEL_NAMES = list(MODELS.keys()) + Q_W_MODELS
39
 
40
  MODEL_NAMES_SORTED_BY_NAME_AND_SIZE = sorted(
41
+ MODEL_NAMES,
42
+ key=lambda x: (
43
+ MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"],
44
+ MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"],
45
+ ),
46
  )
47
 
48
  MODEL_NAMES_SORTED_BY_SIZE = sorted(
49
+ MODEL_NAMES,
50
+ key=lambda x: (
51
+ MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"],
52
+ MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"],
53
+ ),
54
  )
55
 
56
 
 
63
  MODELS_SIZE_MAPPING_LIST = list(MODELS_SIZE_MAPPING.keys())
64
 
65
 
66
+ CHAT_MODELS = [
67
+ x
68
+ for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE
69
+ if MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["is_chat"]
70
+ ]
71
 
72
 
73
  def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
 
90
  df.columns = df.columns.str.replace("_roc_auc", "")
91
  df.columns = df.columns.str.replace("eval_", "")
92
 
93
+ df["model_family"] = df["model_name"].apply(
94
+ lambda x: MODELS_FAMILY_MAPPING[
95
+ x.replace("_quantized", "").replace("_watermarked", "")
96
+ ]
97
+ )
98
  # create a dict with the model_name and the model_family
99
  model_family_dict = {
100
  k: v
 
112
  df_std = df_std.drop(columns=["exp_seed"])
113
  df_avg["model_family"] = df_avg.index.map(model_family_dict)
114
  df_std["model_family"] = df_std.index.map(model_family_dict)
115
+ df_avg["model_size"] = df_avg.index.map(
116
+ lambda x: MODELS_SIZE_MAPPING[
117
+ x.replace("_quantized", "").replace("_watermarked", "")
118
+ ]
119
+ )
120
+ df_std["model_size"] = df_std.index.map(
121
+ lambda x: MODELS_SIZE_MAPPING[
122
+ x.replace("_quantized", "").replace("_watermarked", "")
123
+ ]
124
+ )
125
 
126
  # sort rows by model family then model size
127
  df_avg = df_avg.sort_values(
 
137
  availables_rows = [x for x in df_std.columns if x in df_std.index]
138
  df_std = df_std.reindex(availables_rows)
139
 
140
+ df_avg["is_quantized"] = df_avg.index.str.contains("quantized")
141
+ df_avg["is_watermarked"] = df_avg.index.str.contains("watermarked")
142
+ df_std["is_quantized"] = df_std.index.str.contains("quantized")
143
+ df_std["is_watermarked"] = df_std.index.str.contains("watermarked")
144
+
145
  return df_avg, df_std
146
 
147
 
148
+ def get_data(path) -> Tuple[pd.DataFrame, pd.DataFrame]:
149
  df, df_std = clean_dataframe(pd.read_csv(path, index_col=0))
150
  return df, df_std
151
 
 
158
  model_size_test: tuple,
159
  is_chat_train: bool,
160
  is_chat_test: bool,
161
+ is_quantized_train: bool,
162
+ is_quantized_test: bool,
163
+ is_watermarked_train: bool,
164
+ is_watermarked_test: bool,
165
  sort_by_size: bool,
166
  split_chat_models: bool,
167
+ split_quantized_models: bool,
168
+ split_watermarked_models: bool,
169
+ filter_empty_col_row: bool,
170
  is_debug: bool,
171
  ) -> pd.DataFrame:
172
  # remove all columns and rows that have "pythia-70m" in the name
 
191
  if is_debug:
192
  st.write("Filter is chat train")
193
  st.write(df)
194
+ if is_quantized_train != "Both":
195
+ df = df.loc[df["is_quantized"] == is_quantized_train]
196
+ if is_debug:
197
+ st.write("Filter is quantized train")
198
+ st.write(df)
199
+ if is_watermarked_train != "Both":
200
+ df = df.loc[df["is_watermarked"] == is_watermarked_train]
201
+ if is_debug:
202
+ st.write("Filter is watermark train")
203
+ st.write(df)
204
 
205
  # filter columns
206
  if is_debug:
 
208
  st.write(df)
209
  columns_to_keep = []
210
  for column in df.columns:
211
+ if (
212
+ column.replace("_quantized", "").replace("_watermarked", "")
213
+ in MODELS.keys()
214
+ ):
215
+ model_size = MODELS[
216
+ column.replace("_quantized", "").replace("_watermarked", "")
217
+ ]["model_size"]
218
  if (
219
  model_size >= model_size_test[0] * 1e9
220
  and model_size <= model_size_test[1] * 1e9
 
230
  columns_to_keep = []
231
  for column in df.columns:
232
  for model_family in model_family_test:
233
+ if (
234
+ model_family
235
+ == MODELS[column.replace("_quantized", "").replace("_watermarked", "")][
236
+ "model_family"
237
+ ]
238
+ ):
239
  columns_to_keep.append(column)
240
  df = df[list(sorted(list(set(columns_to_keep))))]
241
  if is_debug:
 
246
  # filter columns
247
  columns_to_keep = []
248
  for column in df.columns:
249
+ if (
250
+ MODELS[column.replace("_quantized", "").replace("_watermarked", "")][
251
+ "is_chat"
252
+ ]
253
+ == is_chat_test
254
+ ):
255
  columns_to_keep.append(column)
256
  df = df[list(sorted(list(set(columns_to_keep))))]
257
  if is_debug:
258
  st.write("Filter is chat test")
259
  st.write(df)
260
 
261
+ if is_quantized_test != "Both":
262
+ # filter columns
263
+ columns_to_keep = []
264
+ for column in df.columns:
265
+ if "quantized" in column and is_quantized_test:
266
+ columns_to_keep.append(column)
267
+ elif "quantized" not in column and not is_quantized_test:
268
+ columns_to_keep.append(column)
269
+ df = df[list(sorted(list(set(columns_to_keep))))]
270
+ if is_debug:
271
+ st.write("Filter is quantized test")
272
+ st.write(df)
273
+
274
+ if is_watermarked_test != "Both":
275
+ # filter columns
276
+ columns_to_keep = []
277
+ for column in df.columns:
278
+ if "watermark" in column and is_watermarked_test:
279
+ columns_to_keep.append(column)
280
+ elif "watermark" not in column and not is_watermarked_test:
281
+ columns_to_keep.append(column)
282
+ df = df[list(sorted(list(set(columns_to_keep))))]
283
+ if is_debug:
284
+ st.write("Filter is watermark test")
285
+ st.write(df)
286
+
287
  df = df.select_dtypes(include="number")
288
  if is_debug:
289
  st.write("Select dtypes to be only numbers")
 
326
  if is_debug:
327
  st.write("Split chat models")
328
  st.write(df)
329
+
330
+ if split_quantized_models:
331
+ # put chat models at the end of the columns
332
+ quantized_models = [
333
+ x for x in Q_W_MODELS if x in df.columns and "quantized" in x
334
+ ]
335
+ # sort chat models by size
336
+ quantized_models = sorted(
337
+ quantized_models,
338
+ key=lambda x: MODELS[
339
+ x.replace("_quantized", "").replace("_watermarked", "")
340
+ ]["model_size"],
341
+ )
342
+ df = df[[x for x in df.columns if x not in quantized_models] + quantized_models]
343
+
344
+ # put chat models at the end of the rows
345
+ quantized_models = [x for x in Q_W_MODELS if x in df.index and "quantized" in x]
346
+ # sort chat models by size
347
+ quantized_models = sorted(
348
+ quantized_models,
349
+ key=lambda x: MODELS[
350
+ x.replace("_quantized", "").replace("_watermarked", "")
351
+ ]["model_size"],
352
+ )
353
+ df = df.reindex(
354
+ [x for x in df.index if x not in quantized_models] + quantized_models
355
+ )
356
+
357
+ if split_watermarked_models:
358
+ # put chat models at the end of the columns
359
+ watermarked_models = [
360
+ x for x in Q_W_MODELS if x in df.columns and "watermarked" in x
361
+ ]
362
+ # sort chat models by size
363
+ watermarked_models = sorted(
364
+ watermarked_models,
365
+ key=lambda x: MODELS[
366
+ x.replace("_quantized", "").replace("_watermarked", "")
367
+ ]["model_size"],
368
+ )
369
+ df = df[
370
+ [x for x in df.columns if x not in watermarked_models] + watermarked_models
371
+ ]
372
+
373
+ # put chat models at the end of the rows
374
+ watermarked_models = [
375
+ x for x in Q_W_MODELS if x in df.index and "watermarked" in x
376
+ ]
377
+ # sort chat models by size
378
+ watermarked_models = sorted(
379
+ watermarked_models,
380
+ key=lambda x: MODELS[
381
+ x.replace("_quantized", "").replace("_watermarked", "")
382
+ ]["model_size"],
383
+ )
384
+ df = df.reindex(
385
+ [x for x in df.index if x not in watermarked_models] + watermarked_models
386
+ )
387
+
388
+ if is_debug:
389
+ st.write("Split chat models")
390
+ st.write(df)
391
+
392
+ if filter_empty_col_row:
393
+ # remove all for which the row and column are Nan
394
+ df = df.dropna(axis=0, how="all")
395
+ df = df.dropna(axis=1, how="all")
396
  return df
397
 
398
 
399
  df, df_std = get_data("./deberta_results.csv")
400
+ df_q_w, df_std_q_w = get_data("./results_qantized_watermarked.csv")
401
+
402
+ df = df.merge(
403
+ df_q_w[
404
+ df_q_w.columns[
405
+ df_q_w.columns.str.contains("quantized|watermarked", case=False, regex=True)
406
+ ]
407
+ ],
408
+ how="outer",
409
+ left_index=True,
410
+ right_index=True,
411
+ )
412
+ df_std = df_std.merge(
413
+ df_std_q_w[
414
+ df_std_q_w.columns[
415
+ df_std_q_w.columns.str.contains(
416
+ "quantized|watermarked", case=False, regex=True
417
+ )
418
+ ]
419
+ ],
420
+ how="outer",
421
+ left_index=True,
422
+ right_index=True,
423
+ )
424
+
425
+
426
+ df.columns = df.columns.str.replace("_y", "", regex=True)
427
+ df_std.columns = df_std.columns.str.replace("_y", "", regex=True)
428
+
429
+ df = df.drop(columns=["is_quantized_x", "is_watermarked_x"])
430
+
431
+
432
+ df.update(df_q_w)
433
+ df_std.update(df_std_q_w)
434
+
435
+
436
+ df["is_chat"].fillna(False, inplace=True)
437
+ df_std["is_chat"].fillna(False, inplace=True)
438
+
439
+ df["is_watermarked"].fillna(False, inplace=True)
440
+ df_std["is_watermarked"].fillna(False, inplace=True)
441
+
442
+ df["is_quantized"].fillna(False, inplace=True)
443
+ df_std["is_quantized"].fillna(False, inplace=True)
444
 
445
  with open("./ood_results.json", "r") as f:
446
  ood_results = json.load(f)
 
468
  )
469
 
470
  # filters
471
+ how_diff = st.sidebar.checkbox("Show Diff", value=False)
472
+ sort_by_size = st.sidebar.checkbox("Sort by size", value=True)
473
+ split_chat_models = st.sidebar.checkbox("Split chat models", value=True)
474
+ split_quantized_models = st.sidebar.checkbox("Split quantized models", value=True)
475
+ split_watermarked_models = st.sidebar.checkbox("Split watermarked models", value=True)
476
  add_mean = st.sidebar.checkbox("Add mean", value=False)
477
  show_std = st.sidebar.checkbox("Show std", value=False)
478
+ filter_empty_col_row = st.sidebar.checkbox("Filter empty col/row", value=True)
479
  model_size_train = st.sidebar.slider(
480
  "Train Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1
481
  )
 
484
  )
485
  is_chat_train = st.sidebar.selectbox("(Train) Is Chat?", [True, False, "Both"], index=2)
486
  is_chat_test = st.sidebar.selectbox("(Test) Is Chat?", [True, False, "Both"], index=2)
487
+ is_quantized_train = st.sidebar.selectbox(
488
+ "(Train) Is Quantized?", [True, False, "Both"], index=1
489
+ )
490
+ is_quantized_test = st.sidebar.selectbox(
491
+ "(Test) Is Quantized?", [True, False, "Both"], index=1
492
+ )
493
+ is_watermarked_train = st.sidebar.selectbox(
494
+ "(Train) Is Watermark?", [True, False, "Both"], index=1
495
+ )
496
+ is_watermarked_test = st.sidebar.selectbox(
497
+ "(Test) Is Watermark?", [True, False, "Both"], index=1
498
+ )
499
  model_family_train = st.sidebar.multiselect(
500
  "Model Family Train",
501
  MODEL_FAMILES,
 
507
  default=MODEL_FAMILES,
508
  )
509
 
510
+ show_values = st.sidebar.checkbox("Show Values", value=False)
511
+
512
  add_adversarial = False
513
  if "Adversarial" in model_family_test:
514
  model_family_test.remove("Adversarial")
 
531
  else:
532
  selected_df = df.copy()
533
 
 
 
 
 
 
 
 
 
534
 
535
  filtered_df = filter_df(
536
  selected_df,
 
540
  model_size_test,
541
  is_chat_train,
542
  is_chat_test,
543
+ is_quantized_train,
544
+ is_quantized_test,
545
+ is_watermarked_train,
546
+ is_watermarked_test,
547
  sort_by_size,
548
  split_chat_models,
549
+ split_quantized_models,
550
+ split_watermarked_models,
551
+ filter_empty_col_row,
552
  is_debug,
553
  )
554
 
555
 
556
+ if show_diff:
557
+ # get those 3 columns {'model_size', 'model_family', 'is_chat'}
558
+ diag = filtered_df.values.diagonal()
559
+ filtered_df = filtered_df.sub(diag, axis=1)
560
 
561
+ # subtract each row by the diagonal
 
562
  if add_adversarial:
563
+ if show_diff:
564
+ index = filtered_df.index
565
+ ood_results_avg = ood_results_avg.loc[index]
566
+ filtered_df = filtered_df.join(ood_results_avg.sub(diag, axis=0))
567
+ else:
568
+ filtered_df = filtered_df.join(ood_results_avg)
569
 
570
  if add_mean:
571
  col_mean = filtered_df.mean(axis=1)
 
574
  filtered_df["mean"] = col_mean
575
  filtered_df.loc["mean"] = row_mean
576
 
 
577
  filtered_df = filtered_df * 100
578
  filtered_df = filtered_df.round(0)
579
 
 
596
  y=list(filtered_df.index),
597
  color_continuous_scale=color_scale,
598
  contrast_rescaling=None,
599
+ text_auto=show_values,
600
  aspect="auto",
601
  )
602