|
""" |
|
BERT Score |
|
--------------------- |
|
BERT Score is introduced in this paper (BERTScore: Evaluating Text Generation with BERT) `arxiv link`_. |
|
|
|
.. _arxiv link: https://arxiv.org/abs/1904.09675 |
|
|
|
BERT Score measures token similarity between two text using contextual embedding. |
|
|
|
To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text. |
|
|
|
""" |
|
|
|
import bert_score |
|
|
|
from textattack.constraints import Constraint |
|
from textattack.shared import utils |
|
|
|
|
|
class BERTScore(Constraint): |
|
"""A constraint on BERT-Score difference. |
|
|
|
Args: |
|
min_bert_score (float), minimum threshold value for BERT-Score |
|
model_name (str), name of model to use for scoring |
|
num_layers (int), number of hidden layers in the model |
|
score_type (str), Pick one of following three choices |
|
|
|
-(1) ``precision`` : match words from candidate text to reference text |
|
-(2) ``recall`` : match words from reference text to candidate text |
|
-(3) ``f1``: harmonic mean of precision and recall (recommended) |
|
|
|
compare_against_original (bool): |
|
If ``True``, compare new ``x_adv`` against the original ``x``. |
|
Otherwise, compare it against the previous ``x_adv``. |
|
""" |
|
|
|
SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2} |
|
|
|
def __init__( |
|
self, |
|
min_bert_score, |
|
model_name="bert-base-uncased", |
|
num_layers=None, |
|
score_type="f1", |
|
compare_against_original=True, |
|
): |
|
super().__init__(compare_against_original) |
|
if not isinstance(min_bert_score, float): |
|
raise TypeError("max_bert_score must be a float") |
|
if min_bert_score < 0.0 or min_bert_score > 1.0: |
|
raise ValueError("max_bert_score must be a value between 0.0 and 1.0") |
|
|
|
self.min_bert_score = min_bert_score |
|
self.model = model_name |
|
self.score_type = score_type |
|
|
|
self._bert_scorer = bert_score.BERTScorer( |
|
model_type=model_name, idf=False, device=utils.device, num_layers=num_layers |
|
) |
|
|
|
def _check_constraint(self, transformed_text, reference_text): |
|
"""Return `True` if BERT Score between `transformed_text` and |
|
`reference_text` is lower than minimum BERT Score.""" |
|
cand = transformed_text.text |
|
ref = reference_text.text |
|
result = self._bert_scorer.score([cand], [ref]) |
|
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() |
|
if score >= self.min_bert_score: |
|
return True |
|
else: |
|
return False |
|
|
|
def extra_repr_keys(self): |
|
return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys() |
|
|