from collections import OrderedDict import gradio as gr import plotly.graph_objects as go from datasets import load_dataset, Dataset from huggingface_hub import list_datasets pipelines = [d.id[20:-21] for d in list_datasets(author='open-source-metrics') if 'checkpoint-downloads' in d.id] def sum_with_none(iterator): return sum([v for v in iterator if v is not None]) def merge_columns(dataset: Dataset, max_number_of_columns: int): downloads = {col: sum_with_none(dataset[col]) for col in dataset.column_names if col != 'dates'} sorted_downloads = OrderedDict(sorted(downloads.items(), key=lambda x: x[1], reverse=True)) to_merge = list(sorted_downloads.keys())[max_number_of_columns:] to_keep = list(sorted_downloads.keys())[:max_number_of_columns] if 'no_arch' in to_merge: to_merge.remove('no_arch') if 'no_arch' in to_keep: to_keep.remove('no_arch') dictionary = dataset.to_dict() dictionary['combined'] = dictionary['no_arch'] while len(to_merge): current = dictionary['combined'] to_add = dictionary.pop(to_merge.pop(0)) for i in range(len(current)): if current[i] is None: current[i] = 0 if to_add[i] is None: to_add[i] = 0 current[i] += to_add[i] dictionary['combined'] = current dataset = Dataset.from_dict(dictionary) return dataset def plot(library: str, stacked: bool, number_of_architectures_to_show: int): dataset = load_dataset(f"open-source-metrics/{library}-checkpoint-downloads")['train'] n_archs = len(dataset.column_names) - 1 # Remove dates if n_archs > number_of_architectures_to_show: dataset = merge_columns(dataset, number_of_architectures_to_show) dates = dataset['dates'] axis = dataset.column_names axis.remove('dates') fig = go.Figure() for i in axis: fig.add_trace( go.Scatter(x=dates, y=dataset[i], mode='lines+markers', name=i, stackgroup='one' if stacked else None) ) return fig with gr.Blocks() as demo: inputs = [ gr.Dropdown(pipelines), gr.Checkbox(label='Stacked'), gr.Slider(minimum=1, maximum=len(pipelines), value=3, step=1, label="Max number of architectures to show") ] submit = gr.Button('Submit') with gr.Row(): outputs = [gr.Plot()] submit.click(fn=plot, inputs=inputs, outputs=outputs) demo.launch()