Spaces:
Runtime error
Runtime error
BerserkerMother
commited on
Commit
·
ff8f746
1
Parent(s):
190015e
Adds initial files for seq2seq training
Browse files
elise/src/data/__init__.py
ADDED
File without changes
|
elise/src/data/t5_dataset.py
ADDED
File without changes
|
elise/src/excutors/__init__.py
ADDED
File without changes
|
elise/src/excutors/trainer_seq2seq.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import get_scheduler
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from datasets import load_dataset
|
5 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
6 |
+
from transformers import DataCollatorForTokenClassification
|
7 |
+
from accelerate import Accelerator
|
8 |
+
import evaluate
|
9 |
+
import datasets
|
10 |
+
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
ner_tags = {
|
15 |
+
"O": 0,
|
16 |
+
"B-Rating": 1,
|
17 |
+
"I-Rating": 2,
|
18 |
+
"B-Amenity": 3,
|
19 |
+
"I-Amenity": 4,
|
20 |
+
"B-Location": 5,
|
21 |
+
"I-Location": 6,
|
22 |
+
"B-Restaurant_Name": 7,
|
23 |
+
"I-Restaurant_Name": 8,
|
24 |
+
"B-Price": 9,
|
25 |
+
"B-Hours": 10,
|
26 |
+
"I-Hours": 11,
|
27 |
+
"B-Dish": 12,
|
28 |
+
"I-Dish": 13,
|
29 |
+
"B-Cuisine": 14,
|
30 |
+
"I-Price": 15,
|
31 |
+
"I-Cuisine": 16,
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
label_names = {v: k for k, v in ner_tags.items()}
|
36 |
+
|
37 |
+
# dataset aggregation
|
38 |
+
dataset = load_dataset("tner/mit_restaurant")
|
39 |
+
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["validation"]])
|
40 |
+
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
|
41 |
+
|
42 |
+
print(dataset)
|
43 |
+
|
44 |
+
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
46 |
+
'sentence-transformers/all-MiniLM-L6-v2')
|
47 |
+
|
48 |
+
|
49 |
+
def align_labels_with_tokens(labels, word_ids):
|
50 |
+
new_labels = []
|
51 |
+
current_word = None
|
52 |
+
for word_id in word_ids:
|
53 |
+
if word_id != current_word:
|
54 |
+
# Start of a new word!
|
55 |
+
current_word = word_id
|
56 |
+
label = -100 if word_id is None else labels[word_id]
|
57 |
+
new_labels.append(label)
|
58 |
+
elif word_id is None:
|
59 |
+
# Special token
|
60 |
+
new_labels.append(-100)
|
61 |
+
else:
|
62 |
+
# Same word as previous token
|
63 |
+
label = labels[word_id]
|
64 |
+
# If the label is B-XXX we change it to I-XXX
|
65 |
+
label_name = label_names[label]
|
66 |
+
if label_name.startswith("B"):
|
67 |
+
label = ner_tags["I" + label_name[1:]]
|
68 |
+
new_labels.append(label)
|
69 |
+
|
70 |
+
return new_labels
|
71 |
+
|
72 |
+
|
73 |
+
def tokenize_and_align_labels(examples):
|
74 |
+
tokenized_inputs = tokenizer(
|
75 |
+
examples["tokens"], truncation=True, is_split_into_words=True
|
76 |
+
)
|
77 |
+
all_labels = examples["tags"]
|
78 |
+
new_labels = []
|
79 |
+
for i, labels in enumerate(all_labels):
|
80 |
+
word_ids = tokenized_inputs.word_ids(i)
|
81 |
+
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
82 |
+
|
83 |
+
tokenized_inputs["labels"] = new_labels
|
84 |
+
return tokenized_inputs
|
85 |
+
|
86 |
+
|
87 |
+
tokenized_datasets = dataset.map(
|
88 |
+
tokenize_and_align_labels,
|
89 |
+
batched=True,
|
90 |
+
remove_columns=dataset["train"].column_names,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def train():
|
95 |
+
metric = evaluate.load("seqeval")
|
96 |
+
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
97 |
+
|
98 |
+
train_dataloader = DataLoader(
|
99 |
+
tokenized_datasets["train"],
|
100 |
+
shuffle=True,
|
101 |
+
collate_fn=data_collator,
|
102 |
+
batch_size=128,
|
103 |
+
)
|
104 |
+
eval_dataloader = DataLoader(
|
105 |
+
tokenized_datasets["test"],
|
106 |
+
collate_fn=data_collator,
|
107 |
+
batch_size=8
|
108 |
+
)
|
109 |
+
|
110 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
111 |
+
'sentence-transformers/all-MiniLM-L6-v2',
|
112 |
+
id2label=label_names,
|
113 |
+
label2id=ner_tags,
|
114 |
+
)
|
115 |
+
|
116 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
|
117 |
+
|
118 |
+
accelerator = Accelerator()
|
119 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
120 |
+
model, optimizer, train_dataloader, eval_dataloader
|
121 |
+
)
|
122 |
+
|
123 |
+
num_train_epochs = 50
|
124 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
125 |
+
num_training_steps = num_train_epochs * num_update_steps_per_epoch
|
126 |
+
|
127 |
+
lr_scheduler = get_scheduler(
|
128 |
+
"linear",
|
129 |
+
optimizer=optimizer,
|
130 |
+
num_warmup_steps=0,
|
131 |
+
num_training_steps=num_training_steps,
|
132 |
+
)
|
133 |
+
|
134 |
+
def postprocess(predictions, labels):
|
135 |
+
predictions = predictions.detach().cpu().clone().numpy()
|
136 |
+
labels = labels.detach().cpu().clone().numpy()
|
137 |
+
|
138 |
+
# Remove ignored index (special tokens) and convert to labels
|
139 |
+
true_labels = [[label_names[l] for l in label if l != -100]
|
140 |
+
for label in labels]
|
141 |
+
true_predictions = [
|
142 |
+
[label_names[p] for (p, l) in zip(prediction, label) if l != -100]
|
143 |
+
for prediction, label in zip(predictions, labels)
|
144 |
+
]
|
145 |
+
return true_labels, true_predictions
|
146 |
+
|
147 |
+
progress_bar = tqdm(range(num_training_steps))
|
148 |
+
|
149 |
+
for epoch in range(num_train_epochs):
|
150 |
+
# Training
|
151 |
+
model.train()
|
152 |
+
for batch in train_dataloader:
|
153 |
+
outputs = model(**batch)
|
154 |
+
loss = outputs.loss
|
155 |
+
accelerator.backward(loss)
|
156 |
+
|
157 |
+
optimizer.step()
|
158 |
+
lr_scheduler.step()
|
159 |
+
optimizer.zero_grad()
|
160 |
+
progress_bar.update(1)
|
161 |
+
|
162 |
+
# Evaluation
|
163 |
+
model.eval()
|
164 |
+
for batch in eval_dataloader:
|
165 |
+
with torch.no_grad():
|
166 |
+
outputs = model(**batch)
|
167 |
+
|
168 |
+
predictions = outputs.logits.argmax(dim=-1)
|
169 |
+
labels = batch["labels"]
|
170 |
+
|
171 |
+
# Necessary to pad predictions and labels for being gathered
|
172 |
+
predictions = accelerator.pad_across_processes(
|
173 |
+
predictions, dim=1, pad_index=-100)
|
174 |
+
labels = accelerator.pad_across_processes(
|
175 |
+
labels, dim=1, pad_index=-100)
|
176 |
+
|
177 |
+
predictions_gathered = accelerator.gather(predictions)
|
178 |
+
labels_gathered = accelerator.gather(labels)
|
179 |
+
|
180 |
+
true_predictions, true_labels = postprocess(
|
181 |
+
predictions_gathered, labels_gathered)
|
182 |
+
metric.add_batch(predictions=true_predictions,
|
183 |
+
references=true_labels)
|
184 |
+
|
185 |
+
results = metric.compute()
|
186 |
+
print(
|
187 |
+
f"epoch {epoch}:",
|
188 |
+
{
|
189 |
+
key: results[f"overall_{key}"]
|
190 |
+
for key in ["precision", "recall", "f1", "accuracy"]
|
191 |
+
},
|
192 |
+
)
|
193 |
+
|
194 |
+
output_dir = "restaurant_ner"
|
195 |
+
# Save and upload
|
196 |
+
accelerator.wait_for_everyone()
|
197 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
198 |
+
unwrapped_model.save_pretrained(
|
199 |
+
output_dir, save_function=accelerator.save)
|
200 |
+
if accelerator.is_main_process:
|
201 |
+
tokenizer.save_pretrained(output_dir)
|
202 |
+
|
203 |
+
accelerator.wait_for_everyone()
|
204 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
205 |
+
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
|
206 |
+
|
207 |
+
|
208 |
+
train()
|
elise/src/models/__init__.py
ADDED
File without changes
|