Spaces:
Build error
Build error
import gradio as gr | |
import requests | |
import json | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering | |
from datasets import load_dataset | |
import datasets | |
import plotly.io as pio | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
from sklearn.metrics import confusion_matrix | |
import importlib | |
import torch | |
from dash import Dash, html, dcc | |
import numpy as np | |
from sklearn.metrics import accuracy_score | |
from sklearn.metrics import f1_score | |
def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str): | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
if model_type == "text_classification": | |
dataset = load_dataset(dataset_name, config_name) | |
num_labels = len(dataset["train"].features["label"].names) | |
if "roberta" in model_name_or_path.lower(): | |
from transformers import RobertaForSequenceClassification | |
model = RobertaForSequenceClassification.from_pretrained( | |
model_name_or_path, num_labels=num_labels) | |
else: | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name_or_path, num_labels=num_labels) | |
elif model_type == "token_classification": | |
dataset = load_dataset(dataset_name, config_name) | |
num_labels = len( | |
dataset["train"].features["ner_tags"].feature.names) | |
model = AutoModelForTokenClassification.from_pretrained( | |
model_name_or_path, num_labels=num_labels) | |
elif model_type == "question_answering": | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) | |
else: | |
raise ValueError(f"Invalid model type: {model_type}") | |
return tokenizer, model | |
def test_model(tokenizer, model, test_data: list, label_map: dict): | |
results = [] | |
for text, _, true_label in test_data: | |
inputs = tokenizer(text, return_tensors="pt", | |
truncation=True, padding=True) | |
outputs = model(**inputs) | |
pred_label = label_map[int(outputs.logits.argmax(dim=-1))] | |
results.append((text, true_label, pred_label)) | |
return results | |
def generate_label_map(dataset): | |
if "label" not in dataset.features or dataset.features["label"] is None: | |
return {} | |
if isinstance(dataset.features["label"], datasets.ClassLabel): | |
num_labels = dataset.features["label"].num_classes | |
label_map = {i: label for i, label in enumerate(dataset.features["label"].names)} | |
else: | |
num_labels = len(set(dataset["label"])) | |
label_map = {i: label for i, label in enumerate(set(dataset["label"]))} | |
return label_map | |
def calculate_fairness_score(results, label_map): | |
true_labels = [r[1] for r in results] | |
pred_labels = [r[2] for r in results] | |
# Overall accuracy | |
# accuracy = (true_labels == pred_labels).mean() | |
accuracy = accuracy_score(true_labels, pred_labels) | |
# Calculate confusion matrix for each group | |
group_names = label_map.values() | |
group_cms = {} | |
for group in group_names: | |
true_group_indices = [i for i, label in enumerate(true_labels) if label == group] | |
pred_group_labels = [pred_labels[i] for i in true_group_indices] | |
true_group_labels = [true_labels[i] for i in true_group_indices] | |
cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names)) | |
group_cms[group] = cm | |
# Calculate fairness score | |
score = 0 | |
for i, group1 in enumerate(group_names): | |
for j, group2 in enumerate(group_names): | |
if i < j: | |
cm1 = group_cms[group1] | |
cm2 = group_cms[group2] | |
diff = np.abs(cm1 - cm2) | |
score += (diff.sum() / 2) / cm1.sum() | |
return accuracy, score | |
def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'): | |
unique_labels = sorted(label_map.values()) | |
metrics = [] | |
if metric == 'accuracy': | |
for label in unique_labels: | |
label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label] | |
true_label_subset = [true_labels[i] for i in label_indices] | |
pred_label_subset = [pred_labels[i] for i in label_indices] | |
accuracy = accuracy_score(true_label_subset, pred_label_subset) | |
metrics.append(accuracy) | |
elif metric == 'f1': | |
f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None) | |
metrics = f1_scores.tolist() | |
else: | |
raise ValueError(f"Invalid metric: {metric}") | |
return metrics | |
def generate_visualization(visualization_type, results, label_map): | |
true_labels = [r[1] for r in results] | |
pred_labels = [r[2] for r in results] | |
if visualization_type == "confusion_matrix": | |
return generate_report_card(results, label_map)["fig"] | |
elif visualization_type == "per_class_accuracy": | |
per_class_accuracy = calculate_per_class_metrics( | |
true_labels, pred_labels, label_map, metric='accuracy') | |
colors = px.colors.qualitative.Plotly | |
fig = go.Figure() | |
for i, label in enumerate(label_map.values()): | |
fig.add_trace(go.Bar( | |
x=[label], | |
y=[per_class_accuracy[i]], | |
name=label, | |
marker_color=colors[i % len(colors)] | |
)) | |
fig.update_layout(title='Per-Class Accuracy', | |
xaxis_title='Class', yaxis_title='Accuracy') | |
return fig | |
elif visualization_type == "per_class_f1": | |
per_class_f1 = calculate_per_class_metrics( | |
true_labels, pred_labels, label_map, metric='f1') | |
colors = px.colors.qualitative.Plotly | |
fig = go.Figure() | |
for i, label in enumerate(label_map.values()): | |
fig.add_trace(go.Bar( | |
x=[label], | |
y=[per_class_f1[i]], | |
name=label, | |
marker_color=colors[i % len(colors)] | |
)) | |
fig.update_layout(title='Per-Class F1-Score', | |
xaxis_title='Class', yaxis_title='F1-Score') | |
return fig | |
else: | |
raise ValueError(f"Invalid visualization type: {visualization_type}") | |
def generate_report_card(results, label_map): | |
true_labels = [r[1] for r in results] | |
pred_labels = [r[2] for r in results] | |
cm = confusion_matrix(true_labels, pred_labels, | |
labels=list(label_map.values())) | |
# Create the plotly figure | |
fig = make_subplots(rows=1, cols=1) | |
fig.add_trace(go.Heatmap( | |
z=cm, | |
x=list(label_map.values()), | |
y=list(label_map.values()), | |
colorscale='RdYlGn', | |
colorbar=dict(title='# of Samples') | |
)) | |
fig.update_layout( | |
height=500, width=600, | |
title='Confusion Matrix', | |
xaxis=dict(title='Predicted Labels'), | |
yaxis=dict(title='True Labels', autorange='reversed') | |
) | |
# Create the text output | |
# accuracy = pd.Series(true_labels) == pd.Series(pred_labels) | |
accuracy = accuracy_score(true_labels, pred_labels, normalize=False) | |
fairness_score = calculate_fairness_score(results, label_map) | |
per_class_accuracy = calculate_per_class_metrics( | |
true_labels, pred_labels, label_map, metric='accuracy') | |
per_class_f1 = calculate_per_class_metrics( | |
true_labels, pred_labels, label_map, metric='f1') | |
text_output = html.Div(children=[ | |
html.H2('Performance Metrics'), | |
html.Div(children=[ | |
html.Div(children=[ | |
html.H3('Accuracy'), | |
html.H4(f'{accuracy}') | |
], className='metric'), | |
html.Div(children=[ | |
html.H3('Fairness Score'), | |
# html.H4(f'{fairness_score}') | |
html.H4( | |
f'Accuracy: {fairness_score[0]:.2f}, Score: {fairness_score[1]:.2f}') | |
], className='metric'), | |
], className='metric-container'), | |
], className='text-output') | |
# Combine the plot and text output into a Dash container | |
# report_card = html.Div([ | |
# dcc.Graph(figure=fig), | |
# text_output, | |
# ]) | |
# return report_card | |
report_card = { | |
"fig": fig, | |
"accuracy": accuracy, | |
"fairness_score": fairness_score, | |
"per_class_accuracy": per_class_accuracy, | |
"per_class_f1": per_class_f1 | |
} | |
return report_card | |
# return fig, text_output | |
def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str): | |
tokenizer, model = load_model( | |
model_type, model_name_or_path, dataset_name, config_name) | |
# Load the dataset | |
# Add this line to cast num_samples to an integer | |
num_samples = int(num_samples) | |
dataset = load_dataset( | |
dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]") | |
test_data = [] | |
if dataset_name == "glue": | |
test_data = [(item["sentence"], None, | |
dataset.features["label"].names[item["label"]]) for item in dataset] | |
elif dataset_name == "tweet_eval": | |
test_data = [(item["text"], None, dataset.features["label"].names[item["label"]]) | |
for item in dataset] | |
else: | |
test_data = [(item["sentence"], None, | |
dataset.features["label"].names[item["label"]]) for item in dataset] | |
# if model_type == "text_classification": | |
# for item in dataset: | |
# text = item["sentence"] | |
# context = None | |
# true_label = item["label"] | |
# test_data.append((text, context, true_label)) | |
# elif model_type == "question_answering": | |
# for item in dataset: | |
# text = item["question"] | |
# context = item["context"] | |
# true_label = None | |
# test_data.append((text, context, true_label)) | |
# else: | |
# raise ValueError(f"Invalid model type: {model_type}") | |
label_map = generate_label_map(dataset) | |
results = test_model(tokenizer, model, test_data, label_map) | |
# fig, text_output = generate_report_card(results, label_map) | |
# return fig, text_output | |
report_card = generate_report_card(results, label_map) | |
visualization = generate_visualization(visualization_type, results, label_map) | |
per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip( | |
label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])]) | |
# return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}" | |
# return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"] | |
return (f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}\n\n" | |
f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization | |
interface = gr.Interface( | |
fn=app, | |
inputs=[ | |
gr.inputs.Radio(["text_classification", "token_classification", | |
"question_answering"], label="Model Type", default="text_classification"), | |
gr.inputs.Textbox(lines=1, label="Model Name or Path", | |
placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"), | |
gr.inputs.Textbox(lines=1, label="Dataset Name", | |
placeholder="ex: glue", default="glue"), | |
gr.inputs.Textbox(lines=1, label="Config Name", | |
placeholder="ex: sst2", default="cola"), | |
gr.inputs.Dropdown( | |
choices=["train", "validation", "test"], label="Dataset Split", default="validation"), | |
gr.inputs.Number(default=100, label="Number of Samples"), | |
gr.inputs.Dropdown( | |
choices=["confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="confusion_matrix" | |
), | |
], | |
# outputs=gr.Plot(), | |
# outputs=gr.outputs.HTML(), | |
# outputs=[gr.outputs.HTML(), gr.Plot()], | |
outputs=[ | |
gr.outputs.Textbox(label="Fairness and Bias Metrics"), | |
gr.Plot(label="Graph") | |
], | |
title="Fairness and Bias Testing", | |
description="Enter a model and dataset to test for fairness and bias.", | |
) | |
# Define the label map globally | |
label_map = {0: "negative", 1: "positive"} | |
if __name__ == "__main__": | |
interface.launch() | |