|
|
|
import json |
|
import os |
|
|
|
import jieba |
|
|
|
import pandas as pd |
|
|
|
import torch |
|
|
|
from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer |
|
from nltk.translate.bleu_score import sentence_bleu |
|
from sacrebleu.metrics import CHRF |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from transformers import BertTokenizer, EncoderDecoderModel |
|
|
|
chrf = CHRF(word_order=2) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "5" |
|
|
|
|
|
|
|
MAX_TGT_LEN = 512 |
|
MAX_SRC_LEN = 512 |
|
|
|
|
|
output_dir = "./mytest" |
|
os.makedirs(output_dir, exist_ok=True) |
|
early_stopping = True |
|
num_beams = 2 |
|
length_penalty = 1.0 |
|
batch_size = 16 |
|
metric_res_filepath = os.path.join(output_dir, "metric_res.json") |
|
decoding_res_filepath = os.path.join(output_dir, "decoding_res.json") |
|
trained_model_dir = "/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12/checkpoint-64000" |
|
|
|
dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset" |
|
data_file = f"{dataset_dir}/testset_10k.jsonl" |
|
|
|
|
|
def no_blank(sen): |
|
return "".join(sen.split()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base" |
|
|
|
tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer" |
|
|
|
src_tokenizer, tgt_tokenizer = prepare_tokenizer( |
|
src_tokenizer_dir=encoder_ckpt_dir, |
|
tgt_tokenizer_dir=tgt_tokenizer_dir, |
|
) |
|
dataset_df = prepare_dataset_df(data_file=data_file) |
|
my_dataset = MyDataset(df=dataset_df, |
|
src_tokenizer=src_tokenizer, |
|
tgt_tokenizer=tgt_tokenizer, |
|
max_src_length=512, |
|
max_target_length=512) |
|
print(len(my_dataset)) |
|
from torch.utils.data import Subset |
|
num_test = 5000 |
|
my_dataset = Subset(my_dataset, range(0, num_test)) |
|
my_dataloader = DataLoader( |
|
my_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
) |
|
|
|
|
|
model = EncoderDecoderModel.from_pretrained(trained_model_dir) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
print(model) |
|
|
|
|
|
pred_res_list = [] |
|
gt_list = [] |
|
|
|
for batch in tqdm(my_dataloader): |
|
|
|
with torch.no_grad(): |
|
encoder_outputs = model.encoder( |
|
input_ids=batch["input_ids"].to(device), |
|
bbox=batch["bbox"].to(device), |
|
attention_mask=batch["attention_mask"].to(device), |
|
) |
|
outputs = model.generate( |
|
input_ids=batch["input_ids"].to(device), |
|
attention_mask=batch["attention_mask"].to(device), |
|
encoder_outputs=encoder_outputs, |
|
max_length=MAX_TGT_LEN, |
|
early_stopping=early_stopping, |
|
num_beams=num_beams, |
|
length_penalty=length_penalty, |
|
use_cache=True, |
|
decoder_start_token_id=0) |
|
|
|
|
|
pred_str = tgt_tokenizer.batch_decode(outputs, |
|
skip_special_tokens=True) |
|
labels = batch["labels"] |
|
labels[labels == -100] = tgt_tokenizer.pad_token_id |
|
label_str = tgt_tokenizer.batch_decode(labels, |
|
skip_special_tokens=True) |
|
|
|
pred_res_list += pred_str |
|
gt_list += label_str |
|
|
|
gt_list = [no_blank(sen) for sen in gt_list] |
|
pred_res_list = [no_blank(sen) for sen in pred_res_list] |
|
|
|
|
|
img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist() |
|
text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist() |
|
bleu_list = [] |
|
chrf_list = [] |
|
|
|
pred_res_seg_list = [" ".join(jieba.cut(item)) for item in pred_res_list] |
|
gt_seg_list = [" ".join(jieba.cut(item)) for item in gt_list] |
|
print(len(text_src_list), len(pred_res_seg_list), len(gt_seg_list)) |
|
|
|
assert len(img_name_list) == len(pred_res_seg_list) == len(gt_seg_list) |
|
|
|
with open(decoding_res_filepath, "w") as decoding_res_file: |
|
for img_name, text_src, pred_res_seg, gt_seg in zip( |
|
img_name_list, text_src_list, pred_res_seg_list, gt_seg_list): |
|
|
|
instance_bleu = sentence_bleu([gt_seg.split()], |
|
pred_res_seg.split()) |
|
bleu_list.append(instance_bleu) |
|
|
|
instance_chrf = chrf.sentence_score( |
|
hypothesis=pred_res_seg, |
|
references=[gt_seg], |
|
).score |
|
chrf_list.append(instance_chrf) |
|
|
|
res_dict = { |
|
"img_name": img_name, |
|
"text_src": text_src, |
|
"instance_bleu": instance_bleu, |
|
"instance_chrf": instance_chrf, |
|
"trans_res_seg": pred_res_seg, |
|
"gt_seg": gt_seg, |
|
} |
|
|
|
record = f"{json.dumps(res_dict, ensure_ascii=False)}\n" |
|
decoding_res_file.write(record) |
|
|
|
trans_avg_bleu = sum(bleu_list) / len(bleu_list) |
|
trans_avg_chrf = sum(chrf_list) / len(chrf_list) |
|
with open(metric_res_filepath, "w") as metric_res_file: |
|
eval_res_dict = { |
|
"trans_avg_bleu": trans_avg_bleu, |
|
"trans_avg_chrf": trans_avg_chrf, |
|
} |
|
json.dump(eval_res_dict, metric_res_file, indent=4, ensure_ascii=False) |
|
|