hynky HF staff commited on
Commit
579c976
·
1 Parent(s): e906b0b

Revert "tmp revert lateer"

Browse files

This reverts commit e906b0bf37789596a54cc78a37fdfbd859aa4fc5.

Files changed (2) hide show
  1. app.py +566 -571
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,577 +1,572 @@
1
- # from concurrent.futures import ThreadPoolExecutor
2
- # import enum
3
- # from functools import partial
4
- # import json
5
- # import os
6
- # from pathlib import Path
7
- # import re
8
- # import heapq
9
- # import tempfile
10
- # from typing import Literal
11
- # import gradio as gr
12
-
13
- # from collections import defaultdict
14
- # from datatrove.io import get_datafolder
15
- # import plotly.graph_objects as go
16
- # from datatrove.utils.stats import MetricStats, MetricStatsDict
17
- # import plotly.express as px
18
- # import tenacity
19
-
20
- # import gradio as gr
21
- # PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
22
- # METRICS_LOCATION_DEFAULT = os.getenv("METRICS_LOCATION_DEFAULT", "s3://fineweb-stats/summary/")
23
-
24
-
25
- # def find_folders(base_folder, path):
26
- # base_folder = get_datafolder(base_folder)
27
- # if not base_folder.exists(path):
28
- # return []
29
- # return sorted(
30
- # [
31
- # folder["name"]
32
- # for folder in base_folder.ls(path, detail=True)
33
- # if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
34
- # ]
35
- # )
36
-
37
-
38
- # def find_metrics_folders(base_folder: str):
39
- # base_data_folder = get_datafolder(base_folder)
40
- # # First find all metric.json using globing for metric.json
41
- # metrics_merged = base_data_folder.glob("**/metric.json")
42
-
43
- # # Then for each of metrics.merged take the all but last two parts of the path (grouping/metric_name)
44
- # metrics_folders = [str(Path(x).parent.parent.parent) for x in metrics_merged]
45
- # # Finally get the unique paths
46
- # return sorted(list(set(metrics_folders)))
47
-
48
-
49
- # def fetch_datasets(base_folder: str):
50
- # datasets = sorted(find_metrics_folders(base_folder))
51
- # return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union")
52
-
53
-
54
- # def export_data(exported_data: MetricStatsDict, metric_name: str):
55
- # if not exported_data:
56
- # return None
57
- # # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported
58
- # temp_dir = tempfile.mkdtemp()
59
- # temp_path = os.path.join(temp_dir, metric_name + ".json")
60
- # with open(temp_path, "w") as temp_file:
61
- # json.dump({
62
- # name: dt.to_dict()
63
- # for name, dt in exported_data.items()
64
- # }, temp_file)
65
- # return gr.update(visible=True, value=temp_path)
66
-
67
-
68
- # def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
69
- # if not datasets:
70
- # return gr.update(choices=[], value=None)
71
-
72
- # with ThreadPoolExecutor() as executor:
73
- # GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
74
- # if len(GROUPS) == 0:
75
- # return gr.update(choices=[], value=None)
76
-
77
- # if type == "intersection":
78
- # new_choices = set.intersection(*(set(g) for g in GROUPS))
79
- # else:
80
- # new_choices = set.union(*(set(g) for g in GROUPS))
81
- # value = None
82
- # if old_groups:
83
- # value = list(set.intersection(new_choices, {old_groups}))
84
- # value = value[0] if value else None
85
-
86
- # # now take the intersection of all grups
87
- # return gr.update(choices=sorted(list(new_choices)), value=value)
88
-
89
-
90
- # def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"):
91
- # with ThreadPoolExecutor() as executor:
92
- # metrics = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
93
- # if len(metrics) == 0:
94
- # return gr.update(choices=[], value=None)
95
-
96
- # if type == "intersection":
97
- # new_possibles_choices = set.intersection(*(set(s) for s in metrics))
98
- # else:
99
- # new_possibles_choices = set.union(*(set(s) for s in metrics))
100
- # value = None
101
- # if old_metrics:
102
- # value = list(set.intersection(new_possibles_choices, {old_metrics}))
103
- # value = value[0] if value else None
104
-
105
- # return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
106
-
107
-
108
- # def reverse_search(base_folder, possible_datasets, grouping, metric_name):
109
- # with ThreadPoolExecutor() as executor:
110
- # found_datasets = list(executor.map(lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, possible_datasets))
111
- # found_datasets = [dataset for dataset in found_datasets if dataset is not None]
112
- # return "\n".join(found_datasets)
113
-
114
-
115
- # def reverse_search_add(datasets, reverse_search_results):
116
- # datasets = datasets or []
117
- # return sorted(list(set(datasets + reverse_search_results.strip().split("\n"))))
118
-
119
-
120
-
121
- # def metric_exists(base_folder, path, metric_name, group_by):
122
- # base_folder = get_datafolder(base_folder)
123
- # return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
124
-
125
- # @tenacity.retry(stop=tenacity.stop_after_attempt(5))
126
- # def load_metrics(base_folder, path, metric_name, group_by):
127
- # base_folder = get_datafolder(base_folder)
128
- # with base_folder.open(
129
- # f"{path}/{group_by}/{metric_name}/metric.json",
130
- # ) as f:
131
- # json_metric = json.load(f)
132
- # # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malformed
133
- # return MetricStatsDict.from_dict(json_metric)
134
-
135
-
136
- # def prepare_for_non_grouped_plotting(metric, normalization, rounding):
137
- # metrics_rounded = defaultdict(lambda: 0)
138
- # for key, value in metric.items():
139
- # metrics_rounded[round(float(key), rounding)] += value.total
140
- # if normalization:
141
- # normalizer = sum(metrics_rounded.values())
142
- # metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
143
- # # check that the sum of the values is 1
144
- # summed = sum(metrics_rounded.values())
145
- # assert abs(summed - 1) < 0.01, summed
146
- # return metrics_rounded
147
-
148
-
149
- # def load_data(dataset_path, base_folder, grouping, metric_name):
150
- # metrics = load_metrics(base_folder, dataset_path, metric_name, grouping)
151
- # return metrics
152
-
153
- # def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, regex: str | None, rounding: int):
154
- # regex_compiled = re.compile(regex) if regex else None
155
- # metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
156
- # means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
157
- # # Use heap to get top_k keys
158
- # if direction == "Top":
159
- # keys = heapq.nlargest(top_k, means, key=means.get)
160
- # elif direction == "Most frequent (n_docs)":
161
- # totals = {key: int(value.n) for key, value in metric.items()}
162
- # keys = heapq.nlargest(top_k, totals, key=totals.get)
163
- # else:
164
- # keys = heapq.nsmallest(top_k, means, key=means.get)
165
 
166
 
167
- # means = [means[key] for key in keys]
168
- # stds = [metric[key].standard_deviation for key in keys]
169
- # return keys, means, stds
170
-
171
-
172
- # def set_alpha(color, alpha):
173
- # """
174
- # Takes a hex color and returns
175
- # rgba(r, g, b, a)
176
- # """
177
- # if color.startswith('#'):
178
- # r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
179
- # else:
180
- # r, g, b = 0, 0, 0 # Fallback to black if the color format is not recognized
181
- # return f"rgba({r}, {g}, {b}, {alpha})"
182
-
183
-
184
- # def plot_scatter(
185
- # data: dict[str, dict[float, float]],
186
- # metric_name: str,
187
- # log_scale_x: bool,
188
- # log_scale_y: bool,
189
- # normalization: bool,
190
- # rounding: int,
191
- # progress: gr.Progress,
192
- # ):
193
- # fig = go.Figure()
194
-
195
- # # First sort the histograms, by their name
196
- # data = {name: histogram for name, histogram in sorted(data.items())}
197
- # for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
198
- # histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
199
- # x = sorted(histogram_prepared.keys())
200
- # y = [histogram_prepared[k] for k in x]
201
-
202
- # fig.add_trace(
203
- # go.Scatter(
204
- # x=x,
205
- # y=y,
206
- # mode="lines",
207
- # name=name,
208
- # marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
209
- # )
210
- # )
211
-
212
- # yaxis_title = "Frequency" if normalization else "Total"
213
-
214
- # fig.update_layout(
215
- # title=f"Line Plots for {metric_name}",
216
- # xaxis_title=metric_name,
217
- # yaxis_title=yaxis_title,
218
- # xaxis_type="log" if log_scale_x and len(x) > 1 else None,
219
- # yaxis_type="log" if log_scale_y and len(y) > 1 else None,
220
- # width=1200,
221
- # height=600,
222
- # showlegend=True,
223
- # )
224
-
225
- # return fig
226
-
227
-
228
- # def plot_bars(
229
- # data: dict[str, list[dict[str, float]]],
230
- # metric_name: str,
231
- # top_k: int,
232
- # direction: PARTITION_OPTIONS,
233
- # regex: str | None,
234
- # rounding: int,
235
- # log_scale_x: bool,
236
- # log_scale_y: bool,
237
- # progress: gr.Progress,
238
- # ):
239
- # fig = go.Figure()
240
- # x = []
241
- # y = []
242
-
243
- # for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
244
- # x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
245
-
246
- # fig.add_trace(go.Bar(
247
- # x=x,
248
- # y=y,
249
- # name=f"{name} Mean",
250
- # marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
251
- # # error_y=dict(type='data', array=stds, visible=True)
252
- # ))
253
-
254
- # fig.update_layout(
255
- # title=f"Bar Plots for {metric_name}",
256
- # xaxis_title=metric_name,
257
- # yaxis_title="Avg. value",
258
- # xaxis_type="log" if log_scale_x and len(x) > 1 else None,
259
- # yaxis_type="log" if log_scale_y and len(y) > 1 else None,
260
- # autosize=True,
261
- # width=1200,
262
- # height=600,
263
- # showlegend=True,
264
- # )
265
-
266
- # return fig
267
-
268
-
269
- # def update_graph(
270
- # base_folder,
271
- # datasets,
272
- # metric_name,
273
- # grouping,
274
- # log_scale_x,
275
- # log_scale_y,
276
- # rounding,
277
- # normalization,
278
- # top_k,
279
- # direction,
280
- # regex,
281
- # progress=gr.Progress(),
282
- # ):
283
- # if len(datasets) <= 0 or not metric_name or not grouping:
284
- # return None
285
- # # Placeholder for logic to rerender the graph based on the inputs
286
-
287
- # with ThreadPoolExecutor() as pool:
288
- # data = list(
289
- # progress.tqdm(
290
- # pool.map(
291
- # partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
292
- # datasets,
293
- # ),
294
- # total=len(datasets),
295
- # desc="Loading data...",
296
- # )
297
- # )
298
-
299
- # data = {path: result for path, result in zip(datasets, data)}
300
- # return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress), data, export_data(data, metric_name)
301
-
302
- # def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress=gr.Progress()):
303
- # if rounding is None or top_k is None:
304
- # return None
305
- # graph_fc = (
306
- # partial(plot_scatter, normalization=normalization, rounding=rounding)
307
- # if grouping == "histogram"
308
- # else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
309
- # )
310
- # return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, log_scale_y=log_scale_y)
311
-
312
-
313
-
314
- # # Create the Gradio interface
315
- # with gr.Blocks() as demo:
316
- # datasets = gr.State([])
317
- # exported_data = gr.State([])
318
- # metrics_headline = gr.Markdown(value="# Metrics Exploration")
319
- # with gr.Row():
320
- # with gr.Column(scale=2):
321
- # with gr.Row():
322
- # with gr.Column(scale=1):
323
- # base_folder = gr.Textbox(
324
- # label="Metrics Location",
325
- # value=METRICS_LOCATION_DEFAULT,
326
- # )
327
- # datasets_refetch = gr.Button("Fetch Datasets")
328
-
329
- # with gr.Column(scale=1):
330
- # regex_select = gr.Text(label="Regex filter", value=".*")
331
- # regex_button = gr.Button("Search")
332
- # with gr.Row():
333
- # datasets_selected = gr.Dropdown(
334
- # choices=[],
335
- # label="Datasets",
336
- # multiselect=True,
337
- # )
338
-
339
- # # add a readme description
340
- # readme_description = gr.Markdown(
341
- # label="Readme",
342
- # value="""
343
- # ## How to use:
344
- # 1) Specify Metrics location (Stats block `output_folder` without the last path segment) and click "Fetch Datasets"
345
- # 2) Select datasets you are interested in using the dropdown or regex filter
346
- # 3) Specify Grouping (global average/value/fqdn/suffix) and Metric name
347
- # 4) Click "Update Graph"
348
-
349
-
350
- # ## Groupings:
351
- # - **histogram**: Creates a line plot of values with their frequencies. If normalization is on, the frequencies sum to 1.
352
- # * normalize:
353
- # - **(fqdn/suffix)**: Creates a bar plot of the avg. values of the metric for full qualifed domain name/suffix of domain.
354
- # * k: the number of groups to show
355
- # * Top/Bottom/Most frequent (n_docs): Groups with the top/bottom k values/most prevalant docs are shown
356
- # - **none**: Shows the average value of given metric
357
-
358
- # ## Reverse search:
359
- # To search for datasets containing a grouping and certain metric, use the Reverse search section.
360
- # Specify the search parameters and click "Search". This will show you found datasets in the "Found datasets" textbox. You can modify the selection after search by removing unwanted lines and clicking "Add to selection".
361
-
362
- # ## Note:
363
- # The data might not be 100% representative, due to the sampling and optimistic merging of the metrics (fqdn/suffix).
364
- # """,
365
- # )
366
- # with gr.Column(scale=1):
367
- # # Define the dropdown for grouping
368
- # grouping_dropdown = gr.Dropdown(
369
- # choices=[],
370
- # label="Grouping",
371
- # multiselect=False,
372
- # )
373
- # # Define the dropdown for metric_name
374
- # metric_name_dropdown = gr.Dropdown(
375
- # choices=[],
376
- # label="Metric name",
377
- # multiselect=False,
378
- # )
379
-
380
-
381
- # update_button = gr.Button("Update Graph", variant="primary")
382
-
383
- # with gr.Row():
384
- # with gr.Column(scale=1):
385
- # log_scale_x_checkbox = gr.Checkbox(
386
- # label="Log scale x",
387
- # value=False,
388
- # )
389
- # log_scale_y_checkbox = gr.Checkbox(
390
- # label="Log scale y",
391
- # value=False,
392
- # )
393
- # rounding = gr.Number(
394
- # label="Rounding",
395
- # value=2,
396
- # )
397
- # normalization_checkbox = gr.Checkbox(
398
- # label="Normalize",
399
- # value=True, # Default value
400
- # visible=False
401
- # )
402
- # with gr.Row():
403
- # # export_data_button = gr.Button("Export data", visible=True, link=export_data_json)
404
- # export_data_json = gr.File(visible=False)
405
- # with gr.Column(scale=4):
406
- # with gr.Row(visible=False) as group_choices:
407
- # with gr.Column(scale=2):
408
- # group_regex = gr.Text(
409
- # label="Group Regex",
410
- # value=None,
411
- # )
412
- # with gr.Row():
413
- # top_select = gr.Number(
414
- # label="N Groups",
415
- # value=100,
416
- # interactive=True,
417
- # )
418
 
419
- # direction_checkbox = gr.Radio(
420
- # label="Partition",
421
- # choices=[
422
- # "Top",
423
- # "Bottom",
424
- # "Most frequent (n_docs)",
425
- # ],
426
- # value="Most frequent (n_docs)",
427
- # )
428
- # # Define the graph output
429
- # with gr.Row():
430
- # graph_output = gr.Plot(label="Graph")
431
 
432
- # with gr.Row():
433
- # reverse_search_headline = gr.Markdown(value="# Reverse metrics search")
434
 
435
- # with gr.Row():
436
- # with gr.Column(scale=1):
437
- # # Define the dropdown for grouping
438
- # reverse_grouping_dropdown = gr.Dropdown(
439
- # choices=[],
440
- # label="Grouping",
441
- # multiselect=False,
442
- # )
443
- # # Define the dropdown for metric_name
444
- # reverse_metric_name_dropdown = gr.Dropdown(
445
- # choices=[],
446
- # label="Stat name",
447
- # multiselect=False,
448
- # )
449
 
450
- # with gr.Column(scale=1):
451
- # reverse_search_button = gr.Button("Search")
452
- # reverse_search_add_button = gr.Button("Add to selection")
453
-
454
- # with gr.Column(scale=2):
455
- # reverse_search_results = gr.Textbox(
456
- # label="Found datasets",
457
- # lines=10,
458
- # placeholder="Found datasets containing the group/metric name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
459
- # )
460
-
461
-
462
- # update_button.click(
463
- # fn=update_graph,
464
- # inputs=[
465
- # base_folder,
466
- # datasets_selected,
467
- # metric_name_dropdown,
468
- # grouping_dropdown,
469
- # log_scale_x_checkbox,
470
- # log_scale_y_checkbox,
471
- # rounding,
472
- # normalization_checkbox,
473
- # top_select,
474
- # direction_checkbox,
475
- # group_regex,
476
- # ],
477
- # outputs=[graph_output, exported_data, export_data_json],
478
- # )
479
-
480
- # for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox, log_scale_y_checkbox]:
481
- # inp.change(
482
- # fn=plot_data,
483
- # inputs=[
484
- # exported_data,
485
- # metric_name_dropdown,
486
- # normalization_checkbox,
487
- # rounding,
488
- # grouping_dropdown,
489
- # top_select,
490
- # direction_checkbox,
491
- # group_regex,
492
- # log_scale_x_checkbox,
493
- # log_scale_y_checkbox,
494
- # ],
495
- # outputs=[graph_output],
496
- # )
497
-
498
-
499
-
500
- # datasets_selected.change(
501
- # fn=fetch_groups,
502
- # inputs=[base_folder, datasets_selected, grouping_dropdown],
503
- # outputs=grouping_dropdown,
504
- # )
505
-
506
- # grouping_dropdown.select(
507
- # fn=fetch_metrics,
508
- # inputs=[base_folder, datasets_selected, grouping_dropdown, metric_name_dropdown],
509
- # outputs=metric_name_dropdown,
510
- # )
511
-
512
- # reverse_grouping_dropdown.select(
513
- # fn=partial(fetch_metrics, type="union"),
514
- # inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
515
- # outputs=reverse_metric_name_dropdown,
516
- # )
517
-
518
- # reverse_search_button.click(
519
- # fn=reverse_search,
520
- # inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
521
- # outputs=reverse_search_results,
522
- # )
523
-
524
- # reverse_search_add_button.click(
525
- # fn=reverse_search_add,
526
- # inputs=[datasets_selected, reverse_search_results],
527
- # outputs=datasets_selected,
528
- # )
529
-
530
-
531
- # datasets_refetch.click(
532
- # fn=fetch_datasets,
533
- # inputs=[base_folder],
534
- # outputs=[datasets, datasets_selected, reverse_grouping_dropdown],
535
- # )
536
-
537
- # def update_datasets_with_regex(regex, selected_runs, all_runs):
538
- # if not regex:
539
- # return
540
- # new_dsts = {run for run in all_runs if re.search(regex, run)}
541
- # if not new_dsts:
542
- # return gr.update(value=list(selected_runs))
543
- # dst_union = new_dsts.union(selected_runs or [])
544
- # return gr.update(value=sorted(list(dst_union)))
545
-
546
- # regex_button.click(
547
- # fn=update_datasets_with_regex,
548
- # inputs=[regex_select, datasets_selected, datasets],
549
- # outputs=datasets_selected,
550
- # )
551
-
552
- # def update_grouping_options(grouping):
553
- # if grouping == "histogram":
554
- # return {
555
- # normalization_checkbox: gr.Column(visible=True),
556
- # group_choices: gr.Column(visible=False),
557
- # }
558
- # else:
559
- # return {
560
- # normalization_checkbox: gr.Column(visible=False),
561
- # group_choices: gr.Column(visible=True),
562
- # }
563
-
564
- # grouping_dropdown.select(
565
- # fn=update_grouping_options,
566
- # inputs=[grouping_dropdown],
567
- # outputs=[normalization_checkbox, group_choices],
568
- # )
569
-
570
-
571
- # # Launch the application
572
- # if __name__ == "__main__":
573
- # demo.launch()
574
-
575
-
576
- import os
577
- print(os.environ["AWS_ACCESS_KEY_ID"])
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import enum
3
+ from functools import partial
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import re
8
+ import heapq
9
+ import tempfile
10
+ from typing import Literal
11
+ import gradio as gr
12
+
13
+ from collections import defaultdict
14
+ from datatrove.io import get_datafolder
15
+ import plotly.graph_objects as go
16
+ from datatrove.utils.stats import MetricStats, MetricStatsDict
17
+ import plotly.express as px
18
+ import tenacity
19
+
20
+ import gradio as gr
21
+ PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
22
+ METRICS_LOCATION_DEFAULT = os.getenv("METRICS_LOCATION_DEFAULT", "s3://fineweb-stats/summary/")
23
+
24
+
25
+ def find_folders(base_folder, path):
26
+ base_folder = get_datafolder(base_folder)
27
+ if not base_folder.exists(path):
28
+ return []
29
+ return sorted(
30
+ [
31
+ folder["name"]
32
+ for folder in base_folder.ls(path, detail=True)
33
+ if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
34
+ ]
35
+ )
36
+
37
+
38
+ def find_metrics_folders(base_folder: str):
39
+ base_data_folder = get_datafolder(base_folder)
40
+ # First find all metric.json using globing for metric.json
41
+ metrics_merged = base_data_folder.glob("**/metric.json")
42
+
43
+ # Then for each of metrics.merged take the all but last two parts of the path (grouping/metric_name)
44
+ metrics_folders = [str(Path(x).parent.parent.parent) for x in metrics_merged]
45
+ # Finally get the unique paths
46
+ return sorted(list(set(metrics_folders)))
47
+
48
+
49
+ def fetch_datasets(base_folder: str):
50
+ datasets = sorted(find_metrics_folders(base_folder))
51
+ return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union")
52
+
53
+
54
+ def export_data(exported_data: MetricStatsDict, metric_name: str):
55
+ if not exported_data:
56
+ return None
57
+ # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported
58
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, prefix=metric_name, suffix=".json") as temp:
59
+ json.dump({
60
+ name: dt.to_dict()
61
+ for name, dt in exported_data.items()
62
+ }, temp)
63
+ temp_path = temp.name
64
+ return gr.update(visible=True, value=temp_path)
65
+
66
+
67
+ def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
68
+ if not datasets:
69
+ return gr.update(choices=[], value=None)
70
+
71
+ with ThreadPoolExecutor() as executor:
72
+ GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
73
+ if len(GROUPS) == 0:
74
+ return gr.update(choices=[], value=None)
75
+
76
+ if type == "intersection":
77
+ new_choices = set.intersection(*(set(g) for g in GROUPS))
78
+ else:
79
+ new_choices = set.union(*(set(g) for g in GROUPS))
80
+ value = None
81
+ if old_groups:
82
+ value = list(set.intersection(new_choices, {old_groups}))
83
+ value = value[0] if value else None
84
+
85
+ # now take the intersection of all grups
86
+ return gr.update(choices=sorted(list(new_choices)), value=value)
87
+
88
+
89
+ def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"):
90
+ with ThreadPoolExecutor() as executor:
91
+ metrics = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
92
+ if len(metrics) == 0:
93
+ return gr.update(choices=[], value=None)
94
+
95
+ if type == "intersection":
96
+ new_possibles_choices = set.intersection(*(set(s) for s in metrics))
97
+ else:
98
+ new_possibles_choices = set.union(*(set(s) for s in metrics))
99
+ value = None
100
+ if old_metrics:
101
+ value = list(set.intersection(new_possibles_choices, {old_metrics}))
102
+ value = value[0] if value else None
103
+
104
+ return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
105
+
106
+
107
+ def reverse_search(base_folder, possible_datasets, grouping, metric_name):
108
+ with ThreadPoolExecutor() as executor:
109
+ found_datasets = list(executor.map(lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, possible_datasets))
110
+ found_datasets = [dataset for dataset in found_datasets if dataset is not None]
111
+ return "\n".join(found_datasets)
112
+
113
+
114
+ def reverse_search_add(datasets, reverse_search_results):
115
+ datasets = datasets or []
116
+ return sorted(list(set(datasets + reverse_search_results.strip().split("\n"))))
117
+
118
+
119
+
120
+ def metric_exists(base_folder, path, metric_name, group_by):
121
+ base_folder = get_datafolder(base_folder)
122
+ return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
123
+
124
+ @tenacity.retry(stop=tenacity.stop_after_attempt(5))
125
+ def load_metrics(base_folder, path, metric_name, group_by):
126
+ base_folder = get_datafolder(base_folder)
127
+ with base_folder.open(
128
+ f"{path}/{group_by}/{metric_name}/metric.json",
129
+ ) as f:
130
+ json_metric = json.load(f)
131
+ # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malformed
132
+ return MetricStatsDict.from_dict(json_metric)
133
+
134
+
135
+ def prepare_for_non_grouped_plotting(metric, normalization, rounding):
136
+ metrics_rounded = defaultdict(lambda: 0)
137
+ for key, value in metric.items():
138
+ metrics_rounded[round(float(key), rounding)] += value.total
139
+ if normalization:
140
+ normalizer = sum(metrics_rounded.values())
141
+ metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
142
+ # check that the sum of the values is 1
143
+ summed = sum(metrics_rounded.values())
144
+ assert abs(summed - 1) < 0.01, summed
145
+ return metrics_rounded
146
+
147
+
148
+ def load_data(dataset_path, base_folder, grouping, metric_name):
149
+ metrics = load_metrics(base_folder, dataset_path, metric_name, grouping)
150
+ return metrics
151
+
152
+ def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, regex: str | None, rounding: int):
153
+ regex_compiled = re.compile(regex) if regex else None
154
+ metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
155
+ means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
156
+ # Use heap to get top_k keys
157
+ if direction == "Top":
158
+ keys = heapq.nlargest(top_k, means, key=means.get)
159
+ elif direction == "Most frequent (n_docs)":
160
+ totals = {key: int(value.n) for key, value in metric.items()}
161
+ keys = heapq.nlargest(top_k, totals, key=totals.get)
162
+ else:
163
+ keys = heapq.nsmallest(top_k, means, key=means.get)
 
164
 
165
 
166
+ means = [means[key] for key in keys]
167
+ stds = [metric[key].standard_deviation for key in keys]
168
+ return keys, means, stds
169
+
170
+
171
+ def set_alpha(color, alpha):
172
+ """
173
+ Takes a hex color and returns
174
+ rgba(r, g, b, a)
175
+ """
176
+ if color.startswith('#'):
177
+ r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
178
+ else:
179
+ r, g, b = 0, 0, 0 # Fallback to black if the color format is not recognized
180
+ return f"rgba({r}, {g}, {b}, {alpha})"
181
+
182
+
183
+ def plot_scatter(
184
+ data: dict[str, dict[float, float]],
185
+ metric_name: str,
186
+ log_scale_x: bool,
187
+ log_scale_y: bool,
188
+ normalization: bool,
189
+ rounding: int,
190
+ progress: gr.Progress,
191
+ ):
192
+ fig = go.Figure()
193
+
194
+ # First sort the histograms, by their name
195
+ data = {name: histogram for name, histogram in sorted(data.items())}
196
+ for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
197
+ histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
198
+ x = sorted(histogram_prepared.keys())
199
+ y = [histogram_prepared[k] for k in x]
200
+
201
+ fig.add_trace(
202
+ go.Scatter(
203
+ x=x,
204
+ y=y,
205
+ mode="lines",
206
+ name=name,
207
+ marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
208
+ )
209
+ )
210
+
211
+ yaxis_title = "Frequency" if normalization else "Total"
212
+
213
+ fig.update_layout(
214
+ title=f"Line Plots for {metric_name}",
215
+ xaxis_title=metric_name,
216
+ yaxis_title=yaxis_title,
217
+ xaxis_type="log" if log_scale_x and len(x) > 1 else None,
218
+ yaxis_type="log" if log_scale_y and len(y) > 1 else None,
219
+ width=1200,
220
+ height=600,
221
+ showlegend=True,
222
+ )
223
+
224
+ return fig
225
+
226
+
227
+ def plot_bars(
228
+ data: dict[str, list[dict[str, float]]],
229
+ metric_name: str,
230
+ top_k: int,
231
+ direction: PARTITION_OPTIONS,
232
+ regex: str | None,
233
+ rounding: int,
234
+ log_scale_x: bool,
235
+ log_scale_y: bool,
236
+ progress: gr.Progress,
237
+ ):
238
+ fig = go.Figure()
239
+ x = []
240
+ y = []
241
+
242
+ for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
243
+ x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
244
+
245
+ fig.add_trace(go.Bar(
246
+ x=x,
247
+ y=y,
248
+ name=f"{name} Mean",
249
+ marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
250
+ error_y=dict(type='data', array=stds, visible=True)
251
+ ))
252
+
253
+ fig.update_layout(
254
+ title=f"Bar Plots for {metric_name}",
255
+ xaxis_title=metric_name,
256
+ yaxis_title="Avg. value",
257
+ xaxis_type="log" if log_scale_x and len(x) > 1 else None,
258
+ yaxis_type="log" if log_scale_y and len(y) > 1 else None,
259
+ autosize=True,
260
+ width=1200,
261
+ height=600,
262
+ showlegend=True,
263
+ )
264
+
265
+ return fig
266
+
267
+
268
+ def update_graph(
269
+ base_folder,
270
+ datasets,
271
+ metric_name,
272
+ grouping,
273
+ log_scale_x,
274
+ log_scale_y,
275
+ rounding,
276
+ normalization,
277
+ top_k,
278
+ direction,
279
+ regex,
280
+ progress=gr.Progress(),
281
+ ):
282
+ if len(datasets) <= 0 or not metric_name or not grouping:
283
+ return None
284
+ # Placeholder for logic to rerender the graph based on the inputs
285
+
286
+ with ThreadPoolExecutor() as pool:
287
+ data = list(
288
+ progress.tqdm(
289
+ pool.map(
290
+ partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
291
+ datasets,
292
+ ),
293
+ total=len(datasets),
294
+ desc="Loading data...",
295
+ )
296
+ )
297
+
298
+ data = {path: result for path, result in zip(datasets, data)}
299
+ return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress), data, export_data(data, metric_name)
300
+
301
+ def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress=gr.Progress()):
302
+ if rounding is None or top_k is None:
303
+ return None
304
+ graph_fc = (
305
+ partial(plot_scatter, normalization=normalization, rounding=rounding)
306
+ if grouping == "histogram"
307
+ else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
308
+ )
309
+ return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, log_scale_y=log_scale_y)
310
+
311
+
312
+
313
+ # Create the Gradio interface
314
+ with gr.Blocks() as demo:
315
+ datasets = gr.State([])
316
+ exported_data = gr.State([])
317
+ metrics_headline = gr.Markdown(value="# Metrics Exploration")
318
+ with gr.Row():
319
+ with gr.Column(scale=2):
320
+ with gr.Row():
321
+ with gr.Column(scale=1):
322
+ base_folder = gr.Textbox(
323
+ label="Metrics Location",
324
+ value=METRICS_LOCATION_DEFAULT,
325
+ )
326
+ datasets_refetch = gr.Button("Fetch Datasets")
327
+
328
+ with gr.Column(scale=1):
329
+ regex_select = gr.Text(label="Regex filter", value=".*")
330
+ regex_button = gr.Button("Search")
331
+ with gr.Row():
332
+ datasets_selected = gr.Dropdown(
333
+ choices=[],
334
+ label="Datasets",
335
+ multiselect=True,
336
+ )
337
+
338
+ # add a readme description
339
+ readme_description = gr.Markdown(
340
+ label="Readme",
341
+ value="""
342
+ ## How to use:
343
+ 1) Specify Metrics location (Stats block `output_folder` without the last path segment) and click "Fetch Datasets"
344
+ 2) Select datasets you are interested in using the dropdown or regex filter
345
+ 3) Specify Grouping (global average/value/fqdn/suffix) and Metric name
346
+ 4) Click "Update Graph"
347
+
348
+
349
+ ## Groupings:
350
+ - **histogram**: Creates a line plot of values with their frequencies. If normalization is on, the frequencies sum to 1.
351
+ * normalize:
352
+ - **(fqdn/suffix)**: Creates a bar plot of the avg. values of the metric for full qualifed domain name/suffix of domain.
353
+ * k: the number of groups to show
354
+ * Top/Bottom/Most frequent (n_docs): Groups with the top/bottom k values/most prevalant docs are shown
355
+ - **none**: Shows the average value of given metric
356
+
357
+ ## Reverse search:
358
+ To search for datasets containing a grouping and certain metric, use the Reverse search section.
359
+ Specify the search parameters and click "Search". This will show you found datasets in the "Found datasets" textbox. You can modify the selection after search by removing unwanted lines and clicking "Add to selection".
360
+
361
+ ## Note:
362
+ The data might not be 100% representative, due to the sampling and optimistic merging of the metrics (fqdn/suffix).
363
+ """,
364
+ )
365
+ with gr.Column(scale=1):
366
+ # Define the dropdown for grouping
367
+ grouping_dropdown = gr.Dropdown(
368
+ choices=[],
369
+ label="Grouping",
370
+ multiselect=False,
371
+ )
372
+ # Define the dropdown for metric_name
373
+ metric_name_dropdown = gr.Dropdown(
374
+ choices=[],
375
+ label="Metric name",
376
+ multiselect=False,
377
+ )
378
+
379
+
380
+ update_button = gr.Button("Update Graph", variant="primary")
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=1):
384
+ log_scale_x_checkbox = gr.Checkbox(
385
+ label="Log scale x",
386
+ value=False,
387
+ )
388
+ log_scale_y_checkbox = gr.Checkbox(
389
+ label="Log scale y",
390
+ value=False,
391
+ )
392
+ rounding = gr.Number(
393
+ label="Rounding",
394
+ value=2,
395
+ )
396
+ normalization_checkbox = gr.Checkbox(
397
+ label="Normalize",
398
+ value=True, # Default value
399
+ visible=False
400
+ )
401
+ with gr.Row():
402
+ # export_data_button = gr.Button("Export data", visible=True, link=export_data_json)
403
+ export_data_json = gr.File(visible=False)
404
+ with gr.Column(scale=4):
405
+ with gr.Row(visible=False) as group_choices:
406
+ with gr.Column(scale=2):
407
+ group_regex = gr.Text(
408
+ label="Group Regex",
409
+ value=None,
410
+ )
411
+ with gr.Row():
412
+ top_select = gr.Number(
413
+ label="N Groups",
414
+ value=100,
415
+ interactive=True,
416
+ )
417
 
418
+ direction_checkbox = gr.Radio(
419
+ label="Partition",
420
+ choices=[
421
+ "Top",
422
+ "Bottom",
423
+ "Most frequent (n_docs)",
424
+ ],
425
+ value="Most frequent (n_docs)",
426
+ )
427
+ # Define the graph output
428
+ with gr.Row():
429
+ graph_output = gr.Plot(label="Graph")
430
 
431
+ with gr.Row():
432
+ reverse_search_headline = gr.Markdown(value="# Reverse metrics search")
433
 
434
+ with gr.Row():
435
+ with gr.Column(scale=1):
436
+ # Define the dropdown for grouping
437
+ reverse_grouping_dropdown = gr.Dropdown(
438
+ choices=[],
439
+ label="Grouping",
440
+ multiselect=False,
441
+ )
442
+ # Define the dropdown for metric_name
443
+ reverse_metric_name_dropdown = gr.Dropdown(
444
+ choices=[],
445
+ label="Stat name",
446
+ multiselect=False,
447
+ )
448
 
449
+ with gr.Column(scale=1):
450
+ reverse_search_button = gr.Button("Search")
451
+ reverse_search_add_button = gr.Button("Add to selection")
452
+
453
+ with gr.Column(scale=2):
454
+ reverse_search_results = gr.Textbox(
455
+ label="Found datasets",
456
+ lines=10,
457
+ placeholder="Found datasets containing the group/metric name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
458
+ )
459
+
460
+
461
+ update_button.click(
462
+ fn=update_graph,
463
+ inputs=[
464
+ base_folder,
465
+ datasets_selected,
466
+ metric_name_dropdown,
467
+ grouping_dropdown,
468
+ log_scale_x_checkbox,
469
+ log_scale_y_checkbox,
470
+ rounding,
471
+ normalization_checkbox,
472
+ top_select,
473
+ direction_checkbox,
474
+ group_regex,
475
+ ],
476
+ outputs=[graph_output, exported_data, export_data_json],
477
+ )
478
+
479
+ for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox, log_scale_y_checkbox]:
480
+ inp.change(
481
+ fn=plot_data,
482
+ inputs=[
483
+ exported_data,
484
+ metric_name_dropdown,
485
+ normalization_checkbox,
486
+ rounding,
487
+ grouping_dropdown,
488
+ top_select,
489
+ direction_checkbox,
490
+ group_regex,
491
+ log_scale_x_checkbox,
492
+ log_scale_y_checkbox,
493
+ ],
494
+ outputs=[graph_output],
495
+ )
496
+
497
+
498
+
499
+ datasets_selected.change(
500
+ fn=fetch_groups,
501
+ inputs=[base_folder, datasets_selected, grouping_dropdown],
502
+ outputs=grouping_dropdown,
503
+ )
504
+
505
+ grouping_dropdown.select(
506
+ fn=fetch_metrics,
507
+ inputs=[base_folder, datasets_selected, grouping_dropdown, metric_name_dropdown],
508
+ outputs=metric_name_dropdown,
509
+ )
510
+
511
+ reverse_grouping_dropdown.select(
512
+ fn=partial(fetch_metrics, type="union"),
513
+ inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
514
+ outputs=reverse_metric_name_dropdown,
515
+ )
516
+
517
+ reverse_search_button.click(
518
+ fn=reverse_search,
519
+ inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
520
+ outputs=reverse_search_results,
521
+ )
522
+
523
+ reverse_search_add_button.click(
524
+ fn=reverse_search_add,
525
+ inputs=[datasets_selected, reverse_search_results],
526
+ outputs=datasets_selected,
527
+ )
528
+
529
+
530
+ datasets_refetch.click(
531
+ fn=fetch_datasets,
532
+ inputs=[base_folder],
533
+ outputs=[datasets, datasets_selected, reverse_grouping_dropdown],
534
+ )
535
+
536
+ def update_datasets_with_regex(regex, selected_runs, all_runs):
537
+ if not regex:
538
+ return
539
+ new_dsts = {run for run in all_runs if re.search(regex, run)}
540
+ if not new_dsts:
541
+ return gr.update(value=list(selected_runs))
542
+ dst_union = new_dsts.union(selected_runs or [])
543
+ return gr.update(value=sorted(list(dst_union)))
544
+
545
+ regex_button.click(
546
+ fn=update_datasets_with_regex,
547
+ inputs=[regex_select, datasets_selected, datasets],
548
+ outputs=datasets_selected,
549
+ )
550
+
551
+ def update_grouping_options(grouping):
552
+ if grouping == "histogram":
553
+ return {
554
+ normalization_checkbox: gr.Column(visible=True),
555
+ group_choices: gr.Column(visible=False),
556
+ }
557
+ else:
558
+ return {
559
+ normalization_checkbox: gr.Column(visible=False),
560
+ group_choices: gr.Column(visible=True),
561
+ }
562
+
563
+ grouping_dropdown.select(
564
+ fn=update_grouping_options,
565
+ inputs=[grouping_dropdown],
566
+ outputs=[normalization_checkbox, group_choices],
567
+ )
568
+
569
+
570
+ # Launch the application
571
+ if __name__ == "__main__":
572
+ demo.launch()
 
 
 
 
requirements.txt CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ datatrove[dev] @ git+https://github.com/huggingface/datatrove.git
3
+ plotly