|
import subprocess |
|
import gradio as gr |
|
import pandas as pd |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from huggingface_hub import snapshot_download |
|
|
|
from src.about import ( |
|
CITATION_BUTTON_LABEL, |
|
CITATION_BUTTON_TEXT, |
|
EVALUATION_QUEUE_TEXT, |
|
INTRODUCTION_TEXT, |
|
LLM_BENCHMARKS_TEXT, |
|
TITLE, |
|
) |
|
from src.display.css_html_js import custom_css |
|
from src.display.utils import ( |
|
BENCHMARK_COLS, |
|
COLS, |
|
EVAL_COLS, |
|
EVAL_TYPES, |
|
NUMERIC_INTERVALS, |
|
TYPES, |
|
AutoEvalColumn, |
|
ModelType, |
|
fields, |
|
WeightType, |
|
Precision |
|
) |
|
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(data_path): |
|
columns = ['Unlearned_Methods','Pre-ASR', 'Post-ASR','FID','CLIP-Score'] |
|
columns_sorted = ['Unlearned_Methods','Pre-ASR', 'Post-ASR','FID','CLIP-Score'] |
|
|
|
df = pd.read_csv(data_path).dropna() |
|
df['Post-ASR'] = df['Post-ASR'].round(0) |
|
|
|
|
|
df = df.sort_values(by='Post-ASR', ascending=False) |
|
|
|
df = df[columns_sorted] |
|
|
|
|
|
return df |
|
|
|
def restart_space(): |
|
API.restart_space(repo_id=REPO_ID) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_columns = ['Unlearned_Methods','Pre-ASR','Post-ASR','FID','CLIP-Score'] |
|
show_columns = ['Unlearned_Methods','Pre-ASR','Post-ASR','FID','CLIP-Score'] |
|
TYPES = ['str', 'number', 'number', 'number', 'number'] |
|
files = ['nudity','vangogh', 'church','garbage','parachute','tench'] |
|
csv_path='./assets/'+files[0]+'.csv' |
|
df_results = load_data(csv_path) |
|
methods = list(set(df_results['Unlearned_Methods'])) |
|
df_results_init = df_results.copy()[show_columns] |
|
|
|
def update_table( |
|
hidden_df: pd.DataFrame, |
|
model1_column: list, |
|
|
|
|
|
|
|
|
|
|
|
query: str, |
|
): |
|
|
|
|
|
|
|
filtered_df = hidden_df.copy() |
|
|
|
|
|
|
|
|
|
|
|
filtered_df=select_columns(filtered_df,model1_column) |
|
filtered_df = filter_queries(query, filtered_df) |
|
|
|
|
|
|
|
|
|
|
|
df = filtered_df.drop_duplicates() |
|
|
|
return df |
|
|
|
|
|
def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame: |
|
return df[(df['Unlearned_Methods'].str.contains(query, case=False))] |
|
|
|
|
|
def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame: |
|
final_df = [] |
|
if query != "": |
|
queries = [q.strip() for q in query.split(";")] |
|
for _q in queries: |
|
_q = _q.strip() |
|
if _q != "": |
|
temp_filtered_df = search_table(filtered_df, _q) |
|
if len(temp_filtered_df) > 0: |
|
final_df.append(temp_filtered_df) |
|
if len(final_df) > 0: |
|
filtered_df = pd.concat(final_df) |
|
|
|
return filtered_df |
|
|
|
def search_table_model(df: pd.DataFrame, query: str) -> pd.DataFrame: |
|
return df[(df['Diffusion_Models'].str.contains(query, case=False))] |
|
|
|
|
|
def filter_queries_model(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame: |
|
final_df = [] |
|
|
|
|
|
for _q in query: |
|
print(_q) |
|
if _q != "": |
|
temp_filtered_df = search_table_model(filtered_df, _q) |
|
if len(temp_filtered_df) > 0: |
|
final_df.append(temp_filtered_df) |
|
if len(final_df) > 0: |
|
filtered_df = pd.concat(final_df) |
|
|
|
return filtered_df |
|
|
|
def select_columns(df: pd.DataFrame, columns_1: list) -> pd.DataFrame: |
|
always_here_cols = ['Unlearned_Methods'] |
|
|
|
|
|
all_columns =['Pre-ASR','Post-ASR','FID','CLIP-Score'] |
|
|
|
if (len(columns_1)) == 0: |
|
filtered_df = df[ |
|
always_here_cols + |
|
[c for c in all_columns if c in df.columns] |
|
] |
|
|
|
else: |
|
filtered_df = df[ |
|
always_here_cols + |
|
[c for c in all_columns if c in df.columns and (c in columns_1) ] |
|
] |
|
|
|
return filtered_df |
|
|
|
|
|
demo = gr.Blocks(css=custom_css) |
|
with demo: |
|
gr.HTML(TITLE) |
|
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") |
|
gr.Markdown(EVALUATION_QUEUE_TEXT,elem_classes="eval-text") |
|
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="reference-text") |
|
|
|
with gr.Tabs(elem_classes="tab-buttons") as tabs: |
|
with gr.TabItem("UnlearnDiffAtk Benchmark", elem_id="UnlearnDiffAtk-benchmark-tab-table", id=0): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
search_bar = gr.Textbox( |
|
placeholder=" π Search for your model (separate multiple queries with `;`) and press ENTER...", |
|
show_label=False, |
|
elem_id="search-bar", |
|
) |
|
with gr.Row(): |
|
model1_column = gr.CheckboxGroup( |
|
label="Evaluation Metrics", |
|
choices=['Pre-ASR','Post-ASR','FID','CLIP-Score'], |
|
interactive=True, |
|
elem_id="column-select", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(files)): |
|
if files[i] == "church": |
|
name = "### [Unlearned Objects] "+" Church" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
elif files[i] == 'garbage': |
|
name = "### [Unlearned Objects] "+" Garbage" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
elif files[i] == 'tench': |
|
name = "### [Unlearned Objects] "+" Tench" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
elif files[i] == 'parachute': |
|
name = "### [Unlearned Objects] "+" Parachute" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
elif files[i] == 'vangogh': |
|
name = "### [Unlearned Style] "+" Van Gogh" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
elif files[i] == 'nudity': |
|
name = "### Unlearned Concepts "+" Nudity" |
|
csv_path = './assets/'+files[i]+'.csv' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(name) |
|
df_results = load_data(csv_path) |
|
df_results_init = df_results.copy()[show_columns] |
|
leaderboard_table = gr.components.Dataframe( |
|
value = df_results, |
|
datatype = TYPES, |
|
elem_id = "leaderboard-table", |
|
interactive = False, |
|
visible=True, |
|
) |
|
|
|
|
|
hidden_leaderboard_table_for_search = gr.components.Dataframe( |
|
|
|
value=df_results, |
|
interactive=False, |
|
visible=False, |
|
) |
|
|
|
search_bar.submit( |
|
update_table, |
|
[ |
|
|
|
hidden_leaderboard_table_for_search, |
|
model1_column, |
|
search_bar, |
|
], |
|
leaderboard_table, |
|
) |
|
|
|
for selector in [model1_column]: |
|
selector.change( |
|
update_table, |
|
[ |
|
hidden_leaderboard_table_for_search, |
|
model1_column, |
|
search_bar, |
|
], |
|
leaderboard_table, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Accordion("π Citation", open=True): |
|
citation_button = gr.Textbox( |
|
value=CITATION_BUTTON_TEXT, |
|
label=CITATION_BUTTON_LABEL, |
|
lines=10, |
|
elem_id="citation-button", |
|
show_copy_button=True, |
|
) |
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(restart_space, "interval", seconds=1800) |
|
scheduler.start() |
|
demo.queue().launch(share=True) |