BerserkerMother commited on
Commit
ff8f746
·
1 Parent(s): 190015e

Adds initial files for seq2seq training

Browse files
elise/src/data/__init__.py ADDED
File without changes
elise/src/data/t5_dataset.py ADDED
File without changes
elise/src/excutors/__init__.py ADDED
File without changes
elise/src/excutors/trainer_seq2seq.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import get_scheduler
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
6
+ from transformers import DataCollatorForTokenClassification
7
+ from accelerate import Accelerator
8
+ import evaluate
9
+ import datasets
10
+
11
+ from tqdm.auto import tqdm
12
+
13
+
14
+ ner_tags = {
15
+ "O": 0,
16
+ "B-Rating": 1,
17
+ "I-Rating": 2,
18
+ "B-Amenity": 3,
19
+ "I-Amenity": 4,
20
+ "B-Location": 5,
21
+ "I-Location": 6,
22
+ "B-Restaurant_Name": 7,
23
+ "I-Restaurant_Name": 8,
24
+ "B-Price": 9,
25
+ "B-Hours": 10,
26
+ "I-Hours": 11,
27
+ "B-Dish": 12,
28
+ "I-Dish": 13,
29
+ "B-Cuisine": 14,
30
+ "I-Price": 15,
31
+ "I-Cuisine": 16,
32
+ }
33
+
34
+
35
+ label_names = {v: k for k, v in ner_tags.items()}
36
+
37
+ # dataset aggregation
38
+ dataset = load_dataset("tner/mit_restaurant")
39
+ dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["validation"]])
40
+ dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
41
+
42
+ print(dataset)
43
+
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ 'sentence-transformers/all-MiniLM-L6-v2')
47
+
48
+
49
+ def align_labels_with_tokens(labels, word_ids):
50
+ new_labels = []
51
+ current_word = None
52
+ for word_id in word_ids:
53
+ if word_id != current_word:
54
+ # Start of a new word!
55
+ current_word = word_id
56
+ label = -100 if word_id is None else labels[word_id]
57
+ new_labels.append(label)
58
+ elif word_id is None:
59
+ # Special token
60
+ new_labels.append(-100)
61
+ else:
62
+ # Same word as previous token
63
+ label = labels[word_id]
64
+ # If the label is B-XXX we change it to I-XXX
65
+ label_name = label_names[label]
66
+ if label_name.startswith("B"):
67
+ label = ner_tags["I" + label_name[1:]]
68
+ new_labels.append(label)
69
+
70
+ return new_labels
71
+
72
+
73
+ def tokenize_and_align_labels(examples):
74
+ tokenized_inputs = tokenizer(
75
+ examples["tokens"], truncation=True, is_split_into_words=True
76
+ )
77
+ all_labels = examples["tags"]
78
+ new_labels = []
79
+ for i, labels in enumerate(all_labels):
80
+ word_ids = tokenized_inputs.word_ids(i)
81
+ new_labels.append(align_labels_with_tokens(labels, word_ids))
82
+
83
+ tokenized_inputs["labels"] = new_labels
84
+ return tokenized_inputs
85
+
86
+
87
+ tokenized_datasets = dataset.map(
88
+ tokenize_and_align_labels,
89
+ batched=True,
90
+ remove_columns=dataset["train"].column_names,
91
+ )
92
+
93
+
94
+ def train():
95
+ metric = evaluate.load("seqeval")
96
+ data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
97
+
98
+ train_dataloader = DataLoader(
99
+ tokenized_datasets["train"],
100
+ shuffle=True,
101
+ collate_fn=data_collator,
102
+ batch_size=128,
103
+ )
104
+ eval_dataloader = DataLoader(
105
+ tokenized_datasets["test"],
106
+ collate_fn=data_collator,
107
+ batch_size=8
108
+ )
109
+
110
+ model = AutoModelForTokenClassification.from_pretrained(
111
+ 'sentence-transformers/all-MiniLM-L6-v2',
112
+ id2label=label_names,
113
+ label2id=ner_tags,
114
+ )
115
+
116
+ optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
117
+
118
+ accelerator = Accelerator()
119
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
120
+ model, optimizer, train_dataloader, eval_dataloader
121
+ )
122
+
123
+ num_train_epochs = 50
124
+ num_update_steps_per_epoch = len(train_dataloader)
125
+ num_training_steps = num_train_epochs * num_update_steps_per_epoch
126
+
127
+ lr_scheduler = get_scheduler(
128
+ "linear",
129
+ optimizer=optimizer,
130
+ num_warmup_steps=0,
131
+ num_training_steps=num_training_steps,
132
+ )
133
+
134
+ def postprocess(predictions, labels):
135
+ predictions = predictions.detach().cpu().clone().numpy()
136
+ labels = labels.detach().cpu().clone().numpy()
137
+
138
+ # Remove ignored index (special tokens) and convert to labels
139
+ true_labels = [[label_names[l] for l in label if l != -100]
140
+ for label in labels]
141
+ true_predictions = [
142
+ [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
143
+ for prediction, label in zip(predictions, labels)
144
+ ]
145
+ return true_labels, true_predictions
146
+
147
+ progress_bar = tqdm(range(num_training_steps))
148
+
149
+ for epoch in range(num_train_epochs):
150
+ # Training
151
+ model.train()
152
+ for batch in train_dataloader:
153
+ outputs = model(**batch)
154
+ loss = outputs.loss
155
+ accelerator.backward(loss)
156
+
157
+ optimizer.step()
158
+ lr_scheduler.step()
159
+ optimizer.zero_grad()
160
+ progress_bar.update(1)
161
+
162
+ # Evaluation
163
+ model.eval()
164
+ for batch in eval_dataloader:
165
+ with torch.no_grad():
166
+ outputs = model(**batch)
167
+
168
+ predictions = outputs.logits.argmax(dim=-1)
169
+ labels = batch["labels"]
170
+
171
+ # Necessary to pad predictions and labels for being gathered
172
+ predictions = accelerator.pad_across_processes(
173
+ predictions, dim=1, pad_index=-100)
174
+ labels = accelerator.pad_across_processes(
175
+ labels, dim=1, pad_index=-100)
176
+
177
+ predictions_gathered = accelerator.gather(predictions)
178
+ labels_gathered = accelerator.gather(labels)
179
+
180
+ true_predictions, true_labels = postprocess(
181
+ predictions_gathered, labels_gathered)
182
+ metric.add_batch(predictions=true_predictions,
183
+ references=true_labels)
184
+
185
+ results = metric.compute()
186
+ print(
187
+ f"epoch {epoch}:",
188
+ {
189
+ key: results[f"overall_{key}"]
190
+ for key in ["precision", "recall", "f1", "accuracy"]
191
+ },
192
+ )
193
+
194
+ output_dir = "restaurant_ner"
195
+ # Save and upload
196
+ accelerator.wait_for_everyone()
197
+ unwrapped_model = accelerator.unwrap_model(model)
198
+ unwrapped_model.save_pretrained(
199
+ output_dir, save_function=accelerator.save)
200
+ if accelerator.is_main_process:
201
+ tokenizer.save_pretrained(output_dir)
202
+
203
+ accelerator.wait_for_everyone()
204
+ unwrapped_model = accelerator.unwrap_model(model)
205
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
206
+
207
+
208
+ train()
elise/src/models/__init__.py ADDED
File without changes