Detsutut commited on
Commit
f388ec1
·
verified ·
1 Parent(s): 990c0de

Upload 2 files

Browse files
Files changed (2) hide show
  1. evaluation.py +54 -0
  2. explainer.py +153 -0
evaluation.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import torch
3
+ from enum import Enum
4
+ from scripts.gputils import print_gpu_utilization, clear_gpu_mem
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AssertionType(Enum):
9
+ PRESENT = 0
10
+ ABSENT = 1
11
+ POSSIBLE = 2
12
+
13
+
14
+ class EntityWithAssertion:
15
+ def __init__(self, entity: str, assertion_type: AssertionType):
16
+ self.entity = entity
17
+ self.assertion_type = assertion_type
18
+
19
+ def __repr__(self) -> str:
20
+ return f"{self.assertion_type.name}: {self.entity}"
21
+
22
+
23
+ def classify_assertions_in_sentences(sentences, model, tokenizer, batch_size=32):
24
+ predictions = []
25
+ for i in tqdm(range(0, len(sentences), batch_size)):
26
+ batch = tokenizer(sentences[i:i + batch_size], return_tensors="pt", padding=True, truncation=True).to("cuda")
27
+ with torch.no_grad():
28
+ outputs = model(**batch)
29
+ predicted_labels = torch.argmax(outputs.logits, dim=1)
30
+ predictions.append(predicted_labels)
31
+ print_gpu_utilization()
32
+ return torch.cat(predictions)
33
+
34
+
35
+ def input_classification(model, tokenizer, x: str = None, all_classes = False):
36
+ if x is None:
37
+ x = input("Write your sentence and press Enter to continue")
38
+ tokenized_x = tokenizer(x, return_tensors="pt", padding=True, truncation=True)
39
+ with torch.no_grad():
40
+ outputs = model(**tokenized_x)
41
+ predicted_label = torch.argmax(outputs.logits, dim=1)
42
+ if all_classes:
43
+ return {model.config.id2label[i]:float(k) for i,k in enumerate(torch.softmax(outputs.logits, dim=1)[0])}
44
+ return model.config.id2label[int(predicted_label)]
45
+
46
+
47
+ def compute_results(y, y_hat):
48
+ metric_f1 = evaluate.load("f1")
49
+ metric_acc = evaluate.load("accuracy")
50
+ return {
51
+ "macro-f1": metric_f1.compute(predictions=y_hat, references=y, average="macro")["f1"],
52
+ "micro-f1": metric_f1.compute(predictions=y_hat, references=y, average="micro")["f1"],
53
+ "accuracy": metric_acc.compute(predictions=y_hat, references=y)["accuracy"]
54
+ }
explainer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers_interpret import SequenceClassificationExplainer
2
+ from captum.attr import visualization as viz
3
+ import html
4
+
5
+
6
+ class CustomExplainer(SequenceClassificationExplainer):
7
+ def __init__(self, model, tokenizer):
8
+ super().__init__(model, tokenizer)
9
+
10
+ def visualize(self, html_filepath: str = None, true_class: str = None):
11
+ """
12
+ Visualizes word attributions. If in a notebook table will be displayed inline.
13
+
14
+ Otherwise pass a valid path to `html_filepath` and the visualization will be saved
15
+ as a html file.
16
+
17
+ If the true class is known for the text that can be passed to `true_class`
18
+
19
+ """
20
+ tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
21
+ attr_class = self.id2label[self.selected_index]
22
+
23
+ if self._single_node_output:
24
+ if true_class is None:
25
+ true_class = round(float(self.pred_probs))
26
+ predicted_class = round(float(self.pred_probs))
27
+ attr_class = round(float(self.pred_probs))
28
+
29
+ else:
30
+ if true_class is None:
31
+ true_class = self.selected_index
32
+ predicted_class = self.predicted_class_name
33
+
34
+ score_viz = self.attributions.visualize_attributions( # type: ignore
35
+ self.pred_probs,
36
+ predicted_class,
37
+ true_class,
38
+ attr_class,
39
+ tokens,
40
+ )
41
+
42
+ print(score_viz)
43
+
44
+ html = viz.visualize_text([score_viz])
45
+
46
+ if html_filepath:
47
+ if not html_filepath.endswith(".html"):
48
+ html_filepath = html_filepath + ".html"
49
+ with open(html_filepath, "w") as html_file:
50
+ html_file.write("<meta charset='UTF-8'>" + html.data)
51
+ return html
52
+
53
+ def merge_attributions(self, token_level_attributions):
54
+ final = []
55
+ scores = []
56
+ for i, elem in enumerate(token_level_attributions):
57
+ token = elem[0]
58
+ score = elem[1]
59
+ if token.startswith("##"):
60
+ final[-1] = final[-1] + token.replace("##", "")
61
+ scores[-1] = scores[-1] + score
62
+ else:
63
+ final.append(token)
64
+ scores.append(score)
65
+ attr = [(final[i], scores[i]) for i in range(len(final))]
66
+ return attr
67
+
68
+ def visualize_wordwise(self, sentence: str, path: str, true_class: str):
69
+ pred_class = self.predicted_class_name
70
+ if pred_class == true_class:
71
+ legend_sent = f"against {pred_class}"
72
+ else:
73
+ legend_sent = f"against {pred_class} and towards {true_class}"
74
+ attribution_weights = self.merge_attributions(self(sentence))
75
+ min_weight = min([float(abs(w)) for _, w in attribution_weights])
76
+ max_weight = max([float(abs(w)) for _, w in attribution_weights])
77
+ attention_html = []
78
+ for word, weight in attribution_weights:
79
+ hue = 5 if weight < 0 else 147
80
+ sat = "100%" if weight < 0 else "50%"
81
+
82
+ # Logarithmic mapping to scale weight values
83
+ scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight)
84
+
85
+ # Adjust brightness and saturation for better contrast
86
+ lightness = f"{100 - 50 * scaled_weight}%"
87
+
88
+ color = f"hsl({hue},{sat},{lightness})"
89
+
90
+ attention_html.append(
91
+ f"<span class='word-box' style='background-color: {color};''>{word}</span><span>&nbsp;</span>")
92
+
93
+ attention_html = html.unescape("".join(attention_html))
94
+
95
+ final_html = f"""
96
+ <!DOCTYPE html>
97
+ <html>
98
+ <head>
99
+ <title>Attention Visualization</title>
100
+ <style>
101
+ span {{
102
+ font-family: sans-serif;
103
+ font-size: 16px;
104
+ }}
105
+ </style>
106
+ <style>
107
+ /* Color legend */
108
+ .color-legend {{
109
+ display: inline-block;
110
+ margin: 10px 0;
111
+ padding: 10px 15px;
112
+ border: 1px solid #ccc;
113
+ border-radius: 5px;
114
+ }}
115
+
116
+ .word-box {{
117
+ display: inline-block;
118
+ border-radius: 5px;
119
+ padding: 0.2em;
120
+ }}
121
+
122
+ .color-legend span {{
123
+ display: inline-block;
124
+ margin: 0 5px;
125
+ }}
126
+
127
+ .positive-weight {{
128
+ color: green;
129
+ }}
130
+
131
+ .negative-weight {{
132
+ color: red;
133
+ }}
134
+
135
+ .color-legend span:first-child {{
136
+ margin-left: 0;
137
+ }}
138
+ </style>
139
+ <meta charset="utf-8" />
140
+ </head>
141
+ <body>
142
+ <div class="color-legend">
143
+ <p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p>
144
+ <p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p>
145
+ <p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p>
146
+ </div>
147
+ <div>{attention_html}</div>
148
+ </body>
149
+ </html>
150
+ """
151
+
152
+ with open(path, "w", encoding="utf-8") as f:
153
+ f.write(final_html)