maximuspowers
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,13 @@ import torch
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
6 |
-
#
|
7 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
8 |
-
model = BertForTokenClassification.from_pretrained('
|
9 |
model.eval()
|
10 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
12 |
-
#
|
13 |
id2label = {
|
14 |
0: 'O',
|
15 |
1: 'B-STEREO',
|
@@ -20,8 +20,35 @@ id2label = {
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
23 |
-
#
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
26 |
input_ids = inputs['input_ids'].to(model.device)
|
27 |
attention_mask = inputs['attention_mask'].to(model.device)
|
@@ -30,29 +57,111 @@ def predict_ner_tags(sentence):
|
|
30 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
31 |
logits = outputs.logits
|
32 |
probabilities = torch.sigmoid(logits)
|
33 |
-
predicted_labels = (probabilities > 0.5).int() # remember to try your own threshold
|
34 |
|
35 |
-
result = []
|
36 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
|
|
37 |
for i, token in enumerate(tokens):
|
38 |
if token not in tokenizer.all_special_tokens:
|
39 |
-
label_indices = (
|
40 |
-
labels = [
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
6 |
+
# Initialize tokenizer and model
|
7 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
8 |
+
model = BertForTokenClassification.from_pretrained('ethical-spectacle/social-bias-ner')
|
9 |
model.eval()
|
10 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
12 |
+
# Mapping IDs to labels
|
13 |
id2label = {
|
14 |
0: 'O',
|
15 |
1: 'B-STEREO',
|
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
23 |
+
# Entity colors for highlights
|
24 |
+
label_colors = {
|
25 |
+
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
|
26 |
+
"GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
|
27 |
+
"UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
|
28 |
+
}
|
29 |
+
|
30 |
+
# Post-process entity tags
|
31 |
+
def post_process_entities(result):
|
32 |
+
prev_entity_type = None
|
33 |
+
for token_data in result:
|
34 |
+
labels = token_data["labels"]
|
35 |
+
|
36 |
+
# Handle sequence rules
|
37 |
+
new_labels = []
|
38 |
+
for label_data in labels:
|
39 |
+
label = label_data['label']
|
40 |
+
if label.startswith("B-") and prev_entity_type == label[2:]:
|
41 |
+
new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]})
|
42 |
+
elif label.startswith("I-") and prev_entity_type != label[2:]:
|
43 |
+
new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]})
|
44 |
+
else:
|
45 |
+
new_labels.append(label_data)
|
46 |
+
prev_entity_type = label[2:]
|
47 |
+
token_data["labels"] = new_labels
|
48 |
+
return result
|
49 |
+
|
50 |
+
# Generate HTML matrix and JSON results with probabilities
|
51 |
+
def predict_ner_tags_with_json(sentence):
|
52 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
53 |
input_ids = inputs['input_ids'].to(model.device)
|
54 |
attention_mask = inputs['attention_mask'].to(model.device)
|
|
|
57 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
58 |
logits = outputs.logits
|
59 |
probabilities = torch.sigmoid(logits)
|
|
|
60 |
|
|
|
61 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
62 |
+
result = []
|
63 |
for i, token in enumerate(tokens):
|
64 |
if token not in tokenizer.all_special_tokens:
|
65 |
+
label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1)
|
66 |
+
labels = [
|
67 |
+
{
|
68 |
+
"label": id2label[idx.item()],
|
69 |
+
"confidence": round(probabilities[0][i][idx].item() * 100, 2)
|
70 |
+
}
|
71 |
+
for idx in label_indices
|
72 |
+
]
|
73 |
+
result.append({"token": token.replace("##", ""), "labels": labels})
|
74 |
+
|
75 |
+
result = post_process_entities(result)
|
76 |
+
|
77 |
+
# Create table rows
|
78 |
+
word_row = []
|
79 |
+
stereo_row = []
|
80 |
+
gen_row = []
|
81 |
+
unfair_row = []
|
82 |
+
|
83 |
+
for token_data in result:
|
84 |
+
token = token_data["token"]
|
85 |
+
labels = token_data["labels"]
|
86 |
+
|
87 |
+
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
|
88 |
+
|
89 |
+
# STEREO
|
90 |
+
stereo_labels = [
|
91 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"]
|
92 |
+
]
|
93 |
+
stereo_row.append(
|
94 |
+
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
|
95 |
+
if stereo_labels else " "
|
96 |
+
)
|
97 |
+
|
98 |
+
# GEN
|
99 |
+
gen_labels = [
|
100 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"]
|
101 |
+
]
|
102 |
+
gen_row.append(
|
103 |
+
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
|
104 |
+
if gen_labels else " "
|
105 |
+
)
|
106 |
+
|
107 |
+
# UNFAIR
|
108 |
+
unfair_labels = [
|
109 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"]
|
110 |
+
]
|
111 |
+
unfair_row.append(
|
112 |
+
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
|
113 |
+
if unfair_labels else " "
|
114 |
+
)
|
115 |
+
|
116 |
+
matrix_html = f"""
|
117 |
+
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
|
118 |
+
<tr>
|
119 |
+
<td><strong>Text Sequence</strong></td>
|
120 |
+
{''.join(f"<td>{word}</td>" for word in word_row)}
|
121 |
+
</tr>
|
122 |
+
<tr>
|
123 |
+
<td><strong>Generalizations</strong></td>
|
124 |
+
{''.join(f"<td>{cell}</td>" for cell in gen_row)}
|
125 |
+
</tr>
|
126 |
+
<tr>
|
127 |
+
<td><strong>Unfairness</strong></td>
|
128 |
+
{''.join(f"<td>{cell}</td>" for cell in unfair_row)}
|
129 |
+
</tr>
|
130 |
+
<tr>
|
131 |
+
<td><strong>Stereotypes</strong></td>
|
132 |
+
{''.join(f"<td>{cell}</td>" for cell in stereo_row)}
|
133 |
+
</tr>
|
134 |
+
</table>
|
135 |
+
"""
|
136 |
+
|
137 |
+
# JSON string
|
138 |
+
json_result = json.dumps(result, indent=4)
|
139 |
+
|
140 |
+
return f"{matrix_html}<br><pre>{json_result}</pre>"
|
141 |
+
|
142 |
+
# Gradio Interface
|
143 |
+
iface = gr.Blocks()
|
144 |
+
|
145 |
+
with iface:
|
146 |
+
with gr.Row():
|
147 |
+
gr.Markdown(
|
148 |
+
"""
|
149 |
+
# GUS-Net 🕵
|
150 |
+
[GUS-Net](https://huggingface.co/ethical-spectacle/social-bias-ner) is a `BertForTokenClassification` based model, trained on the [GUS dataset](https://huggingface.co/datasets/ethical-spectacle/gus-dataset-v1). It preforms multi-label named-entity recognition of socially biased entities, intended to reveal the underlying structure of bias rather than a one-size fits all definition.
|
151 |
+
You can find the full collection of resources introduced in our paper [here](https://huggingface.co/collections/ethical-spectacle/gus-net-66edfe93801ea45d7a26a10f).
|
152 |
+
This [blog post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition) walks through the training and architecture of the model.
|
153 |
+
Enter a sentence for named-entity recognition of biased entities:
|
154 |
+
- **Generalizations (GEN)**
|
155 |
+
- **Unfairness (UNFAIR)**
|
156 |
+
- **Stereotypes (STEREO)**
|
157 |
+
Labels follow the BIO format. Try it out:
|
158 |
+
"""
|
159 |
+
)
|
160 |
+
with gr.Row():
|
161 |
+
input_box = gr.Textbox(label="Input Sentence")
|
162 |
+
with gr.Row():
|
163 |
+
output_box = gr.HTML(label="Entity Matrix and JSON Output")
|
164 |
+
|
165 |
+
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
|
166 |
+
|
167 |
+
iface.launch(share=True)
|