responsibleGPT / app.py
kyleledbetter's picture
feat(app): support more models and datasets
0ec25a0
raw
history blame
12.4 kB
import gradio as gr
import requests
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering
from datasets import load_dataset
import datasets
import plotly.io as pio
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
from sklearn.metrics import confusion_matrix
import importlib
import torch
from dash import Dash, html, dcc
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if model_type == "text_classification":
dataset = load_dataset(dataset_name, config_name)
num_labels = len(dataset["train"].features["label"].names)
if "roberta" in model_name_or_path.lower():
from transformers import RobertaForSequenceClassification
model = RobertaForSequenceClassification.from_pretrained(
model_name_or_path, num_labels=num_labels)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, num_labels=num_labels)
elif model_type == "token_classification":
dataset = load_dataset(dataset_name, config_name)
num_labels = len(
dataset["train"].features["ner_tags"].feature.names)
model = AutoModelForTokenClassification.from_pretrained(
model_name_or_path, num_labels=num_labels)
elif model_type == "question_answering":
model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
else:
raise ValueError(f"Invalid model type: {model_type}")
return tokenizer, model
def test_model(tokenizer, model, test_data: list, label_map: dict):
results = []
for text, _, true_label in test_data:
inputs = tokenizer(text, return_tensors="pt",
truncation=True, padding=True)
outputs = model(**inputs)
pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
results.append((text, true_label, pred_label))
return results
def generate_label_map(dataset):
if "label" not in dataset.features or dataset.features["label"] is None:
return {}
if isinstance(dataset.features["label"], datasets.ClassLabel):
num_labels = dataset.features["label"].num_classes
label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
else:
num_labels = len(set(dataset["label"]))
label_map = {i: label for i, label in enumerate(set(dataset["label"]))}
return label_map
def calculate_fairness_score(results, label_map):
true_labels = [r[1] for r in results]
pred_labels = [r[2] for r in results]
# Overall accuracy
# accuracy = (true_labels == pred_labels).mean()
accuracy = accuracy_score(true_labels, pred_labels)
# Calculate confusion matrix for each group
group_names = label_map.values()
group_cms = {}
for group in group_names:
true_group_indices = [i for i, label in enumerate(true_labels) if label == group]
pred_group_labels = [pred_labels[i] for i in true_group_indices]
true_group_labels = [true_labels[i] for i in true_group_indices]
cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names))
group_cms[group] = cm
# Calculate fairness score
score = 0
for i, group1 in enumerate(group_names):
for j, group2 in enumerate(group_names):
if i < j:
cm1 = group_cms[group1]
cm2 = group_cms[group2]
diff = np.abs(cm1 - cm2)
score += (diff.sum() / 2) / cm1.sum()
return accuracy, score
def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'):
unique_labels = sorted(label_map.values())
metrics = []
if metric == 'accuracy':
for label in unique_labels:
label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label]
true_label_subset = [true_labels[i] for i in label_indices]
pred_label_subset = [pred_labels[i] for i in label_indices]
accuracy = accuracy_score(true_label_subset, pred_label_subset)
metrics.append(accuracy)
elif metric == 'f1':
f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None)
metrics = f1_scores.tolist()
else:
raise ValueError(f"Invalid metric: {metric}")
return metrics
def generate_visualization(visualization_type, results, label_map):
true_labels = [r[1] for r in results]
pred_labels = [r[2] for r in results]
if visualization_type == "confusion_matrix":
return generate_report_card(results, label_map)["fig"]
elif visualization_type == "per_class_accuracy":
per_class_accuracy = calculate_per_class_metrics(
true_labels, pred_labels, label_map, metric='accuracy')
colors = px.colors.qualitative.Plotly
fig = go.Figure()
for i, label in enumerate(label_map.values()):
fig.add_trace(go.Bar(
x=[label],
y=[per_class_accuracy[i]],
name=label,
marker_color=colors[i % len(colors)]
))
fig.update_layout(title='Per-Class Accuracy',
xaxis_title='Class', yaxis_title='Accuracy')
return fig
elif visualization_type == "per_class_f1":
per_class_f1 = calculate_per_class_metrics(
true_labels, pred_labels, label_map, metric='f1')
colors = px.colors.qualitative.Plotly
fig = go.Figure()
for i, label in enumerate(label_map.values()):
fig.add_trace(go.Bar(
x=[label],
y=[per_class_f1[i]],
name=label,
marker_color=colors[i % len(colors)]
))
fig.update_layout(title='Per-Class F1-Score',
xaxis_title='Class', yaxis_title='F1-Score')
return fig
else:
raise ValueError(f"Invalid visualization type: {visualization_type}")
def generate_report_card(results, label_map):
true_labels = [r[1] for r in results]
pred_labels = [r[2] for r in results]
cm = confusion_matrix(true_labels, pred_labels,
labels=list(label_map.values()))
# Create the plotly figure
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Heatmap(
z=cm,
x=list(label_map.values()),
y=list(label_map.values()),
colorscale='RdYlGn',
colorbar=dict(title='# of Samples')
))
fig.update_layout(
height=500, width=600,
title='Confusion Matrix',
xaxis=dict(title='Predicted Labels'),
yaxis=dict(title='True Labels', autorange='reversed')
)
# Create the text output
# accuracy = pd.Series(true_labels) == pd.Series(pred_labels)
accuracy = accuracy_score(true_labels, pred_labels, normalize=False)
fairness_score = calculate_fairness_score(results, label_map)
per_class_accuracy = calculate_per_class_metrics(
true_labels, pred_labels, label_map, metric='accuracy')
per_class_f1 = calculate_per_class_metrics(
true_labels, pred_labels, label_map, metric='f1')
text_output = html.Div(children=[
html.H2('Performance Metrics'),
html.Div(children=[
html.Div(children=[
html.H3('Accuracy'),
html.H4(f'{accuracy}')
], className='metric'),
html.Div(children=[
html.H3('Fairness Score'),
# html.H4(f'{fairness_score}')
html.H4(
f'Accuracy: {fairness_score[0]:.2f}, Score: {fairness_score[1]:.2f}')
], className='metric'),
], className='metric-container'),
], className='text-output')
# Combine the plot and text output into a Dash container
# report_card = html.Div([
# dcc.Graph(figure=fig),
# text_output,
# ])
# return report_card
report_card = {
"fig": fig,
"accuracy": accuracy,
"fairness_score": fairness_score,
"per_class_accuracy": per_class_accuracy,
"per_class_f1": per_class_f1
}
return report_card
# return fig, text_output
def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str):
tokenizer, model = load_model(
model_type, model_name_or_path, dataset_name, config_name)
# Load the dataset
# Add this line to cast num_samples to an integer
num_samples = int(num_samples)
dataset = load_dataset(
dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
test_data = []
if dataset_name == "glue":
test_data = [(item["sentence"], None,
dataset.features["label"].names[item["label"]]) for item in dataset]
elif dataset_name == "tweet_eval":
test_data = [(item["text"], None, dataset.features["label"].names[item["label"]])
for item in dataset]
else:
test_data = [(item["sentence"], None,
dataset.features["label"].names[item["label"]]) for item in dataset]
# if model_type == "text_classification":
# for item in dataset:
# text = item["sentence"]
# context = None
# true_label = item["label"]
# test_data.append((text, context, true_label))
# elif model_type == "question_answering":
# for item in dataset:
# text = item["question"]
# context = item["context"]
# true_label = None
# test_data.append((text, context, true_label))
# else:
# raise ValueError(f"Invalid model type: {model_type}")
label_map = generate_label_map(dataset)
results = test_model(tokenizer, model, test_data, label_map)
# fig, text_output = generate_report_card(results, label_map)
# return fig, text_output
report_card = generate_report_card(results, label_map)
visualization = generate_visualization(visualization_type, results, label_map)
per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip(
label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])])
# return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}"
# return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"]
return (f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}\n\n"
f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization
interface = gr.Interface(
fn=app,
inputs=[
gr.inputs.Radio(["text_classification", "token_classification",
"question_answering"], label="Model Type", default="text_classification"),
gr.inputs.Textbox(lines=1, label="Model Name or Path",
placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"),
gr.inputs.Textbox(lines=1, label="Dataset Name",
placeholder="ex: glue", default="glue"),
gr.inputs.Textbox(lines=1, label="Config Name",
placeholder="ex: sst2", default="cola"),
gr.inputs.Dropdown(
choices=["train", "validation", "test"], label="Dataset Split", default="validation"),
gr.inputs.Number(default=100, label="Number of Samples"),
gr.inputs.Dropdown(
choices=["confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="confusion_matrix"
),
],
# outputs=gr.Plot(),
# outputs=gr.outputs.HTML(),
# outputs=[gr.outputs.HTML(), gr.Plot()],
outputs=[
gr.outputs.Textbox(label="Fairness and Bias Metrics"),
gr.Plot(label="Graph")
],
title="Fairness and Bias Testing",
description="Enter a model and dataset to test for fairness and bias.",
)
# Define the label map globally
label_map = {0: "negative", 1: "positive"}
if __name__ == "__main__":
interface.launch()