|
|
|
|
|
|
|
from fengshen.models.model_utils import add_module_args |
|
from transformers import PegasusForConditionalGeneration, PegasusConfig |
|
from pytorch_lightning import Trainer, loggers, LightningModule |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from tokenizers_pegasus import PegasusTokenizer |
|
from utils import UniversalCheckpoint |
|
from data.universal_datamodule import UniversalDataModule |
|
from data_utils import ( |
|
get_input_mask, pseudo_summary_f1, shift_tokens_right, |
|
padding_to_maxlength, load_stopwords, text_segmentate) |
|
import argparse |
|
import torch |
|
import os |
|
import sys |
|
|
|
sys.path.append('../../') |
|
|
|
|
|
|
|
|
|
|
|
class FakeAbstractCollator: |
|
|
|
def __init__(self, tokenizer, stopwords_dict, max_enc_length): |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_enc_length |
|
self.stopwords_dict = stopwords_dict |
|
|
|
def __call__(self, samples): |
|
|
|
labels = [] |
|
attn_mask = [] |
|
decoder_attn_mask = [] |
|
source_inputs = [] |
|
|
|
for text in samples: |
|
texts = text["chunks"] |
|
text = text_segmentate(texts) |
|
sentence_id_vec, source, target, source_idxs, target_idxs = pseudo_summary_f1( |
|
text, self.stopwords_dict, self.tokenizer, self.max_seq_length, |
|
"rouge-l") |
|
source_idxs, target_idxs = get_input_mask(sentence_id_vec, |
|
target_idxs) |
|
if len(source_idxs) > self.max_seq_length: |
|
if 2 not in source_idxs[self.max_seq_length - 1:]: |
|
source_idxs = source_idxs[:self.max_seq_length] |
|
source_idxs[-1] = self.tokenizer.eos_token_id |
|
sys.stderr.write("Warning split long line: " + source + |
|
"\n") |
|
else: |
|
continue |
|
|
|
source_idxs, attention_mask = padding_to_maxlength( |
|
source_idxs, self.max_seq_length, self.tokenizer.pad_token_id) |
|
label, target_attention_mask = padding_to_maxlength( |
|
target_idxs, self.max_seq_length, self.tokenizer.pad_token_id) |
|
|
|
source_inputs.append(source_idxs) |
|
attn_mask.append(attention_mask) |
|
decoder_attn_mask.append(target_attention_mask) |
|
labels.append(label) |
|
labels = torch.tensor(labels) |
|
decode_input_idxs = shift_tokens_right(labels, |
|
self.tokenizer.pad_token_id, |
|
self.tokenizer.pad_token_id) |
|
end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1] |
|
for idx, end_idx in enumerate(end_token_index): |
|
labels[idx][end_idx + 1:] = -100 |
|
|
|
|
|
return { |
|
"input_ids": torch.tensor(source_inputs), |
|
"attention_mask": torch.tensor(attn_mask), |
|
"labels": labels, |
|
"decoder_input_ids": decode_input_idxs, |
|
"decoder_attention_mask": torch.tensor(decoder_attn_mask) |
|
} |
|
|
|
|
|
class PegasusChineseModel(LightningModule): |
|
|
|
def __init__(self, args, **kwargs): |
|
super().__init__() |
|
self.args = args |
|
self.save_hyperparameters(args) |
|
config = PegasusConfig.from_json_file( |
|
os.path.join(args.model_path, "config.json")) |
|
print("vocab_size: ", config.vocab_size) |
|
self.model = PegasusForConditionalGeneration(config=config) |
|
print("model.num_parameters: ", self.model.num_parameters()) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader( |
|
) |
|
|
|
|
|
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) |
|
ab_size = self.trainer.accumulate_grad_batches * float( |
|
self.trainer.max_epochs) |
|
self.total_steps = (len(train_loader.dataset) // |
|
tb_size) // ab_size |
|
print('Total training step:', self.total_steps) |
|
|
|
def configure_optimizers(self): |
|
from fengshen.models.model_utils import configure_optimizers |
|
return configure_optimizers(self) |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self.model(**batch) |
|
self.log('train_loss', output.loss, sync_dist=True) |
|
return output.loss |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1, )) |
|
y_true = labels.view(size=(-1, )).float() |
|
corr = torch.eq(y_pred, y_true) |
|
acc = torch.sum(corr.float()) / labels.size()[0] |
|
return acc |
|
|
|
def validation_step(self, batch, batch_idx): |
|
output = self.model(**batch) |
|
acc = self.comput_metrix(output.logits, batch['labels']) |
|
self.log('val_loss', output.loss, sync_dist=True) |
|
self.log('val_acc', acc, sync_dist=True) |
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
if self.trainer._accelerator_connector.cluster_environment.global_rank( |
|
) == 0: |
|
self.model.save_pretrained( |
|
os.path.join( |
|
self.trainer.checkpoint_callback.dirpath, |
|
'hf_pretrained_epoch{}_step{}'.format( |
|
checkpoint['epoch'], checkpoint['global_step']))) |
|
|
|
|
|
def main(): |
|
args_parser = argparse.ArgumentParser("Pegasus Task") |
|
|
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args_parser = add_module_args(args_parser) |
|
args_parser.add_argument('--deepspeed') |
|
args_parser.add_argument( |
|
'--stopword_path', |
|
default="/cognitive_comp/dongxiaoqun/project/pegasus/own/pegasus/stopwords", |
|
type=str) |
|
args_parser.add_argument('--max_seq_length', default=1024, type=int) |
|
args = args_parser.parse_args() |
|
|
|
tokenizer = PegasusTokenizer.from_pretrained(args.model_path) |
|
stopwords_dict = load_stopwords(args.stopword_path) |
|
collator = FakeAbstractCollator(tokenizer, stopwords_dict, |
|
args.max_seq_length) |
|
data_module = UniversalDataModule(tokenizer=tokenizer, |
|
args=args, |
|
collate_fn=collator) |
|
module = PegasusChineseModel(args) |
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
logger = loggers.TensorBoardLogger( |
|
save_dir=os.path.join(args.default_root_dir, 'logs/'), |
|
name=os.path.basename(os.path.dirname(args.model_path))) |
|
checkpoint_callback = UniversalCheckpoint(args).callbacks |
|
|
|
|
|
if args.deepspeed is not None: |
|
os.environ['PL_DEEPSPEED_CONFIG_PATH'] = args.deepspeed |
|
|
|
trainer = Trainer.from_argparse_args( |
|
args, logger=logger, callbacks=[lr_monitor, checkpoint_callback]) |
|
|
|
trainer.fit(module, data_module) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|