Spaces:
Runtime error
Runtime error
j.gilyazev
commited on
Commit
·
c1c5bd9
1
Parent(s):
0766044
add personalized-chat-bot
Browse files
personalized-chat-bot/util/__init__.py
ADDED
File without changes
|
personalized-chat-bot/util/bloom_trainer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from transformers import get_scheduler
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
from util.metrics import perplexity
|
11 |
+
|
12 |
+
|
13 |
+
class BloomTrainer:
|
14 |
+
DEFAULT_VAL_FREQ = 5
|
15 |
+
ITERATION_LIMIT = 150
|
16 |
+
|
17 |
+
def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
|
18 |
+
self.model = model
|
19 |
+
self.config = config
|
20 |
+
self.train_dataset = train_dataset
|
21 |
+
self.val_dataset = val_dataset
|
22 |
+
self.wandb_run = wandb_run
|
23 |
+
self.val_freq = val_freq
|
24 |
+
if self.val_freq is None:
|
25 |
+
self.val_freq = self.DEFAULT_VAL_FREQ
|
26 |
+
self.prompt_path = prompt_path
|
27 |
+
|
28 |
+
self.best_loss = np.inf
|
29 |
+
|
30 |
+
self.train_loader = DataLoader(self.train_dataset,
|
31 |
+
shuffle=True,
|
32 |
+
batch_size=config.BATCH_SIZE,
|
33 |
+
drop_last=True)
|
34 |
+
self.val_loader = DataLoader(self.val_dataset,
|
35 |
+
shuffle=True,
|
36 |
+
batch_size=config.BATCH_SIZE,
|
37 |
+
drop_last=False)
|
38 |
+
|
39 |
+
self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
|
40 |
+
|
41 |
+
self.lr_scheduler = get_scheduler(
|
42 |
+
name="linear",
|
43 |
+
optimizer=self.optimizer,
|
44 |
+
num_warmup_steps=0,
|
45 |
+
num_training_steps= len(self.train_loader) * self.config.N_EPOCH
|
46 |
+
)
|
47 |
+
|
48 |
+
def train(self):
|
49 |
+
self.model.train()
|
50 |
+
iter_counter = 0
|
51 |
+
for epoch in range(self.config.N_EPOCH):
|
52 |
+
for batch in self.train_loader:
|
53 |
+
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
|
54 |
+
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
|
55 |
+
outputs = self.model(**batch)
|
56 |
+
loss = outputs.loss
|
57 |
+
loss.backward()
|
58 |
+
self.optimizer.step()
|
59 |
+
self.lr_scheduler.step()
|
60 |
+
self.optimizer.zero_grad()
|
61 |
+
self.wandb_run.log({'loss': loss})
|
62 |
+
iter_counter += 1
|
63 |
+
if (iter_counter + 1) % self.val_freq == 0:
|
64 |
+
eval_perplexity = self.evaluate(perplexity)
|
65 |
+
self.wandb_run.log({'perplexity': eval_perplexity})
|
66 |
+
if loss.item() < self.best_loss:
|
67 |
+
self.best_loss = loss.item()
|
68 |
+
self.save_model(self.prompt_path)
|
69 |
+
print('Model saved')
|
70 |
+
if iter_counter >= self.ITERATION_LIMIT:
|
71 |
+
return
|
72 |
+
|
73 |
+
def evaluate(self, eval_fn):
|
74 |
+
logits = []
|
75 |
+
labels = []
|
76 |
+
self.model.eval()
|
77 |
+
with torch.no_grad():
|
78 |
+
for batch in self.val_loader:
|
79 |
+
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
|
80 |
+
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
|
81 |
+
outputs = self.model(**batch)
|
82 |
+
labels.extend(batch['input_ids'])
|
83 |
+
logits.extend(outputs.logits)
|
84 |
+
metric = eval_fn(logits, labels)
|
85 |
+
return metric
|
86 |
+
|
87 |
+
def save_model(self, path):
|
88 |
+
torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)
|
89 |
+
|
90 |
+
def load_model(self, path):
|
91 |
+
self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
|
personalized-chat-bot/util/data.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from joblib import Parallel, delayed
|
5 |
+
|
6 |
+
|
7 |
+
class OnePersonaDataset(Dataset):
|
8 |
+
def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.data = data
|
12 |
+
if len(data) == 0:
|
13 |
+
self.input_ids = []
|
14 |
+
self.history = []
|
15 |
+
self.labels = []
|
16 |
+
return
|
17 |
+
|
18 |
+
if positive_candidates:
|
19 |
+
self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
|
20 |
+
self.labels = np.ones(len(self.history), dtype=int)
|
21 |
+
else:
|
22 |
+
self.history = [row['history'] + [candidate, ] for row in data
|
23 |
+
for candidate in row['candidates']]
|
24 |
+
self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
|
25 |
+
for row in data)
|
26 |
+
self.labels = np.array(self.labels, dtype=int)
|
27 |
+
|
28 |
+
if transforms is None:
|
29 |
+
self.history = ["\n".join(item) for item in self.history]
|
30 |
+
else:
|
31 |
+
self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
|
32 |
+
self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]
|
33 |
+
|
34 |
+
def __getitem__(self, idx):
|
35 |
+
return {'input_ids': self.input_ids[idx],
|
36 |
+
'labels': self.input_ids[idx],
|
37 |
+
'example': self.history[idx],
|
38 |
+
'class': self.labels[idx]}
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.data)
|
42 |
+
|
43 |
+
|
44 |
+
class PersonaChatDataset(Dataset):
|
45 |
+
DEFAULT_DATASET_NAME = "bavard/personachat_truecased"
|
46 |
+
|
47 |
+
def __init__(self, clustering, dataset, tokenizer):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.dataset = dataset
|
51 |
+
self.clustering = clustering
|
52 |
+
|
53 |
+
all_personalities = list(set([sent for item in self.dataset
|
54 |
+
for sent in item['personality']]))
|
55 |
+
predicted_centers = self.clustering.predict(all_personalities)
|
56 |
+
self.all_personalities_to_id = {persona: center
|
57 |
+
for persona, center in zip(all_personalities, predicted_centers)}
|
58 |
+
self.personalities = self.clustering._cluster_centers
|
59 |
+
|
60 |
+
subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]
|
61 |
+
|
62 |
+
for i in range(len(self.dataset)):
|
63 |
+
item = self.dataset[i]
|
64 |
+
cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
|
65 |
+
for persona_id in cur_persona_ids:
|
66 |
+
subdataset_data_by_personality[persona_id].append(item)
|
67 |
+
|
68 |
+
self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]
|
69 |
+
|
70 |
+
def __getitem__(self, persona_id):
|
71 |
+
return self.subdatasets[persona_id]
|
72 |
+
|
73 |
+
def __len__(self, ):
|
74 |
+
return len(self.datasets)
|
personalized-chat-bot/util/dialogue_manager.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DistilBertForSequenceClassification
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class DialogueManagerModel(nn.Module):
|
5 |
+
DEFAULT_MODEL = "distilbert-base-uncased"
|
6 |
+
|
7 |
+
def __init__(self, n_classes, model_name=None, device='cpu'):
|
8 |
+
super().__init__()
|
9 |
+
if model_name is None:
|
10 |
+
self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL)
|
11 |
+
else:
|
12 |
+
raise NotImplementedError()
|
13 |
+
self.model.to(device)
|
14 |
+
self.n_classes = n_classes
|
15 |
+
self.freeze_layers()
|
16 |
+
self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes,
|
17 |
+
device=device)
|
18 |
+
|
19 |
+
for param in self.model.classifier.parameters():
|
20 |
+
param.requires_grad = True
|
21 |
+
|
22 |
+
def freeze_layers(self):
|
23 |
+
for param in self.model.parameters():
|
24 |
+
param.requires_grad = False
|
25 |
+
|
26 |
+
def forward(self, X):
|
27 |
+
return self.model(X)
|
personalized-chat-bot/util/metrics.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def _perplexity(logits, labels, pad_token=3):
|
7 |
+
for i in range(len(labels)-1, -1, -1):
|
8 |
+
if labels[i] != pad_token:
|
9 |
+
last_not_pad_id = i
|
10 |
+
break
|
11 |
+
logits = logits[:last_not_pad_id + 1]
|
12 |
+
labels = labels[:last_not_pad_id + 1]
|
13 |
+
log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
|
14 |
+
log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
|
15 |
+
l = np.mean(log_probas)
|
16 |
+
return 2 ** (-l)
|
17 |
+
|
18 |
+
|
19 |
+
def perplexity(logits, labels, pad_token=3):
|
20 |
+
pp = []
|
21 |
+
if isinstance(logits, torch.Tensor):
|
22 |
+
logits = logits.detach().cpu().numpy()
|
23 |
+
if isinstance(labels, torch.Tensor):
|
24 |
+
labels = labels.detach().cpu().numpy()
|
25 |
+
for cur_logits, cur_labels in zip(logits, labels):
|
26 |
+
pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
|
27 |
+
return np.mean(pp)
|