summary / fengshen /examples /pegasus /pretrain_pegasus.py
fclong's picture
Upload 396 files
8ebda9e
# -*- coding: utf-8 -*-
from fengshen.models.model_utils import add_module_args
from transformers import PegasusForConditionalGeneration, PegasusConfig
from pytorch_lightning import Trainer, loggers, LightningModule
from pytorch_lightning.callbacks import LearningRateMonitor
from tokenizers_pegasus import PegasusTokenizer
from utils import UniversalCheckpoint
from data.universal_datamodule import UniversalDataModule
from data_utils import (
get_input_mask, pseudo_summary_f1, shift_tokens_right,
padding_to_maxlength, load_stopwords, text_segmentate)
import argparse
import torch
import os
import sys
sys.path.append('../../')
# os.environ["CUDA_VISIBLE_DEVICES"] = '6'
class FakeAbstractCollator:
def __init__(self, tokenizer, stopwords_dict, max_enc_length):
self.tokenizer = tokenizer
self.max_seq_length = max_enc_length
self.stopwords_dict = stopwords_dict
def __call__(self, samples):
# print("samples: ", samples)
labels = []
attn_mask = []
decoder_attn_mask = []
source_inputs = []
for text in samples:
texts = text["chunks"]
text = text_segmentate(texts)
sentence_id_vec, source, target, source_idxs, target_idxs = pseudo_summary_f1(
text, self.stopwords_dict, self.tokenizer, self.max_seq_length,
"rouge-l")
source_idxs, target_idxs = get_input_mask(sentence_id_vec,
target_idxs)
if len(source_idxs) > self.max_seq_length:
if 2 not in source_idxs[self.max_seq_length - 1:]:
source_idxs = source_idxs[:self.max_seq_length]
source_idxs[-1] = self.tokenizer.eos_token_id
sys.stderr.write("Warning split long line: " + source +
"\n")
else:
continue
source_idxs, attention_mask = padding_to_maxlength(
source_idxs, self.max_seq_length, self.tokenizer.pad_token_id)
label, target_attention_mask = padding_to_maxlength(
target_idxs, self.max_seq_length, self.tokenizer.pad_token_id)
# print("sample len: ", len(source_idxs))
source_inputs.append(source_idxs)
attn_mask.append(attention_mask)
decoder_attn_mask.append(target_attention_mask)
labels.append(label)
labels = torch.tensor(labels)
decode_input_idxs = shift_tokens_right(labels,
self.tokenizer.pad_token_id,
self.tokenizer.pad_token_id)
end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx + 1:] = -100
# print("call samples: ")
return {
"input_ids": torch.tensor(source_inputs),
"attention_mask": torch.tensor(attn_mask),
"labels": labels,
"decoder_input_ids": decode_input_idxs,
"decoder_attention_mask": torch.tensor(decoder_attn_mask)
}
class PegasusChineseModel(LightningModule):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
self.save_hyperparameters(args)
config = PegasusConfig.from_json_file(
os.path.join(args.model_path, "config.json"))
print("vocab_size: ", config.vocab_size)
self.model = PegasusForConditionalGeneration(config=config)
print("model.num_parameters: ", self.model.num_parameters())
def setup(self, stage) -> None:
if stage == 'fit':
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader(
)
# Calculate total steps
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus)
ab_size = self.trainer.accumulate_grad_batches * float(
self.trainer.max_epochs)
self.total_steps = (len(train_loader.dataset) //
tb_size) // ab_size
print('Total training step:', self.total_steps)
def configure_optimizers(self):
from fengshen.models.model_utils import configure_optimizers
return configure_optimizers(self)
def training_step(self, batch, batch_idx):
output = self.model(**batch)
self.log('train_loss', output.loss, sync_dist=True)
return output.loss
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1, ))
y_true = labels.view(size=(-1, )).float()
corr = torch.eq(y_pred, y_true)
acc = torch.sum(corr.float()) / labels.size()[0]
return acc
def validation_step(self, batch, batch_idx):
output = self.model(**batch)
acc = self.comput_metrix(output.logits, batch['labels'])
self.log('val_loss', output.loss, sync_dist=True)
self.log('val_acc', acc, sync_dist=True)
def on_save_checkpoint(self, checkpoint) -> None:
if self.trainer._accelerator_connector.cluster_environment.global_rank(
) == 0:
self.model.save_pretrained(
os.path.join(
self.trainer.checkpoint_callback.dirpath,
'hf_pretrained_epoch{}_step{}'.format(
checkpoint['epoch'], checkpoint['global_step'])))
def main():
args_parser = argparse.ArgumentParser("Pegasus Task")
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args_parser = add_module_args(args_parser)
args_parser.add_argument('--deepspeed')
args_parser.add_argument(
'--stopword_path',
default="/cognitive_comp/dongxiaoqun/project/pegasus/own/pegasus/stopwords",
type=str)
args_parser.add_argument('--max_seq_length', default=1024, type=int)
args = args_parser.parse_args()
tokenizer = PegasusTokenizer.from_pretrained(args.model_path)
stopwords_dict = load_stopwords(args.stopword_path)
collator = FakeAbstractCollator(tokenizer, stopwords_dict,
args.max_seq_length)
data_module = UniversalDataModule(tokenizer=tokenizer,
args=args,
collate_fn=collator)
module = PegasusChineseModel(args)
lr_monitor = LearningRateMonitor(logging_interval='step')
logger = loggers.TensorBoardLogger(
save_dir=os.path.join(args.default_root_dir, 'logs/'),
name=os.path.basename(os.path.dirname(args.model_path)))
checkpoint_callback = UniversalCheckpoint(args).callbacks
# autotuning
if args.deepspeed is not None:
os.environ['PL_DEEPSPEED_CONFIG_PATH'] = args.deepspeed
trainer = Trainer.from_argparse_args(
args, logger=logger, callbacks=[lr_monitor, checkpoint_callback])
trainer.fit(module, data_module)
if __name__ == '__main__':
main()