LysandreJik's picture
Max number of pipelines
ffeea82
raw
history blame
2.34 kB
from collections import OrderedDict
import gradio as gr
import plotly.graph_objects as go
from datasets import load_dataset, Dataset
from huggingface_hub import list_datasets
pipelines = [d.id[20:-21] for d in list_datasets(author='open-source-metrics') if 'checkpoint-downloads' in d.id]
def sum_with_none(iterator):
return sum([v for v in iterator if v is not None])
def merge_columns(dataset: Dataset, max_number_of_columns: int):
downloads = {col: sum_with_none(dataset[col]) for col in dataset.column_names if col != 'dates'}
sorted_downloads = OrderedDict(sorted(downloads.items(), key=lambda x: x[1], reverse=True))
to_merge = list(sorted_downloads.keys())[max_number_of_columns:]
to_keep = list(sorted_downloads.keys())[:max_number_of_columns]
dictionary = dataset.to_dict()
dictionary['combined'] = dictionary.pop('no_arch')
while len(to_merge):
current = dictionary['combined']
to_add = dictionary.pop(to_merge.pop(0))
for i in range(len(current)):
if current[i] is None:
current[i] = 0
if to_add[i] is None:
to_add[i] = 0
current[i] += to_add[i]
dictionary['combined'] = current
dataset = Dataset.from_dict(dictionary)
return dataset
def plot(library: str, stacked: bool, number_of_pipelines_to_show: int):
dataset = load_dataset(f"open-source-metrics/{library}-checkpoint-downloads")['train']
n_archs = len(dataset.column_names) - 1 # Remove dates
if n_archs > number_of_pipelines_to_show:
dataset = merge_columns(dataset, number_of_pipelines_to_show)
dates = dataset['dates']
axis = dataset.column_names
axis.remove('dates')
fig = go.Figure()
for i in axis:
fig.add_trace(
go.Scatter(x=dates, y=dataset[i], mode='lines+markers', name=i, stackgroup='one' if stacked else None)
)
fig.show()
return fig
with gr.Blocks() as demo:
inputs = [
gr.Dropdown(pipelines),
gr.Checkbox(label='Stacked'),
gr.Slider(minimum=1, maximum=len(pipelines), value=3, step=1, label="Max number of pipelines to show")
]
submit = gr.Button('Submit')
with gr.Row():
outputs = [gr.Plot()]
submit.click(fn=plot, inputs=inputs, outputs=outputs)
demo.launch()