|
import json |
|
|
|
from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset" |
|
data_file = f"{dataset_dir}/testset_10k.jsonl" |
|
if __name__ == "__main__": |
|
|
|
encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base" |
|
|
|
tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer" |
|
|
|
src_tokenizer, tgt_tokenizer = prepare_tokenizer( |
|
src_tokenizer_dir=encoder_ckpt_dir, |
|
tgt_tokenizer_dir=tgt_tokenizer_dir, |
|
) |
|
dataset_df = prepare_dataset_df(data_file=data_file) |
|
my_dataset = MyDataset(df=dataset_df, |
|
src_tokenizer=src_tokenizer, |
|
tgt_tokenizer=tgt_tokenizer, |
|
max_src_length=512, |
|
max_target_length=512) |
|
print(len(my_dataset)) |
|
from torch.utils.data import Subset |
|
num_test = 5000 |
|
my_dataset = Subset(my_dataset, range(0, num_test)) |
|
|
|
|
|
|
|
|
|
|
|
img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist() |
|
text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist() |
|
with open('./mytest/text_src.jsonl', "w") as decoding_res_file: |
|
for img_name, text_src in zip(img_name_list, text_src_list): |
|
res_dict = { |
|
"img_name": img_name, |
|
"text_src": text_src, |
|
} |
|
|
|
record = f"{json.dumps(res_dict, ensure_ascii=False)}\n" |
|
decoding_res_file.write(record) |
|
|