File size: 2,798 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# coding=utf8
import argparse
import sys
import os
from concurrent.futures import ProcessPoolExecutor
def _generate_cache_arrow(index, ds, path):
print('saving dataset shard {}'.format(index))
ds.save_to_disk(os.path.join(path, 'part_{}'.format(index)))
return 'saving dataset shard {} done'.format(index)
def generate_arrow_cache(ds, args) -> None:
'''
读取wudao_180g等原数据或者tokenized之后的数据,并进行train test split
同时利用seed 42做shuffle 缓存下来
'''
ds = ds.train_test_split(train_size=args.train_split_size, seed=42)
print(ds)
p = ProcessPoolExecutor(max_workers=args.preprocessing_num_workers)
res = []
train_shard_part = args.saved_data_shards
for i in range(0, train_shard_part):
res.append(p.submit(_generate_cache_arrow, i,
ds['train'].shard(train_shard_part, i), args.saved_train_data_path))
p.shutdown(wait=True)
for future in res:
print(future.result(), flush=True)
ds['test'].save_to_disk(args.saved_test_data_path)
print('done')
if __name__ == '__main__':
total_parser = argparse.ArgumentParser("Save data Task")
total_parser.add_argument(
'--new_vocab_path', default='/cognitive_comp/ganruyi/hf_models/t5_cn_small/sentencepiece_cn.model', type=str)
total_parser.add_argument('--preprocessing_num_workers', default=30, type=int)
total_parser.add_argument(
'--train_data_path', default='/cognitive_comp/common_data/test_wudao_180g_mt5_tokenized/', type=str)
total_parser.add_argument('--saved_data_shards', default=800, type=int)
total_parser.add_argument('--saved_train_data_path', default=None, type=str)
total_parser.add_argument('--saved_test_data_path', default=None, type=str)
total_parser.add_argument('--max_seq_length', default=512, type=int)
total_parser.add_argument('--train_split_size', default=0.999, type=float)
total_parser.add_argument('--pretrained_model_path', default=None, type=str)
total_parser.add_argument('--tokenizer_type', default='t5_tokenizer', choices=['t5_tokenizer', 'bert_tokenizer'])
total_parser.add_argument('--text_column_name', default='text')
total_parser.add_argument('--remove_columns', nargs='+', default=[])
# * Args for data preprocessing
args = total_parser.parse_args()
sys.path.append('../../../')
from fengshen.data.t5_dataloader.t5_datasets import UnsuperviseT5Dataset
ds = UnsuperviseT5Dataset(args.train_data_path, args)
print(ds)
generate_arrow_cache(ds.data, args=args)
# ds = UnsuperviseT5Dataset(args.train_data_path, args, load_data_type=0)
for i in range(0, 2):
print(ds.data[i])
print(ds.tokenizer.decode(ds.data[i]['input_ids']))
print(ds.data)
|