File size: 1,764 Bytes
41404b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import json
from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
from torch.utils.data import DataLoader, Dataset
dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
data_file = f"{dataset_dir}/testset_10k.jsonl"
if __name__ == "__main__":
encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
src_tokenizer, tgt_tokenizer = prepare_tokenizer(
src_tokenizer_dir=encoder_ckpt_dir,
tgt_tokenizer_dir=tgt_tokenizer_dir,
)
dataset_df = prepare_dataset_df(data_file=data_file)
my_dataset = MyDataset(df=dataset_df,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_length=512,
max_target_length=512)
print(len(my_dataset))
from torch.utils.data import Subset
num_test = 5000 #total 10k
my_dataset = Subset(my_dataset, range(0, num_test))
# my_dataloader = DataLoader(
# my_dataset,
# batch_size=batch_size,
# shuffle=False,
# )
img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
with open('./mytest/text_src.jsonl', "w") as decoding_res_file:
for img_name, text_src in zip(img_name_list, text_src_list):
res_dict = {
"img_name": img_name,
"text_src": text_src,
}
record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
decoding_res_file.write(record)
|