Spaces:
Sleeping
Sleeping
Merge pull request #51 from project-kxkg/evaluation
Browse filesEvaluation
Former-commit-id: ef5be7bdfbbb13908d071d9a8785d3b06be1143f
- README.md +2 -2
- evaluation/alignment.py +139 -0
- evaluation/evaluation.py +58 -0
- evaluation/readme.md +24 -0
- evaluation/scores/LLM_eval.py +121 -0
- evaluation/scores/__init__.py +0 -0
- evaluation/scores/multi_scores.py +63 -0
- evaluation/scores/score.py +15 -0
- requirement.txt +2 -0
- src/srt_util/srt.py +13 -4
README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
|
3 |
## Installation
|
4 |
|
@@ -51,4 +51,4 @@ options:
|
|
51 |
|
52 |
## Notice
|
53 |
if you cannot download youtube video, please follow the link below.
|
54 |
-
https://github.com/pytube/pytube/issues/1498
|
|
|
1 |
+
# Pigeon AI: Automatic Video Translation Toolkit
|
2 |
|
3 |
## Installation
|
4 |
|
|
|
51 |
|
52 |
## Notice
|
53 |
if you cannot download youtube video, please follow the link below.
|
54 |
+
https://github.com/pytube/pytube/issues/1498
|
evaluation/alignment.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
sys.path.append('../src')
|
4 |
+
from srt_util.srt import SrtScript
|
5 |
+
from srt_util.srt import SrtSegment
|
6 |
+
|
7 |
+
|
8 |
+
# Helper method
|
9 |
+
# Align sub anchor segment pair via greedy approach
|
10 |
+
# Input: anchor segment, SRT segments, output array of sub, index of current sub
|
11 |
+
# Output: updated index of sub
|
12 |
+
def procedure(anchor, subsec, S_arr, subidx):
|
13 |
+
cache_idx = 0
|
14 |
+
while subidx != cache_idx: # Terminate when alignment stablizes
|
15 |
+
cache_idx = subidx
|
16 |
+
# if sub segment runs out during the loop, terminate
|
17 |
+
if subidx >= len(subsec):
|
18 |
+
break
|
19 |
+
sub = subsec[subidx]
|
20 |
+
if anchor.end < sub.start:
|
21 |
+
continue
|
22 |
+
# If next sub has a heavier overlap compartment, add to current alignment
|
23 |
+
if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start:
|
24 |
+
S_arr[-1] += sub#.source_text
|
25 |
+
subidx += 1
|
26 |
+
|
27 |
+
return subidx - 1 # Reset last invalid update from loop
|
28 |
+
|
29 |
+
|
30 |
+
# Input: path1, path2
|
31 |
+
# Output: aligned array of SRTsegment corresponding to path1 path2
|
32 |
+
# Note: Modify comment with .source_text to get output array with string only
|
33 |
+
def alignment_obsolete(pred_path, gt_path):
|
34 |
+
empt = SrtSegment([0,'00:00:00,000 --> 00:00:00,000','','',''])
|
35 |
+
pred = SrtScript.parse_from_srt_file(pred_path).segments
|
36 |
+
gt = SrtScript.parse_from_srt_file(gt_path).segments
|
37 |
+
pred_arr, gt_arr = [], []
|
38 |
+
idx_p, idx_t = 0, 0 # idx_p: current index of pred segment, idx_t for ground truth
|
39 |
+
|
40 |
+
while idx_p < len(pred) or idx_t < len(gt):
|
41 |
+
# Check if one srt file runs out while reading
|
42 |
+
ps = pred[idx_p] if idx_p < len(pred) else None
|
43 |
+
gs = gt[idx_t] if idx_t < len(gt) else None
|
44 |
+
|
45 |
+
if not ps:
|
46 |
+
# If ps runs out, align gs segment with filler one by one
|
47 |
+
gt_arr.append(gs)#.source_text
|
48 |
+
pred_arr.append(empt)
|
49 |
+
idx_t += 1
|
50 |
+
continue
|
51 |
+
|
52 |
+
if not gs:
|
53 |
+
# If gs runs out, align ps segment with filler one by one
|
54 |
+
pred_arr.append(ps)#.source_text
|
55 |
+
gt_arr.append(empt)
|
56 |
+
idx_p += 1
|
57 |
+
continue
|
58 |
+
|
59 |
+
ps_dur = ps.end - ps.start
|
60 |
+
gs_dur = gs.end - gs.start
|
61 |
+
|
62 |
+
# Check for duration to decide anchor and sub
|
63 |
+
if ps_dur <= gs_dur:
|
64 |
+
# Detect segment with no overlap
|
65 |
+
if ps.end < gs.start:
|
66 |
+
pred_arr.append(ps)#.source_text
|
67 |
+
gt_arr.append(empt) # append filler
|
68 |
+
idx_t -= 1 # reset ground truth index
|
69 |
+
else:
|
70 |
+
|
71 |
+
if gs.end >= ps.start:
|
72 |
+
gt_arr.append(gs)#.source_text
|
73 |
+
pred_arr.append(ps)#.source_text
|
74 |
+
idx_p = procedure(gs, pred, pred_arr, idx_p + 1)
|
75 |
+
else:
|
76 |
+
gt_arr[len(gt_arr) - 1] += gs#.source_text
|
77 |
+
#pred_arr.append(empt)
|
78 |
+
idx_p -= 1
|
79 |
+
else:
|
80 |
+
# same overlap checking procedure
|
81 |
+
if gs.end < ps.start:
|
82 |
+
gt_arr.append(gs)#.source_text
|
83 |
+
pred_arr.append(empt) # filler
|
84 |
+
idx_p -= 1 # reset
|
85 |
+
else:
|
86 |
+
if ps.end >= gs.start:
|
87 |
+
pred_arr.append(ps)#.source_text
|
88 |
+
gt_arr.append(gs)#.source_text
|
89 |
+
idx_t = procedure(ps, gt, gt_arr, idx_t + 1)
|
90 |
+
else: # filler pairing
|
91 |
+
pred_arr[len(pred_arr) - 1] += ps
|
92 |
+
idx_t -= 1
|
93 |
+
|
94 |
+
idx_p += 1
|
95 |
+
idx_t += 1
|
96 |
+
#for a in gt_arr:
|
97 |
+
# print(a.translation)
|
98 |
+
return zip(pred_arr, gt_arr)
|
99 |
+
|
100 |
+
# Input: path1, path2, threshold = 0.5 sec by default
|
101 |
+
# Output: aligned array of SRTsegment corresponding to path1 path2
|
102 |
+
def alignment(pred_path, gt_path, threshold=0.5):
|
103 |
+
empt = SrtSegment([0, '00:00:00,000 --> 00:00:00,000', '', '', ''])
|
104 |
+
pred = SrtScript.parse_from_srt_file(pred_path).segments
|
105 |
+
gt = SrtScript.parse_from_srt_file(gt_path).segments
|
106 |
+
pred_arr, gt_arr = [], []
|
107 |
+
idx_p, idx_t = 0, 0
|
108 |
+
|
109 |
+
while idx_p < len(pred) or idx_t < len(gt):
|
110 |
+
ps = pred[idx_p] if idx_p < len(pred) else empt
|
111 |
+
gs = gt[idx_t] if idx_t < len(gt) else empt
|
112 |
+
|
113 |
+
# Merging sequence for pred
|
114 |
+
while idx_p + 1 < len(pred) and pred[idx_p + 1].end <= gs.end + threshold:
|
115 |
+
ps += pred[idx_p + 1]
|
116 |
+
idx_p += 1
|
117 |
+
|
118 |
+
# Merging sequence for gt
|
119 |
+
while idx_t + 1 < len(gt) and gt[idx_t + 1].end <= ps.end + threshold:
|
120 |
+
gs += gt[idx_t + 1]
|
121 |
+
idx_t += 1
|
122 |
+
|
123 |
+
# Append to the result arrays
|
124 |
+
pred_arr.append(ps)
|
125 |
+
gt_arr.append(gs)
|
126 |
+
idx_p += 1
|
127 |
+
idx_t += 1
|
128 |
+
|
129 |
+
|
130 |
+
#for a in pred_arr:
|
131 |
+
# print(a.translation)
|
132 |
+
#for a in gt_arr:
|
133 |
+
# print(a.source_text)
|
134 |
+
|
135 |
+
return zip(pred_arr, gt_arr)
|
136 |
+
|
137 |
+
|
138 |
+
# Test Case
|
139 |
+
#alignment('test_translation_s2.srt', 'test_translation_zh.srt')
|
evaluation/evaluation.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pandas as pd
|
3 |
+
from alignment import alignment
|
4 |
+
from scores.multi_scores import multi_scores
|
5 |
+
|
6 |
+
class Evaluator:
|
7 |
+
def __init__(self, pred_path, gt_path, eval_path, res_path):
|
8 |
+
self.pred_path = pred_path
|
9 |
+
self.gt_path = gt_path
|
10 |
+
self.eval_path = eval_path
|
11 |
+
self.res_path = res_path
|
12 |
+
|
13 |
+
def eval(self):
|
14 |
+
# Align two SRT files
|
15 |
+
aligned_srt = alignment(self.pred_path, self.gt_path)
|
16 |
+
|
17 |
+
# Get sentence scores
|
18 |
+
scorer = multi_scores()
|
19 |
+
result_data = []
|
20 |
+
for (pred_s, gt_s) in aligned_srt:
|
21 |
+
print("pred_s.source_text: ", pred_s.source_text)
|
22 |
+
print("pred_s.translation: ", pred_s.translation)
|
23 |
+
print("gt_s.source_text: ", gt_s.source_text)
|
24 |
+
|
25 |
+
scores_dict = scorer.get_scores(pred_s.source_text, pred_s.translation, gt_s.source_text)
|
26 |
+
print("scores_dict: ", scores_dict)
|
27 |
+
|
28 |
+
scores_dict['Source'] = pred_s.source_text
|
29 |
+
scores_dict['Prediction'] = pred_s.translation
|
30 |
+
scores_dict['Ground Truth'] = gt_s.source_text
|
31 |
+
result_data.append(scores_dict)
|
32 |
+
|
33 |
+
eval_df = pd.DataFrame(result_data)
|
34 |
+
eval_df.to_csv(self.eval_path, index=False, columns=['Source', 'Prediction', 'Ground Truth', 'bleu_score', 'comet_score', 'llm_score', 'llm_explanation'])
|
35 |
+
|
36 |
+
# Get average scores
|
37 |
+
avg_llm = eval_df['llm_score'].mean()
|
38 |
+
avg_bleu = eval_df['bleu_score'].mean()
|
39 |
+
avg_comet = eval_df['comet_score'].mean()
|
40 |
+
|
41 |
+
res_data = {
|
42 |
+
'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
|
43 |
+
'Score': [avg_llm, avg_bleu, avg_comet]
|
44 |
+
}
|
45 |
+
res_df = pd.DataFrame(res_data)
|
46 |
+
res_df.to_csv(self.res_path, index=False)
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = argparse.ArgumentParser(description='Evaluate SRT files.')
|
50 |
+
parser.add_argument('-bi_path', default='evaluation/test5_tiny/test5_bi.srt', help='Path to predicted SRT file')
|
51 |
+
parser.add_argument('-zh_path', default='evaluation/test5_tiny/test5_gt.srt', help='Path to ground truth SRT file')
|
52 |
+
parser.add_argument('-eval_output', default='evaluation/test5_tiny/eval.csv', help='Path to eval CSV file')
|
53 |
+
parser.add_argument('-res_output', default='evaluation/test5_tiny/res.csv', help='Path to result CSV file')
|
54 |
+
args = parser.parse_args()
|
55 |
+
|
56 |
+
evaluator = Evaluator(args.bi_path, args.zh_path, args.eval_output, args.res_output)
|
57 |
+
evaluator.eval()
|
58 |
+
|
evaluation/readme.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Evaluation:
|
2 |
+
BLEU (https://github.com/mjpost/sacrebleu)
|
3 |
+
COMET (https://github.com/Unbabel/COMET)
|
4 |
+
LLM eval
|
5 |
+
Eval time stamp
|
6 |
+
|
7 |
+
Sep 18 - Sep 25
|
8 |
+
Proj-t
|
9 |
+
src
|
10 |
+
evaluation
|
11 |
+
- scores
|
12 |
+
- LLM_eval.py (jiaen)
|
13 |
+
- scores.py (wizard)
|
14 |
+
- comet
|
15 |
+
- sacrebleu
|
16 |
+
- alignment.py (david)
|
17 |
+
- evaluation.py (not assigned)
|
18 |
+
- results
|
19 |
+
- mmddyy-HMS-results.csv
|
20 |
+
- logs
|
21 |
+
|
22 |
+
entry:
|
23 |
+
Python3 evaluation/evaluation.py –pred path/to/pred –gt path/to/gt
|
24 |
+
|
evaluation/scores/LLM_eval.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# This script is used to evaluate the performance of Pigeon AI Video Translation system by using Large Language Model.
|
3 |
+
|
4 |
+
# Written by Jiaen LIU, 2023/09/18
|
5 |
+
|
6 |
+
# Import the necessary packages
|
7 |
+
import re
|
8 |
+
from langchain.evaluation import load_evaluator, EvaluatorType
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
+
from langchain.chat_models import ChatOpenAI
|
11 |
+
# from src.srt_util.srt import SrtScript
|
12 |
+
|
13 |
+
# Load the evaluator
|
14 |
+
|
15 |
+
def init_evaluator(source_lang="en", target_lang="zh", domain="startcraft2", model="gpt-4-0613"):
|
16 |
+
|
17 |
+
# map the language code to the language name
|
18 |
+
language_map = {
|
19 |
+
"en": "English",
|
20 |
+
"zh": "Chinese",
|
21 |
+
}
|
22 |
+
|
23 |
+
llm = ChatOpenAI(temperature=0, model=model)
|
24 |
+
|
25 |
+
# Completeness is the percentage of the input that is translated
|
26 |
+
# Accuracy is the percentage of the translation that is correct
|
27 |
+
fstring = """
|
28 |
+
You are grading the translation based on following input:
|
29 |
+
{input}
|
30 |
+
if the input is "", that means there is no input sentence.
|
31 |
+
you should grade the translation based on the reference translation:
|
32 |
+
Here is the real answer(reference):
|
33 |
+
{reference}
|
34 |
+
You are grading the following translation:
|
35 |
+
{output}
|
36 |
+
based on the following criteria:
|
37 |
+
{criteria}
|
38 |
+
Give two grades, accuracy and completeness rate them from a scale of 0 to 100, where 0 is the lowest (very low accuracy/completeness) and 100 is the highest (very high accuracy/completeness)?
|
39 |
+
Give explanations for every single one and if the answer if partially correct that is acceptable. However punish the scores for answers that are
|
40 |
+
numerically incorrect this also includes values that have the $ in front
|
41 |
+
Please give the completeness score first followed by the accuracy score.
|
42 |
+
For example:
|
43 |
+
Accuracy: 40. Explanation here
|
44 |
+
Completeness: 80. Explanation here
|
45 |
+
Do not differ from the format ever
|
46 |
+
"""
|
47 |
+
|
48 |
+
if source_lang in language_map and target_lang in language_map:
|
49 |
+
lang_str = f"You are an expert {language_map[source_lang]} to {language_map[target_lang]} translator specialized in {domain}."
|
50 |
+
prompt = PromptTemplate.from_template(lang_str+fstring, template_format="f-string")
|
51 |
+
|
52 |
+
else:
|
53 |
+
print("The language code is not supported, please check the language code.")
|
54 |
+
prompt = PromptTemplate.from_template(fstring, template_format="f-string")
|
55 |
+
|
56 |
+
return load_evaluator("labeled_criteria", llm=llm, prompt=prompt, criteria="correctness")
|
57 |
+
|
58 |
+
# prase the output of the evaluation
|
59 |
+
# example :
|
60 |
+
# 'value': 'Accuracy: 80. The predicted answer is partially correct. The sentence "这是一个测试句子" translates to "This is a test sentence" in English. However, the original sentence is "This is an test sentences" which is grammatically incorrect in English. The correct translation should be "这是一个测试句子" if we correct the English sentence to "This is a test sentence". Therefore, the predicted answer is not entirely wrong, but it does not match the original sentence exactly due to the grammatical error in the original sentence.'
|
61 |
+
# def parse_eval_result(eval_result):
|
62 |
+
# # score = eval_result.score
|
63 |
+
# value = eval_result["value"]
|
64 |
+
# value = value.split("Accuracy: ")[1].split(".")
|
65 |
+
# # combine the rest of the string into the whole explanation
|
66 |
+
# explanation = ".".join(value[1:])
|
67 |
+
# return int(value[0]), explanation
|
68 |
+
|
69 |
+
# def parse_eval_result(eval_result):
|
70 |
+
# # Extract the 'Accuracy' score using a regular expression from the 'reasoning' key
|
71 |
+
# accuracy_match = re.search(r'Accuracy: (\d+)', eval_result['value'])
|
72 |
+
# print(accuracy_match)
|
73 |
+
# if accuracy_match:
|
74 |
+
# accuracy = int(accuracy_match.group(1))
|
75 |
+
# else:
|
76 |
+
# # try to get the accuracy from the 'value' key
|
77 |
+
# accuracy = 0
|
78 |
+
|
79 |
+
# # Directly get the 'Explanation' value from the 'value' key
|
80 |
+
# explanation = eval_result['value']
|
81 |
+
|
82 |
+
# return accuracy, explanation
|
83 |
+
|
84 |
+
def parse_eval_result(data):
|
85 |
+
# Extract the value string
|
86 |
+
value_str = data.get('value', '')
|
87 |
+
reasoning_str = data.get('reasoning', '')
|
88 |
+
|
89 |
+
# Use regex to extract accuracy value and explanation
|
90 |
+
accuracy_match = re.search(r'Accuracy: (\d+)', value_str)
|
91 |
+
acc_explanation_match = re.search(r'Accuracy: \d+\. (.+)', value_str)
|
92 |
+
|
93 |
+
# Use regex to extract completeness value and explanation
|
94 |
+
completeness_match = re.search(r'Completeness: (\d+)', reasoning_str)
|
95 |
+
completeness_explanation_match = re.search(r'Completeness: \d+\. (.+)', reasoning_str)
|
96 |
+
|
97 |
+
# Extract the matched groups
|
98 |
+
completeness = int(completeness_match.group(1)) if completeness_match else None
|
99 |
+
completeness_explanation = completeness_explanation_match.group(1) if completeness_explanation_match else None
|
100 |
+
accuracy = int(accuracy_match.group(1)) if accuracy_match else None
|
101 |
+
acc_explanation = acc_explanation_match.group(1) if acc_explanation_match else None
|
102 |
+
|
103 |
+
return (accuracy, acc_explanation), (completeness, completeness_explanation)
|
104 |
+
|
105 |
+
def evaluate_prediction(input, reference, prediction, evaluator):
|
106 |
+
eval_result = evaluator.evaluate_strings(
|
107 |
+
prediction=prediction,
|
108 |
+
input=input,
|
109 |
+
reference=reference,
|
110 |
+
)
|
111 |
+
# print(eval_result)
|
112 |
+
return parse_eval_result(eval_result)
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
evaluator = init_evaluator()
|
116 |
+
# For no input english sentence, just put "" in the input
|
117 |
+
accuracy, completeness = evaluate_prediction("this is an test sentences", "这不是一个测试语句。", "这是一个测试句子。", evaluator)
|
118 |
+
print("Accuracy:", accuracy[0])
|
119 |
+
print("Acc_Explanation:", accuracy[1])
|
120 |
+
print("Completeness:", completeness[0])
|
121 |
+
print("Comp_Explanation:", completeness[1])
|
evaluation/scores/__init__.py
ADDED
File without changes
|
evaluation/scores/multi_scores.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from comet import download_model, load_from_checkpoint
|
2 |
+
from sacrebleu.metrics import BLEU, CHRF, TER
|
3 |
+
from scores import LLM_eval
|
4 |
+
# import LLM_eval
|
5 |
+
|
6 |
+
class multi_scores:
|
7 |
+
def __init__(self, source_lang="en", target_lang="zh", domain="starcraft 2") -> None:
|
8 |
+
self.comet_model = load_from_checkpoint(download_model("Unbabel/wmt22-comet-da"))
|
9 |
+
self.bleu_model = BLEU(tokenize=target_lang)
|
10 |
+
self.LLM_model = LLM_eval.init_evaluator(source_lang=source_lang, target_lang=target_lang, domain=domain)
|
11 |
+
# self.score = {}
|
12 |
+
|
13 |
+
def __preprocess(self, src:str, mt:str, ref:str) -> dict:
|
14 |
+
# remove the space in the beginning and end of the sentence\
|
15 |
+
src = src.strip()
|
16 |
+
mt = mt.strip()
|
17 |
+
ref = ref.strip()
|
18 |
+
print(src, mt, ref)
|
19 |
+
return {'src':src, 'mt':mt, 'ref':ref}
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# The function to get the scores
|
24 |
+
# src: orginal sentence
|
25 |
+
# mt: machine translation
|
26 |
+
# ref: reference translation
|
27 |
+
def calculate_comet_llm(self, src:str, mt:str, ref:str) -> dict:
|
28 |
+
# preprocess the input
|
29 |
+
src, mt, ref = self.__preprocess(src, mt, ref)
|
30 |
+
comet_score = self.comet_model.predict([{"src":src, "mt":mt, "ref":ref}], batch_size=8, gpus=0).scores[0]
|
31 |
+
# bleu_score = self.bleu_model.corpus_score([mt], [ref]).score
|
32 |
+
llm_acc, llm_completeness = LLM_eval.evaluate_prediction(src, ref, mt, self.LLM_model)
|
33 |
+
return {'comet_score':comet_score, 'llm_score':llm_acc[0], 'llm_explanation': llm_acc[1]}
|
34 |
+
# self.score['bleu_score'] = bleu_score
|
35 |
+
# self.score['comet_score'] = comet_score
|
36 |
+
# self.score['llm_score'] = llm_score
|
37 |
+
# self.score['llm_explanation'] = llm_explanation
|
38 |
+
|
39 |
+
def calculate_bleu(self, mts:list, refs:list) -> dict:
|
40 |
+
# src, mt, ref = self.__preprocess(src, mt, ref)
|
41 |
+
# remove the space in the beginning and end of the sentence for each sentence
|
42 |
+
# mts = [mt.strip() for mt in mts]
|
43 |
+
# refs = [ref.strip() for ref in refs]
|
44 |
+
# print(mts, refs)
|
45 |
+
# mt and ref are list of sentences
|
46 |
+
bleu_score = self.bleu_model.corpus_score(mts, refs).score
|
47 |
+
return {'bleu_score':bleu_score}
|
48 |
+
|
49 |
+
def get_scores(self, src:str, mt:str, ref:str) -> dict:
|
50 |
+
comet_score = self.comet_model.predict([{"src":src, "mt":mt, "ref":ref}], batch_size=8, gpus=0).scores[0]
|
51 |
+
bleu_score = self.bleu_model.corpus_score([mt], [[ref]]).score
|
52 |
+
llm_acc, llm_completeness = LLM_eval.evaluate_prediction(src, ref, mt, self.LLM_model)
|
53 |
+
return {'bleu_score':bleu_score ,'comet_score':comet_score, 'llm_score':llm_acc[0], 'llm_explanation': llm_acc[1]}
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
src = "South Korea playing with the Blue Proto's Probes"
|
58 |
+
mt = "位于对角线的另一个角落 使用蓝色的Proto's Probes"
|
59 |
+
ref = " 在对角落里使用蓝色神族探机 他的名字是..."
|
60 |
+
# print(multi_scores().get_scores(src, mt, ref))
|
61 |
+
# print(multi_scores().calculate_comet_llm(src, mt, ref))
|
62 |
+
print(multi_scores().calculate_bleu([mt], [[ref]]))
|
63 |
+
|
evaluation/scores/score.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from comet import download_model, load_from_checkpoint
|
2 |
+
from sacrebleu.metrics import BLEU, CHRF, TER
|
3 |
+
|
4 |
+
def COMETscore(src, mt, ref):
|
5 |
+
data = []
|
6 |
+
for i in enumerate(src):
|
7 |
+
data.append({"src":src[i], "mt":mt[i], "ref":ref[i]})
|
8 |
+
model_path = download_model("Unbabel/wmt22-comet-da")
|
9 |
+
model = load_from_checkpoint(model_path)
|
10 |
+
model_output = model.predict(data, batch_size = 8, gpus=0)
|
11 |
+
return model_output
|
12 |
+
|
13 |
+
def BLEUscore(sys, refs):
|
14 |
+
bleu = BLEU()
|
15 |
+
return bleu.corpus_score(sys, refs)
|
requirement.txt
CHANGED
@@ -38,3 +38,5 @@ tqdm==4.65.0
|
|
38 |
typing_extensions==4.5.0
|
39 |
urllib3==1.26.15
|
40 |
yarl==1.8.2
|
|
|
|
|
|
38 |
typing_extensions==4.5.0
|
39 |
urllib3==1.26.15
|
40 |
yarl==1.8.2
|
41 |
+
sacrebleu==2.3.1
|
42 |
+
unbabel-comet==2.1.0
|
src/srt_util/srt.py
CHANGED
@@ -50,7 +50,10 @@ class SrtSegment(object):
|
|
50 |
self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
|
51 |
end_list = self.end_time_str.split(',')[0].split(':')
|
52 |
self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
def merge_seg(self, seg):
|
56 |
"""
|
@@ -105,10 +108,16 @@ class SrtScript(object):
|
|
105 |
def parse_from_srt_file(cls, path: str):
|
106 |
with open(path, 'r', encoding="utf-8") as f:
|
107 |
script_lines = [line.rstrip() for line in f.readlines()]
|
108 |
-
|
|
|
|
|
109 |
segments = []
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
|
113 |
return cls(segments)
|
114 |
|
|
|
50 |
self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
|
51 |
end_list = self.end_time_str.split(',')[0].split(':')
|
52 |
self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
|
53 |
+
if len(args[0]) < 5:
|
54 |
+
self.translation = ""
|
55 |
+
else:
|
56 |
+
self.translation = args[0][3]
|
57 |
|
58 |
def merge_seg(self, seg):
|
59 |
"""
|
|
|
108 |
def parse_from_srt_file(cls, path: str):
|
109 |
with open(path, 'r', encoding="utf-8") as f:
|
110 |
script_lines = [line.rstrip() for line in f.readlines()]
|
111 |
+
bilingual = False
|
112 |
+
if script_lines[2] != '' and script_lines[3] != '':
|
113 |
+
bilingual = True
|
114 |
segments = []
|
115 |
+
if bilingual:
|
116 |
+
for i in range(0, len(script_lines), 5):
|
117 |
+
segments.append(list(script_lines[i:i + 5]))
|
118 |
+
else:
|
119 |
+
for i in range(0, len(script_lines), 4):
|
120 |
+
segments.append(list(script_lines[i:i + 4]))
|
121 |
|
122 |
return cls(segments)
|
123 |
|