Datasets-Metrics-Viewer / src /logic /data_processing.py
hynky's picture
hynky HF staff
⚡️ make it faster
276d919
from datetime import datetime
import numpy as np
import json
import re
import heapq
from collections import defaultdict
import tempfile
from typing import Dict, Tuple, List, Literal
import gradio as gr
from datatrove.utils.stats import MetricStatsDict
from src.logic.graph_settings import Grouping
PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
keys = np.array([float(key) for key in metric.keys()])
values = np.array([value.total for value in metric.values()])
rounded_keys = np.round(keys, rounding)
unique_keys, indices = np.unique(rounded_keys, return_inverse=True)
metrics_rounded = np.zeros_like(unique_keys, dtype=float)
np.add.at(metrics_rounded, indices, values)
if normalization:
normalizer = np.sum(metrics_rounded)
metrics_rounded /= normalizer
return dict(zip(unique_keys, metrics_rounded))
def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
regex_compiled = re.compile(regex) if regex else None
filtered_metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
keys = np.array(list(filtered_metric.keys()))
means = np.array([float(value.mean) for value in filtered_metric.values()])
stds = np.array([value.standard_deviation for value in filtered_metric.values()])
rounded_means = np.round(means, rounding)
if direction == "Top":
top_indices = np.argsort(rounded_means)[-top_k:][::-1]
elif direction == "Most frequent (n_docs)":
totals = np.array([int(value.n) for value in filtered_metric.values()])
top_indices = np.argsort(totals)[-top_k:][::-1]
else:
top_indices = np.argsort(rounded_means)[:top_k]
top_keys = keys[top_indices]
top_means = rounded_means[top_indices]
top_stds = stds[top_indices]
return top_keys.tolist(), top_means.tolist(), top_stds.tolist()
def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
if not exported_data:
return None
file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
with open(file_name, 'w') as f:
json.dump({
name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"])
for name, dt in exported_data.items()
}, f, indent=2)
return gr.File(value=file_name, visible=True)