File size: 1,764 Bytes
41404b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import json

from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
from torch.utils.data import DataLoader, Dataset

dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
data_file = f"{dataset_dir}/testset_10k.jsonl"
if __name__ == "__main__":

    encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"

    tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"

    src_tokenizer, tgt_tokenizer = prepare_tokenizer(
        src_tokenizer_dir=encoder_ckpt_dir,
        tgt_tokenizer_dir=tgt_tokenizer_dir,
    )
    dataset_df = prepare_dataset_df(data_file=data_file)
    my_dataset = MyDataset(df=dataset_df,
                           src_tokenizer=src_tokenizer,
                           tgt_tokenizer=tgt_tokenizer,
                           max_src_length=512,
                           max_target_length=512)
    print(len(my_dataset))
    from torch.utils.data import Subset
    num_test = 5000  #total 10k
    my_dataset = Subset(my_dataset, range(0, num_test))
    # my_dataloader = DataLoader(
    #     my_dataset,
    #     batch_size=batch_size,
    #     shuffle=False,
    # )
    img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
    text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
    with open('./mytest/text_src.jsonl', "w") as decoding_res_file:
        for img_name, text_src in zip(img_name_list, text_src_list):
            res_dict = {
                "img_name": img_name,
                "text_src": text_src,
            }

            record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
            decoding_res_file.write(record)