from data.bert_dataloader.load import BertDataModule from transformers import ( BertTokenizer, BertConfig, BertForPreTraining, BertModel, BertForMaskedLM ) from pytorch_lightning import ( LightningDataModule, LightningModule, loggers, Trainer, ) from pytorch_lightning.callbacks import ( ModelCheckpoint, LearningRateMonitor, ) from typing import Optional from torch.utils.data import DataLoader from transformers.optimization import get_linear_schedule_with_warmup import argparse import sys import torch import os import re import jieba import numpy as np # 如果没有安装fengshen模块,请把Fengshenbang-LM/fengshen加入到系统环境变量 sys.path.insert(0, '../../../fengshen') os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' class DataCollate(object): def __init__(self, tokenizer, max_length, mask_rate=0.15, max_ngram=3, if_padding=True) -> None: self.tokenizer = tokenizer self.max_length = max_length self.word_cuter = jieba.cut self.vocab_length = len(tokenizer) self.mask_rate = mask_rate self.ignore_labels = -100 self.ngrams = np.arange(1, max_ngram + 1, dtype=np.int64) pvals = 1. / np.arange(1, max_ngram + 1) pvals /= pvals.sum(keepdims=True) # p(n) = 1/n / sigma(1/k) self.pvals = pvals self.padding = if_padding def token_process(self, token_id): rand = np.random.random() if rand <= 0.8: return self.tokenizer.mask_token_id elif rand <= 0.9: return token_id else: return np.random.randint(1, self.vocab_length) def __call__(self, samples): input_ids = [] attention_mask = [] token_type_ids = [] batch_labels = [] # print('^-^ batch size :',len(samples)) for sample in samples: word_list = list(self.word_cuter(sample['text'])) mask_ids, labels = [], [] record = [] for i in range(len(word_list)): rands = np.random.random() if i in record: continue word = word_list[i] if rands > self.mask_rate and len(word) < 4: word = word_list[i] word_encode = tokenizer.encode(word, add_special_tokens=False) for token in word_encode: mask_ids.append(token) labels.append(self.ignore_labels) record.append(i) else: n = np.random.choice(self.ngrams, p=self.pvals) for index in range(n): ind = index + i if ind in record or ind >= len(word_list): continue record.append(ind) word = word_list[ind] word_encode = tokenizer.encode(word, add_special_tokens=False) for token in word_encode: mask_ids.append(self.token_process(token)) labels.append(token) if self.padding: if len(mask_ids) > self.max_length: input_ids.append(mask_ids[:self.max_length]) batch_labels.append(labels[:self.max_length]) else: lenght = len(mask_ids) mask_ids.extend([0]*(self.max_length-lenght)) labels.extend([-100]*(self.max_length-lenght)) input_ids.append(mask_ids) batch_labels.append(labels) attention_mask.append([1]*self.max_length) token_type_ids.append([0]*self.max_length) # print('sentence:',sample['text']) # print('input_ids:',mask_ids) # print('decode inputids:',self.tokenizer.decode(mask_ids)) # print('labels',labels) # print('decode labels:',self.tokenizer.decode(labels)) # print('*'*20) return { 'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(batch_labels), 'attention_mask': torch.tensor(attention_mask), 'token_type_ids': torch.tensor(token_type_ids) } class Bert(LightningModule): @staticmethod def add_module_specific_args(args_parser): parser = args_parser.add_argument_group('Bert') parser.add_argument('--model_path', type=str, default='') parser.add_argument('--learning_rate', default=1e-5, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) parser.add_argument('--warmup', default=0.01, type=float) return args_parser def __init__(self, args): super().__init__() self.save_hyperparameters(args) self.bertconfig = BertConfig.from_pretrained(args.model_path) # self.model = BertForPreTraining(self.bertconfig) self.model = BertForMaskedLM(self.bertconfig) def setup(self, stage) -> None: if stage == 'fit': train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() # Calculate total steps tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size def configure_optimizers(self): no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] paras = list( filter(lambda p: p[1].requires_grad, self.named_parameters())) paras = [{ 'params': [p for n, p in paras if not any(nd in n for nd in no_decay)], 'weight_decay': self.hparams.weight_decay }, { 'params': [p for n, p in paras if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = torch.optim.AdamW(paras, lr=self.hparams.learning_rate) scheduler = get_linear_schedule_with_warmup( optimizer, int(self.total_steps * self.hparams.warmup), self.total_steps) return [{ 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1 } }] def training_step(self, batch, batch_idx): output = self.model(**batch) # print(output) self.log('train_loss', output.loss) return output.loss def comput_metrix(self, logits, labels): ones = torch.ones_like(labels) zero = torch.zeros_like(labels) mask = torch.where(labels < 0, zero, ones) mask = mask.view(size=(-1,)).float() # y_true=labels.view(size=(-1,)).float() y_pred = torch.argmax(logits, dim=-1) y_pred = y_pred.view(size=(-1,)) y_true = labels.view(size=(-1,)).float() corr = torch.eq(y_pred, y_true) corr = torch.multiply(corr.float(), mask) acc = torch.sum(corr.float()) / torch.sum(mask) return acc def validation_step(self, batch, batch_idx): output = self.model(**batch) # print(output) acc = self.comput_metrix(output.logits, batch['labels']) print('val_loss ', output.loss) self.log('val_loss', output.loss) self.log('val_acc', acc) # pass def predict_step(self, batch, batch_idx): output = self.model(**batch) return output.prediction_logits class CustomCKPT: @staticmethod def add_argparse_args(parent_args): parser = parent_args.add_argument_group('ckpt call back') parser.add_argument('--monitor', default='train_loss', type=str) parser.add_argument('--mode', default='min', type=str) parser.add_argument('--dirpath', default='./ckpt/', type=str) parser.add_argument( '--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) parser.add_argument('--save_last', action='store_true', default=True) parser.add_argument('--save_top_k', default=3, type=float) parser.add_argument('--every_n_train_steps', default=100, type=float) parser.add_argument('--save_weights_only', action='store_true', default=False) return parent_args def __init__(self, args): self.callbacks = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode=args.mode, every_n_train_steps=args.every_n_train_steps, save_weights_only=args.save_weights_only, dirpath=args.dirpath, filename=args.filename, save_last=args.save_last) if __name__ == '__main__': args_parser = argparse.ArgumentParser() args_parser = BertDataModule.add_data_specific_args(args_parser) args_parser = Trainer.add_argparse_args(args_parser) args_parser = Bert.add_module_specific_args(args_parser) args_parser = CustomCKPT.add_argparse_args(args_parser) args_parser.add_argument('--deepspeed') args_parser.add_argument('--seq_max_length') args = args_parser.parse_args() tokenizer = BertTokenizer.from_pretrained(args.model_path) collate_fn = DataCollate(tokenizer, 512) data_module = BertDataModule(tokenizer=tokenizer, args=args, collate_fn=collate_fn) print('data load complete') model = Bert(args) print('model load complete') lr_monitor = LearningRateMonitor(logging_interval='step') logger = loggers.TensorBoardLogger(save_dir=os.path.join( args.default_root_dir, 'logs/'), name=os.path.basename(os.path.dirname(args.model_path))) checkpoint_callback = CustomCKPT(args).callbacks if args.resume_from_checkpoint is not None and \ not os.path.exists(args.resume_from_checkpoint): print('--------warning no checkpoint found--------, remove args') del args.resume_from_checkpoint # autotuning if args.deepspeed is not None: os.environ['PL_DEEPSPEED_CONFIG_PATH'] = args.deepspeed trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[ lr_monitor, checkpoint_callback]) trainer.fit(model, data_module)