Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import itertools
|
16 |
+
import json
|
17 |
+
import linecache
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import pickle
|
21 |
+
import socket
|
22 |
+
from logging import getLogger
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
25 |
+
|
26 |
+
import git
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.distributed as dist
|
30 |
+
from rouge_score import rouge_scorer, scoring
|
31 |
+
from sacrebleu import corpus_bleu
|
32 |
+
from torch import nn
|
33 |
+
from torch.utils.data import Dataset, Sampler
|
34 |
+
|
35 |
+
from sentence_splitter import add_newline_to_end_of_each_sentence
|
36 |
+
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
37 |
+
from transformers.file_utils import cached_property
|
38 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
39 |
+
|
40 |
+
|
41 |
+
try:
|
42 |
+
from fairseq.data.data_utils import batch_by_size
|
43 |
+
|
44 |
+
FAIRSEQ_AVAILABLE = True
|
45 |
+
except (ImportError, ModuleNotFoundError):
|
46 |
+
FAIRSEQ_AVAILABLE = False
|
47 |
+
|
48 |
+
|
49 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
50 |
+
"""From fairseq"""
|
51 |
+
if target.dim() == lprobs.dim() - 1:
|
52 |
+
target = target.unsqueeze(-1)
|
53 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
54 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
55 |
+
if ignore_index is not None:
|
56 |
+
pad_mask = target.eq(ignore_index)
|
57 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
58 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
59 |
+
else:
|
60 |
+
nll_loss = nll_loss.squeeze(-1)
|
61 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
62 |
+
|
63 |
+
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
64 |
+
smooth_loss = smooth_loss.sum()
|
65 |
+
eps_i = epsilon / lprobs.size(-1)
|
66 |
+
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
67 |
+
return loss, nll_loss
|
68 |
+
|
69 |
+
|
70 |
+
def lmap(f: Callable, x: Iterable) -> List:
|
71 |
+
"""list(map(f, x))"""
|
72 |
+
return list(map(f, x))
|
73 |
+
|
74 |
+
|
75 |
+
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
76 |
+
"""Uses sacrebleu's corpus_bleu implementation."""
|
77 |
+
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
78 |
+
|
79 |
+
|
80 |
+
def build_compute_metrics_fn(
|
81 |
+
task_name: str, tokenizer: PreTrainedTokenizer
|
82 |
+
) -> Callable[[EvalPrediction], Dict]:
|
83 |
+
def non_pad_len(tokens: np.ndarray) -> int:
|
84 |
+
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
85 |
+
|
86 |
+
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
87 |
+
pred_ids = pred.predictions
|
88 |
+
label_ids = pred.label_ids
|
89 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
90 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
91 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
92 |
+
pred_str = lmap(str.strip, pred_str)
|
93 |
+
label_str = lmap(str.strip, label_str)
|
94 |
+
return pred_str, label_str
|
95 |
+
|
96 |
+
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
97 |
+
pred_str, label_str = decode_pred(pred)
|
98 |
+
rouge: Dict = calculate_rouge(pred_str, label_str)
|
99 |
+
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
100 |
+
rouge.update({"gen_len": summ_len})
|
101 |
+
return rouge
|
102 |
+
|
103 |
+
def translation_metrics(pred: EvalPrediction) -> Dict:
|
104 |
+
pred_str, label_str = decode_pred(pred)
|
105 |
+
bleu: Dict = calculate_bleu(pred_str, label_str)
|
106 |
+
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
107 |
+
bleu.update({"gen_len": gen_len})
|
108 |
+
return bleu
|
109 |
+
|
110 |
+
compute_metrics_fn = (
|
111 |
+
summarization_metrics if "summarization" in task_name else translation_metrics
|
112 |
+
)
|
113 |
+
return compute_metrics_fn
|
114 |
+
|
115 |
+
|
116 |
+
def trim_batch(
|
117 |
+
input_ids,
|
118 |
+
pad_token_id,
|
119 |
+
attention_mask=None,
|
120 |
+
):
|
121 |
+
"""Remove columns that are populated exclusively by pad_token_id"""
|
122 |
+
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
123 |
+
if attention_mask is None:
|
124 |
+
return input_ids[:, keep_column_mask]
|
125 |
+
else:
|
126 |
+
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
127 |
+
|
128 |
+
|
129 |
+
class AbstractSeq2SeqDataset(Dataset):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
tokenizer,
|
133 |
+
data_dir,
|
134 |
+
max_source_length,
|
135 |
+
max_target_length,
|
136 |
+
type_path="train",
|
137 |
+
n_obs=None,
|
138 |
+
prefix="",
|
139 |
+
**dataset_kwargs,
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
143 |
+
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
144 |
+
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
145 |
+
if os.path.exists(self.len_file):
|
146 |
+
self.src_lens = pickle_load(self.len_file)
|
147 |
+
self.used_char_len = False
|
148 |
+
else:
|
149 |
+
self.src_lens = self.get_char_lens(self.src_file)
|
150 |
+
self.used_char_len = True
|
151 |
+
self.max_source_length = max_source_length
|
152 |
+
self.max_target_length = max_target_length
|
153 |
+
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
154 |
+
self.tokenizer = tokenizer
|
155 |
+
self.prefix = prefix if prefix is not None else ""
|
156 |
+
|
157 |
+
if n_obs is not None:
|
158 |
+
self.src_lens = self.src_lens[:n_obs]
|
159 |
+
self.pad_token_id = self.tokenizer.pad_token_id
|
160 |
+
self.dataset_kwargs = dataset_kwargs
|
161 |
+
dataset_kwargs.update(
|
162 |
+
{"add_prefix_space": True}
|
163 |
+
if isinstance(self.tokenizer, BartTokenizer)
|
164 |
+
else {}
|
165 |
+
)
|
166 |
+
|
167 |
+
def __len__(self):
|
168 |
+
return len(self.src_lens)
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def get_char_lens(data_file):
|
172 |
+
return [len(x) for x in Path(data_file).open().readlines()]
|
173 |
+
|
174 |
+
@cached_property
|
175 |
+
def tgt_lens(self):
|
176 |
+
"""Length in characters of target documents"""
|
177 |
+
return self.get_char_lens(self.tgt_file)
|
178 |
+
|
179 |
+
def make_sortish_sampler(
|
180 |
+
self, batch_size, distributed=False, shuffle=True, **kwargs
|
181 |
+
):
|
182 |
+
if distributed:
|
183 |
+
return DistributedSortishSampler(
|
184 |
+
self, batch_size, shuffle=shuffle, **kwargs
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
188 |
+
|
189 |
+
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
190 |
+
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
191 |
+
assert (
|
192 |
+
not self.used_char_len
|
193 |
+
), "You must call python make_len_file.py before calling make_dynamic_sampler"
|
194 |
+
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
195 |
+
|
196 |
+
def num_tokens_in_example(i):
|
197 |
+
return min(self.src_lens[i], self.max_target_length)
|
198 |
+
|
199 |
+
# call fairseq cython function
|
200 |
+
batch_sampler: List[List[int]] = batch_by_size(
|
201 |
+
sorted_indices,
|
202 |
+
num_tokens_fn=num_tokens_in_example,
|
203 |
+
max_tokens=max_tokens_per_batch,
|
204 |
+
required_batch_size_multiple=64,
|
205 |
+
)
|
206 |
+
shuffled_batches = [
|
207 |
+
batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))
|
208 |
+
]
|
209 |
+
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
210 |
+
approximate_toks_per_batch = [
|
211 |
+
max(self.src_lens[i] for i in batch) * len(batch)
|
212 |
+
for batch in shuffled_batches
|
213 |
+
]
|
214 |
+
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
215 |
+
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
216 |
+
shuffled_batches[largest_batch_idx],
|
217 |
+
shuffled_batches[0],
|
218 |
+
)
|
219 |
+
return shuffled_batches
|
220 |
+
|
221 |
+
def __getitem__(self, item):
|
222 |
+
raise NotImplementedError("You must implement this")
|
223 |
+
|
224 |
+
def collate_fn(self, batch):
|
225 |
+
raise NotImplementedError("You must implement this")
|
226 |
+
|
227 |
+
|
228 |
+
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
229 |
+
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
230 |
+
"""Call tokenizer on src and tgt_lines"""
|
231 |
+
index = index + 1 # linecache starts at 1
|
232 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
|
233 |
+
"\n"
|
234 |
+
)
|
235 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
236 |
+
assert source_line, f"empty source line for index {index}"
|
237 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
238 |
+
source_inputs = self.encode_line(
|
239 |
+
self.tokenizer, source_line, self.max_source_length
|
240 |
+
)
|
241 |
+
target_inputs = self.encode_line(
|
242 |
+
self.tokenizer, tgt_line, self.max_target_length
|
243 |
+
)
|
244 |
+
|
245 |
+
source_ids = source_inputs["input_ids"].squeeze()
|
246 |
+
target_ids = target_inputs["input_ids"].squeeze()
|
247 |
+
src_mask = source_inputs["attention_mask"].squeeze()
|
248 |
+
return {
|
249 |
+
"input_ids": source_ids,
|
250 |
+
"attention_mask": src_mask,
|
251 |
+
"labels": target_ids,
|
252 |
+
}
|
253 |
+
|
254 |
+
def encode_line(
|
255 |
+
self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
|
256 |
+
):
|
257 |
+
"""Only used by LegacyDataset"""
|
258 |
+
return tokenizer(
|
259 |
+
[line],
|
260 |
+
max_length=max_length,
|
261 |
+
padding="max_length" if pad_to_max_length else None,
|
262 |
+
truncation=True,
|
263 |
+
return_tensors=return_tensors,
|
264 |
+
**self.dataset_kwargs,
|
265 |
+
)
|
266 |
+
|
267 |
+
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
268 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
269 |
+
masks = torch.stack([x["attention_mask"] for x in batch])
|
270 |
+
target_ids = torch.stack([x["labels"] for x in batch])
|
271 |
+
pad_token_id = self.pad_token_id
|
272 |
+
y = trim_batch(target_ids, pad_token_id)
|
273 |
+
source_ids, source_mask = trim_batch(
|
274 |
+
input_ids, pad_token_id, attention_mask=masks
|
275 |
+
)
|
276 |
+
batch = {
|
277 |
+
"input_ids": source_ids,
|
278 |
+
"attention_mask": source_mask,
|
279 |
+
"labels": y,
|
280 |
+
}
|
281 |
+
return batch
|
282 |
+
|
283 |
+
|
284 |
+
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
285 |
+
"""A dataset that calls prepare_seq2seq_batch."""
|
286 |
+
|
287 |
+
def __getitem__(self, index) -> Dict[str, str]:
|
288 |
+
index = index + 1 # linecache starts at 1
|
289 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
|
290 |
+
"\n"
|
291 |
+
)
|
292 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
293 |
+
assert source_line, f"empty source line for index {index}"
|
294 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
295 |
+
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
296 |
+
|
297 |
+
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
298 |
+
"""Call prepare_seq2seq_batch."""
|
299 |
+
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
300 |
+
[x["src_texts"] for x in batch],
|
301 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
302 |
+
max_length=self.max_source_length,
|
303 |
+
max_target_length=self.max_target_length,
|
304 |
+
return_tensors="pt",
|
305 |
+
**self.dataset_kwargs,
|
306 |
+
).data
|
307 |
+
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
308 |
+
return batch_encoding
|
309 |
+
|
310 |
+
|
311 |
+
class Seq2SeqDataCollator:
|
312 |
+
def __init__(
|
313 |
+
self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None
|
314 |
+
):
|
315 |
+
self.tokenizer = tokenizer
|
316 |
+
self.pad_token_id = tokenizer.pad_token_id
|
317 |
+
self.decoder_start_token_id = decoder_start_token_id
|
318 |
+
assert (
|
319 |
+
self.pad_token_id is not None
|
320 |
+
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
321 |
+
self.data_args = data_args
|
322 |
+
self.tpu_num_cores = tpu_num_cores
|
323 |
+
self.dataset_kwargs = (
|
324 |
+
{"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
325 |
+
)
|
326 |
+
if data_args.src_lang is not None:
|
327 |
+
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
328 |
+
if data_args.tgt_lang is not None:
|
329 |
+
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
330 |
+
|
331 |
+
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
332 |
+
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
333 |
+
batch = self._encode(batch)
|
334 |
+
input_ids, attention_mask, labels = (
|
335 |
+
batch["input_ids"],
|
336 |
+
batch["attention_mask"],
|
337 |
+
batch["labels"],
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
341 |
+
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
342 |
+
labels = torch.stack([x["labels"] for x in batch])
|
343 |
+
|
344 |
+
labels = trim_batch(labels, self.pad_token_id)
|
345 |
+
input_ids, attention_mask = trim_batch(
|
346 |
+
input_ids, self.pad_token_id, attention_mask=attention_mask
|
347 |
+
)
|
348 |
+
|
349 |
+
if isinstance(self.tokenizer, T5Tokenizer):
|
350 |
+
decoder_input_ids = self._shift_right_t5(labels)
|
351 |
+
else:
|
352 |
+
decoder_input_ids = shift_tokens_right(
|
353 |
+
labels, self.pad_token_id, self.decoder_start_token_id
|
354 |
+
)
|
355 |
+
|
356 |
+
batch = {
|
357 |
+
"input_ids": input_ids,
|
358 |
+
"attention_mask": attention_mask,
|
359 |
+
"decoder_input_ids": decoder_input_ids,
|
360 |
+
"labels": labels,
|
361 |
+
}
|
362 |
+
return batch
|
363 |
+
|
364 |
+
def _shift_right_t5(self, input_ids):
|
365 |
+
# shift inputs to the right
|
366 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
367 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
368 |
+
shifted_input_ids[..., 0] = self.pad_token_id
|
369 |
+
return shifted_input_ids
|
370 |
+
|
371 |
+
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
372 |
+
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
373 |
+
[x["src_texts"] for x in batch],
|
374 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
375 |
+
max_length=self.data_args.max_source_length,
|
376 |
+
max_target_length=self.data_args.max_target_length,
|
377 |
+
padding="max_length"
|
378 |
+
if self.tpu_num_cores is not None
|
379 |
+
else "longest", # TPU hack
|
380 |
+
return_tensors="pt",
|
381 |
+
**self.dataset_kwargs,
|
382 |
+
)
|
383 |
+
return batch_encoding.data
|
384 |
+
|
385 |
+
|
386 |
+
class SortishSampler(Sampler):
|
387 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
388 |
+
|
389 |
+
def __init__(self, data, batch_size, shuffle=True):
|
390 |
+
self.data, self.bs, self.shuffle = data, batch_size, shuffle
|
391 |
+
|
392 |
+
def __len__(self) -> int:
|
393 |
+
return len(self.data)
|
394 |
+
|
395 |
+
def __iter__(self):
|
396 |
+
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
397 |
+
|
398 |
+
|
399 |
+
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
400 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
401 |
+
if not shuffle:
|
402 |
+
return np.argsort(np.array(data) * -1)
|
403 |
+
|
404 |
+
def key_fn(i):
|
405 |
+
return data[i]
|
406 |
+
|
407 |
+
idxs = np.random.permutation(len(data))
|
408 |
+
sz = bs * 50
|
409 |
+
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
410 |
+
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
|
411 |
+
sz = bs
|
412 |
+
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
413 |
+
max_ck = np.argmax(
|
414 |
+
[key_fn(ck[0]) for ck in ck_idx]
|
415 |
+
) # find the chunk with the largest key,
|
416 |
+
ck_idx[0], ck_idx[max_ck] = (
|
417 |
+
ck_idx[max_ck],
|
418 |
+
ck_idx[0],
|
419 |
+
) # then make sure it goes first.
|
420 |
+
sort_idx = (
|
421 |
+
np.concatenate(np.random.permutation(ck_idx[1:]))
|
422 |
+
if len(ck_idx) > 1
|
423 |
+
else np.array([], dtype=np.int)
|
424 |
+
)
|
425 |
+
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
426 |
+
return sort_idx
|
427 |
+
|
428 |
+
|
429 |
+
class DistributedSortishSampler(Sampler):
|
430 |
+
"""Copied from torch DistributedSampler"""
|
431 |
+
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
dataset,
|
435 |
+
batch_size,
|
436 |
+
num_replicas=None,
|
437 |
+
rank=None,
|
438 |
+
add_extra_examples=True,
|
439 |
+
shuffle=True,
|
440 |
+
):
|
441 |
+
if num_replicas is None:
|
442 |
+
if not dist.is_available():
|
443 |
+
raise RuntimeError("Requires distributed package to be available")
|
444 |
+
num_replicas = dist.get_world_size()
|
445 |
+
if rank is None:
|
446 |
+
if not dist.is_available():
|
447 |
+
raise RuntimeError("Requires distributed package to be available")
|
448 |
+
rank = dist.get_rank()
|
449 |
+
self.dataset = dataset
|
450 |
+
self.num_replicas = num_replicas
|
451 |
+
self.rank = rank
|
452 |
+
self.epoch = 0
|
453 |
+
if add_extra_examples:
|
454 |
+
self.num_samples = int(
|
455 |
+
math.ceil(len(self.dataset) * 1.0 / self.num_replicas)
|
456 |
+
)
|
457 |
+
self.total_size = self.num_samples * self.num_replicas
|
458 |
+
else:
|
459 |
+
self.total_size = len(dataset)
|
460 |
+
self.num_samples = len(self.available_indices)
|
461 |
+
self.batch_size = batch_size
|
462 |
+
self.add_extra_examples = add_extra_examples
|
463 |
+
self.shuffle = shuffle
|
464 |
+
|
465 |
+
def __iter__(self) -> Iterable:
|
466 |
+
g = torch.Generator()
|
467 |
+
g.manual_seed(self.epoch)
|
468 |
+
|
469 |
+
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
470 |
+
sortish_indices = sortish_sampler_indices(
|
471 |
+
sortish_data, self.batch_size, shuffle=self.shuffle
|
472 |
+
)
|
473 |
+
indices = [self.available_indices[i] for i in sortish_indices]
|
474 |
+
assert len(indices) == self.num_samples
|
475 |
+
return iter(indices)
|
476 |
+
|
477 |
+
@cached_property
|
478 |
+
def available_indices(self) -> np.array:
|
479 |
+
indices = list(range(len(self.dataset)))
|
480 |
+
# add extra samples to make it evenly divisible
|
481 |
+
indices += indices[: (self.total_size - len(indices))]
|
482 |
+
assert len(indices) == self.total_size
|
483 |
+
# subsample
|
484 |
+
available_indices = indices[self.rank : self.total_size : self.num_replicas]
|
485 |
+
return available_indices
|
486 |
+
|
487 |
+
def __len__(self):
|
488 |
+
return self.num_samples
|
489 |
+
|
490 |
+
def set_epoch(self, epoch):
|
491 |
+
self.epoch = epoch
|
492 |
+
|
493 |
+
|
494 |
+
logger = getLogger(__name__)
|
495 |
+
|
496 |
+
|
497 |
+
def use_task_specific_params(model, task):
|
498 |
+
"""Update config with summarization specific params."""
|
499 |
+
task_specific_params = model.config.task_specific_params
|
500 |
+
|
501 |
+
if task_specific_params is not None:
|
502 |
+
pars = task_specific_params.get(task, {})
|
503 |
+
logger.info(
|
504 |
+
f"setting model.config to task specific params for {task}:\n {pars}"
|
505 |
+
)
|
506 |
+
logger.info("note: command line args may override some of these")
|
507 |
+
model.config.update(pars)
|
508 |
+
|
509 |
+
|
510 |
+
def pickle_load(path):
|
511 |
+
"""pickle.load(path)"""
|
512 |
+
with open(path, "rb") as f:
|
513 |
+
return pickle.load(f)
|
514 |
+
|
515 |
+
|
516 |
+
def pickle_save(obj, path):
|
517 |
+
"""pickle.dump(obj, path)"""
|
518 |
+
with open(path, "wb") as f:
|
519 |
+
return pickle.dump(obj, f)
|
520 |
+
|
521 |
+
|
522 |
+
def flatten_list(summary_ids: List[List]):
|
523 |
+
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
524 |
+
|
525 |
+
|
526 |
+
def save_git_info(folder_path: str) -> None:
|
527 |
+
"""Save git information to output_dir/git_log.json"""
|
528 |
+
repo_infos = get_git_info()
|
529 |
+
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
530 |
+
|
531 |
+
|
532 |
+
def save_json(content, path, indent=4, **json_dump_kwargs):
|
533 |
+
with open(path, "w") as f:
|
534 |
+
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
|
535 |
+
|
536 |
+
|
537 |
+
def load_json(path):
|
538 |
+
with open(path) as f:
|
539 |
+
return json.load(f)
|
540 |
+
|
541 |
+
|
542 |
+
def get_git_info():
|
543 |
+
try:
|
544 |
+
repo = git.Repo(search_parent_directories=True)
|
545 |
+
repo_infos = {
|
546 |
+
"repo_id": str(repo),
|
547 |
+
"repo_sha": str(repo.head.object.hexsha),
|
548 |
+
"repo_branch": str(repo.active_branch),
|
549 |
+
"hostname": str(socket.gethostname()),
|
550 |
+
}
|
551 |
+
return repo_infos
|
552 |
+
except TypeError:
|
553 |
+
return {
|
554 |
+
"repo_id": None,
|
555 |
+
"repo_sha": None,
|
556 |
+
"repo_branch": None,
|
557 |
+
"hostname": None,
|
558 |
+
}
|
559 |
+
|
560 |
+
|
561 |
+
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
562 |
+
|
563 |
+
|
564 |
+
def extract_rouge_mid_statistics(dct):
|
565 |
+
new_dict = {}
|
566 |
+
for k1, v1 in dct.items():
|
567 |
+
mid = v1.mid
|
568 |
+
new_dict[k1] = {
|
569 |
+
stat: round(getattr(mid, stat), 4)
|
570 |
+
for stat in ["precision", "recall", "fmeasure"]
|
571 |
+
}
|
572 |
+
return new_dict
|
573 |
+
|
574 |
+
|
575 |
+
def calculate_rouge(
|
576 |
+
pred_lns: List[str],
|
577 |
+
tgt_lns: List[str],
|
578 |
+
use_stemmer=True,
|
579 |
+
rouge_keys=ROUGE_KEYS,
|
580 |
+
return_precision_and_recall=False,
|
581 |
+
bootstrap_aggregation=True,
|
582 |
+
newline_sep=True,
|
583 |
+
) -> Dict:
|
584 |
+
"""Calculate rouge using rouge_scorer package.
|
585 |
+
|
586 |
+
Args:
|
587 |
+
pred_lns: list of summaries generated by model
|
588 |
+
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
|
589 |
+
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
590 |
+
strip word suffixes to improve matching.
|
591 |
+
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
|
592 |
+
return_precision_and_recall: (False) whether to also return precision and recall.
|
593 |
+
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
|
594 |
+
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
|
595 |
+
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
|
596 |
+
on multi sentence summaries (CNN/DM dataset).
|
597 |
+
|
598 |
+
Returns:
|
599 |
+
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
|
600 |
+
|
601 |
+
"""
|
602 |
+
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
|
603 |
+
aggregator = scoring.BootstrapAggregator()
|
604 |
+
for pred, tgt in zip(tgt_lns, pred_lns):
|
605 |
+
# rougeLsum expects "\n" separated sentences within a summary
|
606 |
+
if newline_sep:
|
607 |
+
pred = add_newline_to_end_of_each_sentence(pred)
|
608 |
+
tgt = add_newline_to_end_of_each_sentence(tgt)
|
609 |
+
scores = scorer.score(pred, tgt)
|
610 |
+
aggregator.add_scores(scores)
|
611 |
+
|
612 |
+
if bootstrap_aggregation:
|
613 |
+
result = aggregator.aggregate()
|
614 |
+
if return_precision_and_recall:
|
615 |
+
return extract_rouge_mid_statistics(result) # here we return dict
|
616 |
+
else:
|
617 |
+
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
618 |
+
|
619 |
+
else:
|
620 |
+
return aggregator._scores # here we return defaultdict(list)
|
621 |
+
|
622 |
+
|
623 |
+
# Utilities for freezing parameters and checking whether they are frozen
|
624 |
+
|
625 |
+
|
626 |
+
def freeze_params(model: nn.Module):
|
627 |
+
"""Set requires_grad=False for each of model.parameters()"""
|
628 |
+
for par in model.parameters():
|
629 |
+
par.requires_grad = False
|
630 |
+
|
631 |
+
|
632 |
+
def freeze_embeds(model):
|
633 |
+
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
634 |
+
model_type = model.config.model_type
|
635 |
+
|
636 |
+
if model_type in ["t5", "mt5"]:
|
637 |
+
freeze_params(model.shared)
|
638 |
+
for d in [model.encoder, model.decoder]:
|
639 |
+
freeze_params(d.embed_tokens)
|
640 |
+
elif model_type == "fsmt":
|
641 |
+
for d in [model.model.encoder, model.model.decoder]:
|
642 |
+
freeze_params(d.embed_positions)
|
643 |
+
freeze_params(d.embed_tokens)
|
644 |
+
else:
|
645 |
+
freeze_params(model.model.shared)
|
646 |
+
for d in [model.model.encoder, model.model.decoder]:
|
647 |
+
freeze_params(d.embed_positions)
|
648 |
+
freeze_params(d.embed_tokens)
|
649 |
+
|
650 |
+
|
651 |
+
def grad_status(model: nn.Module) -> Iterable:
|
652 |
+
return (par.requires_grad for par in model.parameters())
|
653 |
+
|
654 |
+
|
655 |
+
def any_requires_grad(model: nn.Module) -> bool:
|
656 |
+
return any(grad_status(model))
|
657 |
+
|
658 |
+
|
659 |
+
def assert_all_frozen(model):
|
660 |
+
model_grads: List[bool] = list(grad_status(model))
|
661 |
+
n_require_grad = sum(lmap(int, model_grads))
|
662 |
+
npars = len(model_grads)
|
663 |
+
assert not any(
|
664 |
+
model_grads
|
665 |
+
), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
666 |
+
|
667 |
+
|
668 |
+
def assert_not_all_frozen(model):
|
669 |
+
model_grads: List[bool] = list(grad_status(model))
|
670 |
+
npars = len(model_grads)
|
671 |
+
assert any(model_grads), f"none of {npars} weights require grad"
|
672 |
+
|
673 |
+
|
674 |
+
def parse_numeric_n_bool_cl_kwargs(
|
675 |
+
unparsed_args: List[str],
|
676 |
+
) -> Dict[str, Union[int, float, bool]]:
|
677 |
+
"""
|
678 |
+
Parse an argv list of unspecified command line args to a dict.
|
679 |
+
Assumes all values are either numeric or boolean in the form of true/false.
|
680 |
+
"""
|
681 |
+
result = {}
|
682 |
+
assert (
|
683 |
+
len(unparsed_args) % 2 == 0
|
684 |
+
), f"got odd number of unparsed args: {unparsed_args}"
|
685 |
+
num_pairs = len(unparsed_args) // 2
|
686 |
+
for pair_num in range(num_pairs):
|
687 |
+
i = 2 * pair_num
|
688 |
+
assert unparsed_args[i].startswith("--")
|
689 |
+
if unparsed_args[i + 1].lower() == "true":
|
690 |
+
value = True
|
691 |
+
elif unparsed_args[i + 1].lower() == "false":
|
692 |
+
value = False
|
693 |
+
else:
|
694 |
+
try:
|
695 |
+
value = int(unparsed_args[i + 1])
|
696 |
+
except ValueError:
|
697 |
+
value = float(
|
698 |
+
unparsed_args[i + 1]
|
699 |
+
) # this can raise another informative ValueError
|
700 |
+
|
701 |
+
result[unparsed_args[i][2:]] = value
|
702 |
+
return result
|
703 |
+
|
704 |
+
|
705 |
+
def write_txt_file(ordered_tgt, path):
|
706 |
+
f = Path(path).open("w")
|
707 |
+
for ln in ordered_tgt:
|
708 |
+
f.write(ln + "\n")
|
709 |
+
f.flush()
|
710 |
+
|
711 |
+
|
712 |
+
def chunks(lst, n):
|
713 |
+
"""Yield successive n-sized chunks from lst."""
|
714 |
+
for i in range(0, len(lst), n):
|
715 |
+
yield lst[i : i + n]
|
716 |
+
|
717 |
+
|
718 |
+
def check_output_dir(args, expected_items=0):
|
719 |
+
"""
|
720 |
+
Checks whether to bail out if output_dir already exists and has more than expected_items in it
|
721 |
+
|
722 |
+
`args`: needs to have the following attributes of `args`:
|
723 |
+
- output_dir
|
724 |
+
- do_train
|
725 |
+
- overwrite_output_dir
|
726 |
+
|
727 |
+
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
|
728 |
+
"""
|
729 |
+
if (
|
730 |
+
os.path.exists(args.output_dir)
|
731 |
+
and len(os.listdir(args.output_dir)) > expected_items
|
732 |
+
and args.do_train
|
733 |
+
and not args.overwrite_output_dir
|
734 |
+
):
|
735 |
+
raise ValueError(
|
736 |
+
f"Output directory ({args.output_dir}) already exists and "
|
737 |
+
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
738 |
+
"Use --overwrite_output_dir to overcome."
|
739 |
+
)
|