Spaces:
Runtime error
Runtime error
import altair as alt | |
import gradio as gr | |
import pandas as pd | |
from functools import partial | |
from datasets import load_dataset | |
def get_data(): | |
model_id = "ybelkada/model_cards_correct_tag" | |
dataset = load_dataset(model_id, split="train").to_pandas() | |
# Convert dataset to a pandas DataFrame and sort by commit_dates | |
df = pd.DataFrame(dataset) | |
df["commit_dates"] = pd.to_datetime(df["commit_dates"]) # Convert commit_dates to datetime format | |
df = df.sort_values(by="commit_dates") | |
melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type') | |
df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100 | |
ratio_df = df[['commit_dates', 'ratio']].copy() | |
return ratio_df, melted_df | |
ratio_df, melted_df = get_data() | |
def make_plot(plot_type, refresh=False): | |
global ratio_df, melted_df | |
if refresh: | |
ratio_df, melted_df = get_data() | |
if plot_type == "Total models with missing 'transformers' tag": | |
highlight = alt.selection(type='single', on='mouseover', | |
fields=['type'], nearest=True) | |
base = alt.Chart(melted_df).encode( | |
x=alt.X('commit_dates:T', title='Date'), | |
y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"), | |
color='type:N', | |
) | |
points = base.mark_circle().encode( | |
opacity=alt.value(1), | |
).add_selection( | |
highlight | |
).properties( | |
width=1200, | |
height=800, | |
) | |
lines = base.mark_line().encode( | |
size=alt.condition(~highlight, alt.value(1), alt.value(3)) | |
) | |
return points + lines | |
else: | |
highlight = alt.selection(type='single', on='mouseover', | |
fields=['ratio'], nearest=True) | |
base = alt.Chart(ratio_df).encode( | |
x=alt.X('commit_dates:T', title='Date'), | |
y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"), | |
) | |
points = base.mark_circle().encode( | |
opacity=alt.value(1) | |
).add_selection( | |
highlight | |
).properties( | |
width=1200, | |
height=800, | |
) | |
lines = base.mark_line().encode( | |
size=alt.condition(~highlight, alt.value(1), alt.value(3)) | |
) | |
return points + lines | |
with gr.Blocks() as demo: | |
button = gr.Radio( | |
label="Plot type", | |
choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], | |
value="Total models with missing 'transformers' tag" | |
) | |
refresh_button = gr.Button(value="Fetch latest data") | |
plot = gr.Plot(label="Plot") | |
button.change(make_plot, inputs=[button], outputs=[plot]) | |
refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot]) | |
demo.load(make_plot, inputs=[button], outputs=[plot]) | |
if __name__ == "__main__": | |
demo.launch() |