j.gilyazev commited on
Commit
c1c5bd9
·
1 Parent(s): 0766044

add personalized-chat-bot

Browse files
personalized-chat-bot/util/__init__.py ADDED
File without changes
personalized-chat-bot/util/bloom_trainer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader
5
+ from torch.optim import AdamW
6
+ from transformers import get_scheduler
7
+ import torch
8
+
9
+
10
+ from util.metrics import perplexity
11
+
12
+
13
+ class BloomTrainer:
14
+ DEFAULT_VAL_FREQ = 5
15
+ ITERATION_LIMIT = 150
16
+
17
+ def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
18
+ self.model = model
19
+ self.config = config
20
+ self.train_dataset = train_dataset
21
+ self.val_dataset = val_dataset
22
+ self.wandb_run = wandb_run
23
+ self.val_freq = val_freq
24
+ if self.val_freq is None:
25
+ self.val_freq = self.DEFAULT_VAL_FREQ
26
+ self.prompt_path = prompt_path
27
+
28
+ self.best_loss = np.inf
29
+
30
+ self.train_loader = DataLoader(self.train_dataset,
31
+ shuffle=True,
32
+ batch_size=config.BATCH_SIZE,
33
+ drop_last=True)
34
+ self.val_loader = DataLoader(self.val_dataset,
35
+ shuffle=True,
36
+ batch_size=config.BATCH_SIZE,
37
+ drop_last=False)
38
+
39
+ self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
40
+
41
+ self.lr_scheduler = get_scheduler(
42
+ name="linear",
43
+ optimizer=self.optimizer,
44
+ num_warmup_steps=0,
45
+ num_training_steps= len(self.train_loader) * self.config.N_EPOCH
46
+ )
47
+
48
+ def train(self):
49
+ self.model.train()
50
+ iter_counter = 0
51
+ for epoch in range(self.config.N_EPOCH):
52
+ for batch in self.train_loader:
53
+ batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
54
+ 'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
55
+ outputs = self.model(**batch)
56
+ loss = outputs.loss
57
+ loss.backward()
58
+ self.optimizer.step()
59
+ self.lr_scheduler.step()
60
+ self.optimizer.zero_grad()
61
+ self.wandb_run.log({'loss': loss})
62
+ iter_counter += 1
63
+ if (iter_counter + 1) % self.val_freq == 0:
64
+ eval_perplexity = self.evaluate(perplexity)
65
+ self.wandb_run.log({'perplexity': eval_perplexity})
66
+ if loss.item() < self.best_loss:
67
+ self.best_loss = loss.item()
68
+ self.save_model(self.prompt_path)
69
+ print('Model saved')
70
+ if iter_counter >= self.ITERATION_LIMIT:
71
+ return
72
+
73
+ def evaluate(self, eval_fn):
74
+ logits = []
75
+ labels = []
76
+ self.model.eval()
77
+ with torch.no_grad():
78
+ for batch in self.val_loader:
79
+ batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
80
+ 'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
81
+ outputs = self.model(**batch)
82
+ labels.extend(batch['input_ids'])
83
+ logits.extend(outputs.logits)
84
+ metric = eval_fn(logits, labels)
85
+ return metric
86
+
87
+ def save_model(self, path):
88
+ torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)
89
+
90
+ def load_model(self, path):
91
+ self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
personalized-chat-bot/util/data.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from joblib import Parallel, delayed
5
+
6
+
7
+ class OnePersonaDataset(Dataset):
8
+ def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
9
+ super().__init__()
10
+
11
+ self.data = data
12
+ if len(data) == 0:
13
+ self.input_ids = []
14
+ self.history = []
15
+ self.labels = []
16
+ return
17
+
18
+ if positive_candidates:
19
+ self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
20
+ self.labels = np.ones(len(self.history), dtype=int)
21
+ else:
22
+ self.history = [row['history'] + [candidate, ] for row in data
23
+ for candidate in row['candidates']]
24
+ self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
25
+ for row in data)
26
+ self.labels = np.array(self.labels, dtype=int)
27
+
28
+ if transforms is None:
29
+ self.history = ["\n".join(item) for item in self.history]
30
+ else:
31
+ self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
32
+ self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]
33
+
34
+ def __getitem__(self, idx):
35
+ return {'input_ids': self.input_ids[idx],
36
+ 'labels': self.input_ids[idx],
37
+ 'example': self.history[idx],
38
+ 'class': self.labels[idx]}
39
+
40
+ def __len__(self):
41
+ return len(self.data)
42
+
43
+
44
+ class PersonaChatDataset(Dataset):
45
+ DEFAULT_DATASET_NAME = "bavard/personachat_truecased"
46
+
47
+ def __init__(self, clustering, dataset, tokenizer):
48
+ super().__init__()
49
+
50
+ self.dataset = dataset
51
+ self.clustering = clustering
52
+
53
+ all_personalities = list(set([sent for item in self.dataset
54
+ for sent in item['personality']]))
55
+ predicted_centers = self.clustering.predict(all_personalities)
56
+ self.all_personalities_to_id = {persona: center
57
+ for persona, center in zip(all_personalities, predicted_centers)}
58
+ self.personalities = self.clustering._cluster_centers
59
+
60
+ subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]
61
+
62
+ for i in range(len(self.dataset)):
63
+ item = self.dataset[i]
64
+ cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
65
+ for persona_id in cur_persona_ids:
66
+ subdataset_data_by_personality[persona_id].append(item)
67
+
68
+ self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]
69
+
70
+ def __getitem__(self, persona_id):
71
+ return self.subdatasets[persona_id]
72
+
73
+ def __len__(self, ):
74
+ return len(self.datasets)
personalized-chat-bot/util/dialogue_manager.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertForSequenceClassification
2
+ from torch import nn
3
+
4
+ class DialogueManagerModel(nn.Module):
5
+ DEFAULT_MODEL = "distilbert-base-uncased"
6
+
7
+ def __init__(self, n_classes, model_name=None, device='cpu'):
8
+ super().__init__()
9
+ if model_name is None:
10
+ self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL)
11
+ else:
12
+ raise NotImplementedError()
13
+ self.model.to(device)
14
+ self.n_classes = n_classes
15
+ self.freeze_layers()
16
+ self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes,
17
+ device=device)
18
+
19
+ for param in self.model.classifier.parameters():
20
+ param.requires_grad = True
21
+
22
+ def freeze_layers(self):
23
+ for param in self.model.parameters():
24
+ param.requires_grad = False
25
+
26
+ def forward(self, X):
27
+ return self.model(X)
personalized-chat-bot/util/metrics.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+
5
+
6
+ def _perplexity(logits, labels, pad_token=3):
7
+ for i in range(len(labels)-1, -1, -1):
8
+ if labels[i] != pad_token:
9
+ last_not_pad_id = i
10
+ break
11
+ logits = logits[:last_not_pad_id + 1]
12
+ labels = labels[:last_not_pad_id + 1]
13
+ log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
14
+ log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
15
+ l = np.mean(log_probas)
16
+ return 2 ** (-l)
17
+
18
+
19
+ def perplexity(logits, labels, pad_token=3):
20
+ pp = []
21
+ if isinstance(logits, torch.Tensor):
22
+ logits = logits.detach().cpu().numpy()
23
+ if isinstance(labels, torch.Tensor):
24
+ labels = labels.detach().cpu().numpy()
25
+ for cur_logits, cur_labels in zip(logits, labels):
26
+ pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
27
+ return np.mean(pp)