File size: 5,793 Bytes
41404b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# basic imports
import json
import os

import jieba
# other external imports
import pandas as pd
# torch imports
import torch
# internal imports
from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
from nltk.translate.bleu_score import sentence_bleu
from sacrebleu.metrics import CHRF
from torch.utils.data import DataLoader
from tqdm import tqdm
# transformers imports
from transformers import BertTokenizer, EncoderDecoderModel

chrf = CHRF(word_order=2)  # word_order=2 to be chrf++.

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

# hyper-parameters.
## for model.
MAX_TGT_LEN = 512
MAX_SRC_LEN = 512

## for decoding.
output_dir = "./mytest"
os.makedirs(output_dir, exist_ok=True)
early_stopping = True
num_beams = 2
length_penalty = 1.0
batch_size = 16
metric_res_filepath = os.path.join(output_dir, "metric_res.json")
decoding_res_filepath = os.path.join(output_dir, "decoding_res.json")
trained_model_dir = "/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12/checkpoint-64000"

dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
data_file = f"{dataset_dir}/testset_10k.jsonl"


def no_blank(sen):
    return "".join(sen.split())


if __name__ == "__main__":

    encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"

    tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"

    src_tokenizer, tgt_tokenizer = prepare_tokenizer(
        src_tokenizer_dir=encoder_ckpt_dir,
        tgt_tokenizer_dir=tgt_tokenizer_dir,
    )
    dataset_df = prepare_dataset_df(data_file=data_file)
    my_dataset = MyDataset(df=dataset_df,
                           src_tokenizer=src_tokenizer,
                           tgt_tokenizer=tgt_tokenizer,
                           max_src_length=512,
                           max_target_length=512)
    print(len(my_dataset))
    from torch.utils.data import Subset
    num_test = 5000  #total 10k
    my_dataset = Subset(my_dataset, range(0, num_test))
    my_dataloader = DataLoader(
        my_dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    # loading model and config from pretrained folder
    model = EncoderDecoderModel.from_pretrained(trained_model_dir)
    # device='cpu'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    print(model)

    # decoding testset
    pred_res_list = []
    gt_list = []

    for batch in tqdm(my_dataloader):
        # predict use generate
        with torch.no_grad():
            encoder_outputs = model.encoder(
                input_ids=batch["input_ids"].to(device),
                bbox=batch["bbox"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            outputs = model.generate(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                encoder_outputs=encoder_outputs,
                max_length=MAX_TGT_LEN,
                early_stopping=early_stopping,
                num_beams=num_beams,
                length_penalty=length_penalty,
                use_cache=True,
                decoder_start_token_id=0)

        # decode
        pred_str = tgt_tokenizer.batch_decode(outputs,
                                              skip_special_tokens=True)
        labels = batch["labels"]
        labels[labels == -100] = tgt_tokenizer.pad_token_id
        label_str = tgt_tokenizer.batch_decode(labels,
                                               skip_special_tokens=True)

        pred_res_list += pred_str
        gt_list += label_str

    gt_list = [no_blank(sen) for sen in gt_list]
    pred_res_list = [no_blank(sen) for sen in pred_res_list]

    # write the decoding res and compute metric.
    img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
    text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
    bleu_list = []
    chrf_list = []

    pred_res_seg_list = [" ".join(jieba.cut(item)) for item in pred_res_list]
    gt_seg_list = [" ".join(jieba.cut(item)) for item in gt_list]
    print(len(text_src_list), len(pred_res_seg_list), len(gt_seg_list))
    # print(img_name_list, pred_res_list, gt_seg_list)
    assert len(img_name_list) == len(pred_res_seg_list) == len(gt_seg_list)

    with open(decoding_res_filepath, "w") as decoding_res_file:
        for img_name, text_src, pred_res_seg, gt_seg in zip(
                img_name_list, text_src_list, pred_res_seg_list, gt_seg_list):

            instance_bleu = sentence_bleu([gt_seg.split()],
                                          pred_res_seg.split())
            bleu_list.append(instance_bleu)

            instance_chrf = chrf.sentence_score(
                hypothesis=pred_res_seg,
                references=[gt_seg],
            ).score
            chrf_list.append(instance_chrf)

            res_dict = {
                "img_name": img_name,
                "text_src": text_src,
                "instance_bleu": instance_bleu,
                "instance_chrf": instance_chrf,
                "trans_res_seg": pred_res_seg,
                "gt_seg": gt_seg,
            }

            record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
            decoding_res_file.write(record)

    trans_avg_bleu = sum(bleu_list) / len(bleu_list)
    trans_avg_chrf = sum(chrf_list) / len(chrf_list)
    with open(metric_res_filepath, "w") as metric_res_file:
        eval_res_dict = {
            "trans_avg_bleu": trans_avg_bleu,
            "trans_avg_chrf": trans_avg_chrf,
        }
        json.dump(eval_res_dict, metric_res_file, indent=4, ensure_ascii=False)