maximuspowers commited on
Commit
eccfda3
·
verified ·
1 Parent(s): bdb7580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -27
app.py CHANGED
@@ -3,13 +3,13 @@ import torch
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
- # init important things
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
- model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- # ids to labels we want to display
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
@@ -20,8 +20,35 @@ id2label = {
20
  6: 'I-UNFAIR'
21
  }
22
 
23
- # predict function you'll want to use if using in your own code
24
- def predict_ner_tags(sentence):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
26
  input_ids = inputs['input_ids'].to(model.device)
27
  attention_mask = inputs['attention_mask'].to(model.device)
@@ -30,29 +57,111 @@ def predict_ner_tags(sentence):
30
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
31
  logits = outputs.logits
32
  probabilities = torch.sigmoid(logits)
33
- predicted_labels = (probabilities > 0.5).int() # remember to try your own threshold
34
 
35
- result = []
36
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
37
  for i, token in enumerate(tokens):
38
  if token not in tokenizer.all_special_tokens:
39
- label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
40
- labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
41
- result.append({"token": token, "labels": labels})
42
-
43
- return json.dumps(result, indent=4)
44
-
45
- # startup gradio
46
- iface = gr.Interface(
47
- fn=predict_ner_tags,
48
- inputs="text",
49
- outputs="text",
50
- title="Social Bias Named Entity Recognition (with BERT) 🕵",
51
- description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
52
- "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
53
- "<br>Model Page: <a href='https://huggingface.co/maximuspowers/bias-detection-ner' target='_blank'>Bias Detection NER</a>."),
54
- allow_flagging="never"
55
- )
56
-
57
- if __name__ == "__main__":
58
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
+ # Initialize tokenizer and model
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
+ model = BertForTokenClassification.from_pretrained('ethical-spectacle/social-bias-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # Mapping IDs to labels
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
 
20
  6: 'I-UNFAIR'
21
  }
22
 
23
+ # Entity colors for highlights
24
+ label_colors = {
25
+ "STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
26
+ "GEN": "rgba(0, 0, 255, 0.2)", # Light Blue
27
+ "UNFAIR": "rgba(0, 255, 0, 0.2)" # Light Green
28
+ }
29
+
30
+ # Post-process entity tags
31
+ def post_process_entities(result):
32
+ prev_entity_type = None
33
+ for token_data in result:
34
+ labels = token_data["labels"]
35
+
36
+ # Handle sequence rules
37
+ new_labels = []
38
+ for label_data in labels:
39
+ label = label_data['label']
40
+ if label.startswith("B-") and prev_entity_type == label[2:]:
41
+ new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]})
42
+ elif label.startswith("I-") and prev_entity_type != label[2:]:
43
+ new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]})
44
+ else:
45
+ new_labels.append(label_data)
46
+ prev_entity_type = label[2:]
47
+ token_data["labels"] = new_labels
48
+ return result
49
+
50
+ # Generate HTML matrix and JSON results with probabilities
51
+ def predict_ner_tags_with_json(sentence):
52
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
53
  input_ids = inputs['input_ids'].to(model.device)
54
  attention_mask = inputs['attention_mask'].to(model.device)
 
57
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
58
  logits = outputs.logits
59
  probabilities = torch.sigmoid(logits)
 
60
 
 
61
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
62
+ result = []
63
  for i, token in enumerate(tokens):
64
  if token not in tokenizer.all_special_tokens:
65
+ label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1)
66
+ labels = [
67
+ {
68
+ "label": id2label[idx.item()],
69
+ "confidence": round(probabilities[0][i][idx].item() * 100, 2)
70
+ }
71
+ for idx in label_indices
72
+ ]
73
+ result.append({"token": token.replace("##", ""), "labels": labels})
74
+
75
+ result = post_process_entities(result)
76
+
77
+ # Create table rows
78
+ word_row = []
79
+ stereo_row = []
80
+ gen_row = []
81
+ unfair_row = []
82
+
83
+ for token_data in result:
84
+ token = token_data["token"]
85
+ labels = token_data["labels"]
86
+
87
+ word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
88
+
89
+ # STEREO
90
+ stereo_labels = [
91
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"]
92
+ ]
93
+ stereo_row.append(
94
+ f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
95
+ if stereo_labels else "&nbsp;"
96
+ )
97
+
98
+ # GEN
99
+ gen_labels = [
100
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"]
101
+ ]
102
+ gen_row.append(
103
+ f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
104
+ if gen_labels else "&nbsp;"
105
+ )
106
+
107
+ # UNFAIR
108
+ unfair_labels = [
109
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"]
110
+ ]
111
+ unfair_row.append(
112
+ f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
113
+ if unfair_labels else "&nbsp;"
114
+ )
115
+
116
+ matrix_html = f"""
117
+ <table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
118
+ <tr>
119
+ <td><strong>Text Sequence</strong></td>
120
+ {''.join(f"<td>{word}</td>" for word in word_row)}
121
+ </tr>
122
+ <tr>
123
+ <td><strong>Generalizations</strong></td>
124
+ {''.join(f"<td>{cell}</td>" for cell in gen_row)}
125
+ </tr>
126
+ <tr>
127
+ <td><strong>Unfairness</strong></td>
128
+ {''.join(f"<td>{cell}</td>" for cell in unfair_row)}
129
+ </tr>
130
+ <tr>
131
+ <td><strong>Stereotypes</strong></td>
132
+ {''.join(f"<td>{cell}</td>" for cell in stereo_row)}
133
+ </tr>
134
+ </table>
135
+ """
136
+
137
+ # JSON string
138
+ json_result = json.dumps(result, indent=4)
139
+
140
+ return f"{matrix_html}<br><pre>{json_result}</pre>"
141
+
142
+ # Gradio Interface
143
+ iface = gr.Blocks()
144
+
145
+ with iface:
146
+ with gr.Row():
147
+ gr.Markdown(
148
+ """
149
+ # GUS-Net 🕵
150
+ [GUS-Net](https://huggingface.co/ethical-spectacle/social-bias-ner) is a `BertForTokenClassification` based model, trained on the [GUS dataset](https://huggingface.co/datasets/ethical-spectacle/gus-dataset-v1). It preforms multi-label named-entity recognition of socially biased entities, intended to reveal the underlying structure of bias rather than a one-size fits all definition.
151
+ You can find the full collection of resources introduced in our paper [here](https://huggingface.co/collections/ethical-spectacle/gus-net-66edfe93801ea45d7a26a10f).
152
+ This [blog post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition) walks through the training and architecture of the model.
153
+ Enter a sentence for named-entity recognition of biased entities:
154
+ - **Generalizations (GEN)**
155
+ - **Unfairness (UNFAIR)**
156
+ - **Stereotypes (STEREO)**
157
+ Labels follow the BIO format. Try it out:
158
+ """
159
+ )
160
+ with gr.Row():
161
+ input_box = gr.Textbox(label="Input Sentence")
162
+ with gr.Row():
163
+ output_box = gr.HTML(label="Entity Matrix and JSON Output")
164
+
165
+ input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
166
+
167
+ iface.launch(share=True)