|
|
|
import argparse |
|
import sys |
|
import os |
|
from concurrent.futures import ProcessPoolExecutor |
|
|
|
|
|
def _generate_cache_arrow(index, ds, path): |
|
print('saving dataset shard {}'.format(index)) |
|
ds.save_to_disk(os.path.join(path, 'part_{}'.format(index))) |
|
return 'saving dataset shard {} done'.format(index) |
|
|
|
|
|
def generate_arrow_cache(ds, args) -> None: |
|
''' |
|
读取wudao_180g等原数据或者tokenized之后的数据,并进行train test split |
|
同时利用seed 42做shuffle 缓存下来 |
|
''' |
|
ds = ds.train_test_split(train_size=args.train_split_size, seed=42) |
|
print(ds) |
|
p = ProcessPoolExecutor(max_workers=args.preprocessing_num_workers) |
|
res = [] |
|
train_shard_part = args.saved_data_shards |
|
for i in range(0, train_shard_part): |
|
res.append(p.submit(_generate_cache_arrow, i, |
|
ds['train'].shard(train_shard_part, i), args.saved_train_data_path)) |
|
|
|
p.shutdown(wait=True) |
|
for future in res: |
|
print(future.result(), flush=True) |
|
|
|
ds['test'].save_to_disk(args.saved_test_data_path) |
|
print('done') |
|
|
|
|
|
if __name__ == '__main__': |
|
total_parser = argparse.ArgumentParser("Save data Task") |
|
total_parser.add_argument( |
|
'--new_vocab_path', default='/cognitive_comp/ganruyi/hf_models/t5_cn_small/sentencepiece_cn.model', type=str) |
|
total_parser.add_argument('--preprocessing_num_workers', default=30, type=int) |
|
total_parser.add_argument( |
|
'--train_data_path', default='/cognitive_comp/common_data/test_wudao_180g_mt5_tokenized/', type=str) |
|
total_parser.add_argument('--saved_data_shards', default=800, type=int) |
|
total_parser.add_argument('--saved_train_data_path', default=None, type=str) |
|
total_parser.add_argument('--saved_test_data_path', default=None, type=str) |
|
total_parser.add_argument('--max_seq_length', default=512, type=int) |
|
total_parser.add_argument('--train_split_size', default=0.999, type=float) |
|
total_parser.add_argument('--pretrained_model_path', default=None, type=str) |
|
total_parser.add_argument('--tokenizer_type', default='t5_tokenizer', choices=['t5_tokenizer', 'bert_tokenizer']) |
|
total_parser.add_argument('--text_column_name', default='text') |
|
total_parser.add_argument('--remove_columns', nargs='+', default=[]) |
|
|
|
|
|
args = total_parser.parse_args() |
|
sys.path.append('../../../') |
|
from fengshen.data.t5_dataloader.t5_datasets import UnsuperviseT5Dataset |
|
ds = UnsuperviseT5Dataset(args.train_data_path, args) |
|
print(ds) |
|
generate_arrow_cache(ds.data, args=args) |
|
|
|
for i in range(0, 2): |
|
print(ds.data[i]) |
|
print(ds.tokenizer.decode(ds.data[i]['input_ids'])) |
|
|
|
print(ds.data) |
|
|