|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from logging import basicConfig |
|
import torch |
|
from torch import nn |
|
import json |
|
from tqdm import tqdm |
|
import os |
|
import numpy as np |
|
from transformers import BertTokenizer |
|
import pytorch_lightning as pl |
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning import trainer, loggers |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers.optimization import get_linear_schedule_with_warmup |
|
from transformers import BertForMaskedLM |
|
from transformers import AutoConfig |
|
from transformers.pipelines.base import Pipeline |
|
from transformers import MegatronBertForMaskedLM |
|
import argparse |
|
import copy |
|
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
|
import warnings |
|
from transformers import TextClassificationPipeline as HuggingfacePipe |
|
|
|
|
|
class TCBertDataset(Dataset): |
|
def __init__(self, data, tokenizer, args, prompt, label_classes): |
|
super().__init__() |
|
|
|
self.tokenizer = tokenizer |
|
self.max_length = args.max_length |
|
self.num_labels = args.num_labels |
|
self.data = data |
|
self.args = args |
|
self.label_classes = label_classes |
|
self.prompt = prompt |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.encode(self.data[index]) |
|
|
|
|
|
def encode(self, item, labeled=True): |
|
|
|
if labeled: |
|
ori_texta = self.prompt.format(item['label']) + item['content'] |
|
mask_texta = self.prompt.format("[MASK]" * len(item['label'])) + item['content'] |
|
|
|
labels = self.label_classes[item['label']] |
|
|
|
ori_encode_dict = self.tokenizer.encode_plus(ori_texta, |
|
max_length=self.max_length, |
|
padding="longest", |
|
truncation=True |
|
) |
|
|
|
mask_encode_dict = self.tokenizer.encode_plus(mask_texta, |
|
max_length=self.max_length, |
|
padding="longest", |
|
truncation=True |
|
) |
|
|
|
ori_input_ids = torch.tensor(ori_encode_dict['input_ids']).long() |
|
token_type_ids = torch.tensor(ori_encode_dict['token_type_ids']).long() |
|
attention_mask = torch.tensor(ori_encode_dict['attention_mask']).float() |
|
|
|
mask_input_ids = torch.tensor(mask_encode_dict['input_ids']).long() |
|
mlmlabels = torch.where(mask_input_ids == self.tokenizer.mask_token_id, ori_input_ids, -100) |
|
|
|
labels = torch.tensor(labels).long() |
|
mlmlabels = torch.tensor(mlmlabels).long() |
|
|
|
encoded = { |
|
"sentence": item["content"], |
|
"input_ids": mask_input_ids, |
|
"token_type_ids": token_type_ids, |
|
"attention_mask": attention_mask, |
|
"labels": labels, |
|
"mlmlabels": mlmlabels, |
|
} |
|
|
|
else: |
|
|
|
texta = self.prompt.format("[MASK]" * self.args.fixed_lablen) + item['content'] |
|
|
|
encode_dict = self.tokenizer.encode_plus(texta, |
|
max_length=self.max_length, |
|
padding="longest", |
|
truncation=True |
|
) |
|
|
|
input_ids = encode_dict['input_ids'] |
|
token_type_ids = encode_dict['token_type_ids'] |
|
attention_mask = encode_dict['attention_mask'] |
|
|
|
encoded = { |
|
"sentence": item["content"], |
|
"input_ids": torch.tensor(input_ids).long(), |
|
"token_type_ids": torch.tensor(token_type_ids).long(), |
|
"attention_mask": torch.tensor(attention_mask).float(), |
|
} |
|
return encoded |
|
|
|
|
|
|
|
class TCBertDataModel(pl.LightningDataModule): |
|
@staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('TASK NAME DataModel') |
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--batchsize', default=16, type=int) |
|
parser.add_argument('--max_length', default=512, type=int) |
|
parser.add_argument('--fixed_lablen', default=2, type=int) |
|
return parent_args |
|
|
|
def __init__(self, train_data, val_data, tokenizer, args, prompt, prompt_label): |
|
super().__init__() |
|
self.batchsize = args.batchsize |
|
self.label_classes = self.get_label_classes(prompt_label) |
|
args.num_labels = len(self.label_classes) |
|
|
|
self.train_data = TCBertDataset(train_data, tokenizer, args, prompt, self.label_classes) |
|
self.valid_data = TCBertDataset(val_data, tokenizer, args, prompt, self.label_classes) |
|
|
|
def get_label_classes(self, prompt_label): |
|
label_classes = {} |
|
i = 0 |
|
for key in prompt_label.keys(): |
|
label_classes[key] = i |
|
i+=1 |
|
print("label_classes:",label_classes) |
|
return label_classes |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_data, shuffle=True, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valid_data, shuffle=False, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False) |
|
|
|
def collate_fn(self, batch): |
|
''' |
|
Aggregate a batch data. |
|
batch = [ins1_dict, ins2_dict, ..., insN_dict] |
|
batch_data = {'sentence':[ins1_sentence, ins2_sentence...], 'input_ids':[ins1_input_ids, ins2_input_ids...], ...} |
|
''' |
|
batch_data = {} |
|
for key in batch[0]: |
|
batch_data[key] = [example[key] for example in batch] |
|
input_ids = batch_data['input_ids'] |
|
attention_mask = batch_data['attention_mask'] |
|
token_type_ids = batch_data["token_type_ids"] |
|
labels = None |
|
if 'labels' in batch_data: |
|
labels = torch.LongTensor(batch_data['labels']) |
|
|
|
mlmlabels = None |
|
if 'mlmlabels' in batch_data: |
|
mlmlabels = nn.utils.rnn.pad_sequence(batch_data['mlmlabels'], |
|
batch_first=True, |
|
padding_value=-100) |
|
|
|
input_ids = nn.utils.rnn.pad_sequence(input_ids, |
|
batch_first=True, |
|
padding_value=0) |
|
|
|
token_type_ids = nn.utils.rnn.pad_sequence(token_type_ids, |
|
batch_first=True, |
|
padding_value=0) |
|
attention_mask = nn.utils.rnn.pad_sequence(attention_mask, |
|
batch_first=True, |
|
padding_value=0) |
|
|
|
batch_data = { |
|
"sentence":batch_data["sentence"], |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
"labels": labels, |
|
"mlmlabels":mlmlabels |
|
} |
|
|
|
return batch_data |
|
|
|
|
|
|
|
class TCBertModel(nn.Module): |
|
def __init__(self, pre_train_dir, nlabels): |
|
super().__init__() |
|
self.config = AutoConfig.from_pretrained(pre_train_dir) |
|
print("pre_train_dir", pre_train_dir) |
|
|
|
if "1.3B" in pre_train_dir: |
|
self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir) |
|
else: |
|
self.bert = BertForMaskedLM.from_pretrained(pre_train_dir) |
|
|
|
self.dropout = nn.Dropout(0.1) |
|
self.nlabels = nlabels |
|
self.linear_classifier = nn.Linear(self.config.hidden_size, self.nlabels) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, mlmlabels=None): |
|
|
|
outputs = self.bert(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
labels=mlmlabels, |
|
output_hidden_states=True) |
|
|
|
mlm_logits = outputs.logits |
|
hidden_states = outputs.hidden_states[-1] |
|
cls_logits = hidden_states[:,0] |
|
cls_logits = self.dropout(cls_logits) |
|
|
|
logits = self.linear_classifier(cls_logits) |
|
|
|
return outputs.loss, logits, mlm_logits |
|
|
|
|
|
class TCBertLitModel(pl.LightningModule): |
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--learning_rate', default=1e-5, type=float) |
|
parser.add_argument('--weight_decay', default=0.1, type=float) |
|
parser.add_argument('--warmup', default=0.01, type=float) |
|
parser.add_argument('--num_labels', default=2, type=int) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args, model_path, nlabels): |
|
super().__init__() |
|
self.args = args |
|
self.loss_fn = torch.nn.CrossEntropyLoss() |
|
self.model = TCBertModel(model_path, nlabels) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 |
|
self.total_step = int(self.trainer.max_epochs * self.num_data / |
|
(max(1, num_gpus) * self.trainer.accumulate_grad_batches)) |
|
print('Total training step:', self.total_step) |
|
|
|
|
|
def train_inputs(self, batch): |
|
inputs = { |
|
'input_ids': batch['input_ids'], |
|
'attention_mask': batch['attention_mask'], |
|
'token_type_ids': batch['token_type_ids'], |
|
'mlmlabels': batch['mlmlabels'] |
|
} |
|
return inputs |
|
|
|
def training_step(self, batch, batch_idx): |
|
labels = batch['labels'] |
|
batch = self.train_inputs(batch) |
|
mlm_loss, logits, _= self.model(**batch) |
|
if labels is not None: |
|
cls_loss = self.loss_fn(logits, labels.view(-1)) |
|
|
|
loss = cls_loss + mlm_loss |
|
|
|
ntotal = logits.size(0) |
|
ncorrect = (logits.argmax(dim=-1) == labels).long().sum() |
|
acc = ncorrect / ntotal |
|
|
|
self.log('train_loss', loss, on_step=True, prog_bar=True) |
|
self.log("train_acc", acc, on_step=True, prog_bar=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
labels = batch['labels'] |
|
batch = self.train_inputs(batch) |
|
mlm_loss, logits, _ = self.model(**batch) |
|
predict = logits.argmax(dim=-1).cpu().tolist() |
|
|
|
if labels is not None: |
|
cls_loss = self.loss_fn(logits, labels.view(-1)) |
|
|
|
loss = cls_loss + mlm_loss |
|
ntotal = logits.size(0) |
|
|
|
ncorrect = int((logits.argmax(dim=-1) == labels).long().sum()) |
|
acc = ncorrect / ntotal |
|
|
|
self.log('valid_loss', loss, on_step=True, prog_bar=True) |
|
self.log("valid_acc", acc, on_step=True, prog_bar=True) |
|
|
|
return int(ncorrect), int(ntotal) |
|
|
|
def configure_optimizers(self): |
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
paras = list( |
|
filter(lambda p: p[1].requires_grad, self.named_parameters())) |
|
paras = [{ |
|
'params': |
|
[p for n, p in paras if not any(nd in n for nd in no_decay)], |
|
'weight_decay': self.args.weight_decay |
|
}, { |
|
'params': [p for n, p in paras if any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.0 |
|
}] |
|
optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, int(self.total_step * self.args.warmup), |
|
self.total_step) |
|
|
|
return [{ |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1 |
|
} |
|
}] |
|
|
|
|
|
|
|
class TCBertPredict: |
|
def __init__(self, model, tokenizer, args, prompt, prompt_label): |
|
self.tokenizer = tokenizer |
|
self.args = args |
|
self.data_model = TCBertDataModel( |
|
[], [], tokenizer, args, prompt, prompt_label) |
|
self.model = model |
|
|
|
def predict_inputs(self, batch): |
|
|
|
inputs = { |
|
'input_ids': batch['input_ids'].cuda(), |
|
'attention_mask': batch['attention_mask'].cuda(), |
|
'token_type_ids': batch['token_type_ids'].cuda(), |
|
} |
|
return inputs |
|
|
|
def predict(self, batch_data): |
|
batch = [self.data_model.train_data.encode( |
|
sample, labeled=False) for sample in batch_data] |
|
batch = self.data_model.collate_fn(batch) |
|
batch = self.predict_inputs(batch) |
|
_, logits, _ = self.model.model(**batch) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
predicts = torch.argmax(probs, dim=-1).detach().cpu().numpy() |
|
|
|
return predicts |
|
|
|
|