test_upload / test_bleu_chrf.py
Chen42's picture
Upload folder using huggingface_hub
41404b8 verified
raw
history blame
5.79 kB
# basic imports
import json
import os
import jieba
# other external imports
import pandas as pd
# torch imports
import torch
# internal imports
from model_and_train import MyDataset, prepare_dataset_df, prepare_tokenizer
from nltk.translate.bleu_score import sentence_bleu
from sacrebleu.metrics import CHRF
from torch.utils.data import DataLoader
from tqdm import tqdm
# transformers imports
from transformers import BertTokenizer, EncoderDecoderModel
chrf = CHRF(word_order=2) # word_order=2 to be chrf++.
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
# hyper-parameters.
## for model.
MAX_TGT_LEN = 512
MAX_SRC_LEN = 512
## for decoding.
output_dir = "./mytest"
os.makedirs(output_dir, exist_ok=True)
early_stopping = True
num_beams = 2
length_penalty = 1.0
batch_size = 16
metric_res_filepath = os.path.join(output_dir, "metric_res.json")
decoding_res_filepath = os.path.join(output_dir, "decoding_res.json")
trained_model_dir = "/home/zychen/hwproject/my_modeling_phase_1/train.lr_0.0001.bsz_28.step_400000.layer_12-12/checkpoint-64000"
dataset_dir = "/home/zychen/hwproject/my_modeling_phase_1/dataset"
data_file = f"{dataset_dir}/testset_10k.jsonl"
def no_blank(sen):
return "".join(sen.split())
if __name__ == "__main__":
encoder_ckpt_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/lilt-roberta-en-base"
tgt_tokenizer_dir = "/home/zychen/hwproject/my_modeling_phase_1/Tokenizer_PretrainedWeights/bert-base-chinese-tokenizer"
src_tokenizer, tgt_tokenizer = prepare_tokenizer(
src_tokenizer_dir=encoder_ckpt_dir,
tgt_tokenizer_dir=tgt_tokenizer_dir,
)
dataset_df = prepare_dataset_df(data_file=data_file)
my_dataset = MyDataset(df=dataset_df,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_length=512,
max_target_length=512)
print(len(my_dataset))
from torch.utils.data import Subset
num_test = 5000 #total 10k
my_dataset = Subset(my_dataset, range(0, num_test))
my_dataloader = DataLoader(
my_dataset,
batch_size=batch_size,
shuffle=False,
)
# loading model and config from pretrained folder
model = EncoderDecoderModel.from_pretrained(trained_model_dir)
# device='cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(model)
# decoding testset
pred_res_list = []
gt_list = []
for batch in tqdm(my_dataloader):
# predict use generate
with torch.no_grad():
encoder_outputs = model.encoder(
input_ids=batch["input_ids"].to(device),
bbox=batch["bbox"].to(device),
attention_mask=batch["attention_mask"].to(device),
)
outputs = model.generate(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
encoder_outputs=encoder_outputs,
max_length=MAX_TGT_LEN,
early_stopping=early_stopping,
num_beams=num_beams,
length_penalty=length_penalty,
use_cache=True,
decoder_start_token_id=0)
# decode
pred_str = tgt_tokenizer.batch_decode(outputs,
skip_special_tokens=True)
labels = batch["labels"]
labels[labels == -100] = tgt_tokenizer.pad_token_id
label_str = tgt_tokenizer.batch_decode(labels,
skip_special_tokens=True)
pred_res_list += pred_str
gt_list += label_str
gt_list = [no_blank(sen) for sen in gt_list]
pred_res_list = [no_blank(sen) for sen in pred_res_list]
# write the decoding res and compute metric.
img_name_list = dataset_df["img_path"].iloc[0:num_test].tolist()
text_src_list = dataset_df["text_src"].iloc[0:num_test].tolist()
bleu_list = []
chrf_list = []
pred_res_seg_list = [" ".join(jieba.cut(item)) for item in pred_res_list]
gt_seg_list = [" ".join(jieba.cut(item)) for item in gt_list]
print(len(text_src_list), len(pred_res_seg_list), len(gt_seg_list))
# print(img_name_list, pred_res_list, gt_seg_list)
assert len(img_name_list) == len(pred_res_seg_list) == len(gt_seg_list)
with open(decoding_res_filepath, "w") as decoding_res_file:
for img_name, text_src, pred_res_seg, gt_seg in zip(
img_name_list, text_src_list, pred_res_seg_list, gt_seg_list):
instance_bleu = sentence_bleu([gt_seg.split()],
pred_res_seg.split())
bleu_list.append(instance_bleu)
instance_chrf = chrf.sentence_score(
hypothesis=pred_res_seg,
references=[gt_seg],
).score
chrf_list.append(instance_chrf)
res_dict = {
"img_name": img_name,
"text_src": text_src,
"instance_bleu": instance_bleu,
"instance_chrf": instance_chrf,
"trans_res_seg": pred_res_seg,
"gt_seg": gt_seg,
}
record = f"{json.dumps(res_dict, ensure_ascii=False)}\n"
decoding_res_file.write(record)
trans_avg_bleu = sum(bleu_list) / len(bleu_list)
trans_avg_chrf = sum(chrf_list) / len(chrf_list)
with open(metric_res_filepath, "w") as metric_res_file:
eval_res_dict = {
"trans_avg_bleu": trans_avg_bleu,
"trans_avg_chrf": trans_avg_chrf,
}
json.dump(eval_res_dict, metric_res_file, indent=4, ensure_ascii=False)