Upload 2 files
Browse files- evaluation.py +54 -0
- explainer.py +153 -0
evaluation.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import torch
|
3 |
+
from enum import Enum
|
4 |
+
from scripts.gputils import print_gpu_utilization, clear_gpu_mem
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class AssertionType(Enum):
|
9 |
+
PRESENT = 0
|
10 |
+
ABSENT = 1
|
11 |
+
POSSIBLE = 2
|
12 |
+
|
13 |
+
|
14 |
+
class EntityWithAssertion:
|
15 |
+
def __init__(self, entity: str, assertion_type: AssertionType):
|
16 |
+
self.entity = entity
|
17 |
+
self.assertion_type = assertion_type
|
18 |
+
|
19 |
+
def __repr__(self) -> str:
|
20 |
+
return f"{self.assertion_type.name}: {self.entity}"
|
21 |
+
|
22 |
+
|
23 |
+
def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32):
|
24 |
+
predictions = []
|
25 |
+
for i in tqdm(range(0, len(sentences), batch_size)):
|
26 |
+
batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda")
|
27 |
+
with torch.no_grad():
|
28 |
+
outputs = model(**batch)
|
29 |
+
predicted_labels = torch.argmax(outputs.logits, dim=1)
|
30 |
+
predictions.append(predicted_labels)
|
31 |
+
print_gpu_utilization()
|
32 |
+
return torch.cat(predictions)
|
33 |
+
|
34 |
+
|
35 |
+
def input_classification(model, tokenizer, x: str = None, all_classes = False):
|
36 |
+
if x is None:
|
37 |
+
x = input("Write your sentence and press Enter to continue")
|
38 |
+
tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True)
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = model(**tokenized_x)
|
41 |
+
predicted_label = torch.argmax(outputs.logits, dim=1)
|
42 |
+
if all_classes:
|
43 |
+
return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])}
|
44 |
+
return model.config.id2label[int(predicted_label)]
|
45 |
+
|
46 |
+
|
47 |
+
def compute_results(y, y_hat):
|
48 |
+
metric_f1 = evaluate.load("f1")
|
49 |
+
metric_acc = evaluate.load("accuracy")
|
50 |
+
return {
|
51 |
+
"macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"],
|
52 |
+
"micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"],
|
53 |
+
"accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"]
|
54 |
+
}
|
explainer.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers_interpret import SequenceClassificationExplainer
|
2 |
+
from captum.attr import visualization as viz
|
3 |
+
import html
|
4 |
+
|
5 |
+
|
6 |
+
class CustomExplainer(SequenceClassificationExplainer):
|
7 |
+
def __init__(self, model, tokenizer):
|
8 |
+
super().__init__(model, tokenizer)
|
9 |
+
|
10 |
+
def visualize(self, html_filepath: str = None, true_class: str = None):
|
11 |
+
"""
|
12 |
+
Visualizes word attributions. If in a notebook table will be displayed inline.
|
13 |
+
|
14 |
+
Otherwise pass a valid path to `html_filepath` and the visualization will be saved
|
15 |
+
as a html file.
|
16 |
+
|
17 |
+
If the true class is known for the text that can be passed to `true_class`
|
18 |
+
|
19 |
+
"""
|
20 |
+
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
|
21 |
+
attr_class = self.id2label[self.selected_index]
|
22 |
+
|
23 |
+
if self._single_node_output:
|
24 |
+
if true_class is None:
|
25 |
+
true_class = round(float(self.pred_probs))
|
26 |
+
predicted_class = round(float(self.pred_probs))
|
27 |
+
attr_class = round(float(self.pred_probs))
|
28 |
+
|
29 |
+
else:
|
30 |
+
if true_class is None:
|
31 |
+
true_class = self.selected_index
|
32 |
+
predicted_class = self.predicted_class_name
|
33 |
+
|
34 |
+
score_viz = self.attributions.visualize_attributions( # type: ignore
|
35 |
+
self.pred_probs,
|
36 |
+
predicted_class,
|
37 |
+
true_class,
|
38 |
+
attr_class,
|
39 |
+
tokens,
|
40 |
+
)
|
41 |
+
|
42 |
+
print(score_viz)
|
43 |
+
|
44 |
+
html = viz.visualize_text([score_viz])
|
45 |
+
|
46 |
+
if html_filepath:
|
47 |
+
if not html_filepath.endswith(".html"):
|
48 |
+
html_filepath = html_filepath + ".html"
|
49 |
+
with open(html_filepath, "w") as html_file:
|
50 |
+
html_file.write("<meta charset='UTF-8'>" + html.data)
|
51 |
+
return html
|
52 |
+
|
53 |
+
def merge_attributions(self, token_level_attributions):
|
54 |
+
final = []
|
55 |
+
scores = []
|
56 |
+
for i, elem in enumerate(token_level_attributions):
|
57 |
+
token = elem[0]
|
58 |
+
score = elem[1]
|
59 |
+
if token.startswith("##"):
|
60 |
+
final[-1] = final[-1] + token.replace("##", "")
|
61 |
+
scores[-1] = scores[-1] + score
|
62 |
+
else:
|
63 |
+
final.append(token)
|
64 |
+
scores.append(score)
|
65 |
+
attr = [(final[i], scores[i]) for i in range(len(final))]
|
66 |
+
return attr
|
67 |
+
|
68 |
+
def visualize_wordwise(self, sentence: str, path: str, true_class: str):
|
69 |
+
pred_class = self.predicted_class_name
|
70 |
+
if pred_class == true_class:
|
71 |
+
legend_sent = f"against {pred_class}"
|
72 |
+
else:
|
73 |
+
legend_sent = f"against {pred_class} and towards {true_class}"
|
74 |
+
attribution_weights = self.merge_attributions(self(sentence))
|
75 |
+
min_weight = min([float(abs(w)) for _, w in attribution_weights])
|
76 |
+
max_weight = max([float(abs(w)) for _, w in attribution_weights])
|
77 |
+
attention_html = []
|
78 |
+
for word, weight in attribution_weights:
|
79 |
+
hue = 5 if weight < 0 else 147
|
80 |
+
sat = "100%" if weight < 0 else "50%"
|
81 |
+
|
82 |
+
# Logarithmic mapping to scale weight values
|
83 |
+
scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight)
|
84 |
+
|
85 |
+
# Adjust brightness and saturation for better contrast
|
86 |
+
lightness = f"{100 - 50 * scaled_weight}%"
|
87 |
+
|
88 |
+
color = f"hsl({hue},{sat},{lightness})"
|
89 |
+
|
90 |
+
attention_html.append(
|
91 |
+
f"<span class='word-box' style='background-color: {color};''>{word}</span><span> </span>")
|
92 |
+
|
93 |
+
attention_html = html.unescape("".join(attention_html))
|
94 |
+
|
95 |
+
final_html = f"""
|
96 |
+
<!DOCTYPE html>
|
97 |
+
<html>
|
98 |
+
<head>
|
99 |
+
<title>Attention Visualization</title>
|
100 |
+
<style>
|
101 |
+
span {{
|
102 |
+
font-family: sans-serif;
|
103 |
+
font-size: 16px;
|
104 |
+
}}
|
105 |
+
</style>
|
106 |
+
<style>
|
107 |
+
/* Color legend */
|
108 |
+
.color-legend {{
|
109 |
+
display: inline-block;
|
110 |
+
margin: 10px 0;
|
111 |
+
padding: 10px 15px;
|
112 |
+
border: 1px solid #ccc;
|
113 |
+
border-radius: 5px;
|
114 |
+
}}
|
115 |
+
|
116 |
+
.word-box {{
|
117 |
+
display: inline-block;
|
118 |
+
border-radius: 5px;
|
119 |
+
padding: 0.2em;
|
120 |
+
}}
|
121 |
+
|
122 |
+
.color-legend span {{
|
123 |
+
display: inline-block;
|
124 |
+
margin: 0 5px;
|
125 |
+
}}
|
126 |
+
|
127 |
+
.positive-weight {{
|
128 |
+
color: green;
|
129 |
+
}}
|
130 |
+
|
131 |
+
.negative-weight {{
|
132 |
+
color: red;
|
133 |
+
}}
|
134 |
+
|
135 |
+
.color-legend span:first-child {{
|
136 |
+
margin-left: 0;
|
137 |
+
}}
|
138 |
+
</style>
|
139 |
+
<meta charset="utf-8" />
|
140 |
+
</head>
|
141 |
+
<body>
|
142 |
+
<div class="color-legend">
|
143 |
+
<p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p>
|
144 |
+
<p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p>
|
145 |
+
<p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p>
|
146 |
+
</div>
|
147 |
+
<div>{attention_html}</div>
|
148 |
+
</body>
|
149 |
+
</html>
|
150 |
+
"""
|
151 |
+
|
152 |
+
with open(path, "w", encoding="utf-8") as f:
|
153 |
+
f.write(final_html)
|