|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
import collections |
|
from math import sqrt |
|
|
|
import scipy.stats |
|
|
|
import torch |
|
from torch import Tensor |
|
from tokenizers import Tokenizer |
|
from transformers import LogitsProcessor |
|
|
|
from nltk.util import ngrams |
|
|
|
from normalizers import normalization_strategy_lookup |
|
|
|
class WatermarkBase: |
|
def __init__( |
|
self, |
|
vocab: list[int] = None, |
|
gamma: float = 0.5, |
|
delta: float = 2.0, |
|
seeding_scheme: str = "simple_1", |
|
hash_key: int = 15485863, |
|
select_green_tokens: bool = True, |
|
): |
|
|
|
|
|
self.vocab = vocab |
|
self.vocab_size = len(vocab) |
|
self.gamma = gamma |
|
self.delta = delta |
|
self.seeding_scheme = seeding_scheme |
|
self.rng = None |
|
self.hash_key = hash_key |
|
self.select_green_tokens = select_green_tokens |
|
|
|
def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None: |
|
|
|
|
|
if seeding_scheme is None: |
|
seeding_scheme = self.seeding_scheme |
|
|
|
if seeding_scheme == "simple_1": |
|
assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng" |
|
prev_token = input_ids[-1].item() |
|
self.rng.manual_seed(self.hash_key * prev_token) |
|
else: |
|
raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}") |
|
return |
|
|
|
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: |
|
|
|
|
|
self._seed_rng(input_ids) |
|
|
|
greenlist_size = int(self.vocab_size * self.gamma) |
|
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng) |
|
if self.select_green_tokens: |
|
greenlist_ids = vocab_permutation[:greenlist_size] |
|
else: |
|
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] |
|
return greenlist_ids |
|
|
|
|
|
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: |
|
|
|
green_tokens_mask = torch.zeros_like(scores) |
|
for b_idx in range(len(greenlist_token_ids)): |
|
green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1 |
|
final_mask = green_tokens_mask.bool() |
|
return final_mask |
|
|
|
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor: |
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias |
|
return scores |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
|
|
if self.rng is None: |
|
self.rng = torch.Generator(device=input_ids.device) |
|
|
|
|
|
|
|
|
|
batched_greenlist_ids = [None for _ in range(input_ids.shape[0])] |
|
|
|
for b_idx in range(input_ids.shape[0]): |
|
greenlist_ids = self._get_greenlist_ids(input_ids[b_idx]) |
|
batched_greenlist_ids[b_idx] = greenlist_ids |
|
|
|
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids) |
|
|
|
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta) |
|
return scores |
|
|
|
|
|
class WatermarkDetector(WatermarkBase): |
|
def __init__( |
|
self, |
|
*args, |
|
device: torch.device = None, |
|
tokenizer: Tokenizer = None, |
|
z_threshold: float = 4.0, |
|
normalizers: list[str] = ["unicode"], |
|
ignore_repeated_bigrams: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__(*args, **kwargs) |
|
|
|
assert device, "Must pass device" |
|
assert tokenizer, "Need an instance of the generating tokenizer to perform detection" |
|
|
|
self.tokenizer = tokenizer |
|
self.device = device |
|
self.z_threshold = z_threshold |
|
self.rng = torch.Generator(device=self.device) |
|
|
|
if self.seeding_scheme == "simple_1": |
|
self.min_prefix_len = 1 |
|
else: |
|
raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}") |
|
|
|
self.normalizers = [] |
|
for normalization_strategy in normalizers: |
|
self.normalizers.append(normalization_strategy_lookup(normalization_strategy)) |
|
|
|
self.ignore_repeated_bigrams = ignore_repeated_bigrams |
|
if self.ignore_repeated_bigrams: |
|
assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme." |
|
|
|
|
|
def _compute_z_score(self, observed_count, T): |
|
|
|
expected_count = self.gamma |
|
numer = observed_count - expected_count * T |
|
denom = sqrt(T * expected_count * (1 - expected_count)) |
|
z = numer / denom |
|
return z |
|
|
|
def _compute_p_value(self, z): |
|
p_value = scipy.stats.norm.sf(z) |
|
return p_value |
|
|
|
def _score_sequence( |
|
self, |
|
input_ids: Tensor, |
|
return_num_tokens_scored: bool = True, |
|
return_num_green_tokens: bool = True, |
|
return_green_fraction: bool = True, |
|
return_green_token_mask: bool = False, |
|
return_z_score: bool = True, |
|
return_p_value: bool = True, |
|
): |
|
if self.ignore_repeated_bigrams: |
|
|
|
|
|
|
|
|
|
|
|
assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats." |
|
bigram_table = {} |
|
token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2) |
|
freq = collections.Counter(token_bigram_generator) |
|
num_tokens_scored = len(freq.keys()) |
|
for idx, bigram in enumerate(freq.keys()): |
|
prefix = torch.tensor([bigram[0]], device=self.device) |
|
greenlist_ids = self._get_greenlist_ids(prefix) |
|
bigram_table[bigram] = True if bigram[1] in greenlist_ids else False |
|
green_token_count = sum(bigram_table.values()) |
|
else: |
|
num_tokens_scored = len(input_ids) - self.min_prefix_len |
|
if num_tokens_scored < 1: |
|
raise ValueError((f"Must have at least {1} token to score after " |
|
f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme.")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
green_token_count, green_token_mask = 0, [] |
|
for idx in range(self.min_prefix_len, len(input_ids)): |
|
curr_token = input_ids[idx] |
|
greenlist_ids = self._get_greenlist_ids(input_ids[:idx]) |
|
if curr_token in greenlist_ids: |
|
green_token_count += 1 |
|
green_token_mask.append(True) |
|
else: |
|
green_token_mask.append(False) |
|
|
|
score_dict = dict() |
|
if return_num_tokens_scored: |
|
score_dict.update(dict(num_tokens_scored=num_tokens_scored)) |
|
if return_num_green_tokens: |
|
score_dict.update(dict(num_green_tokens=green_token_count)) |
|
if return_green_fraction: |
|
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored))) |
|
if return_z_score: |
|
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored))) |
|
if return_p_value: |
|
z_score = score_dict.get("z_score") |
|
if z_score is None: |
|
z_score = self._compute_z_score(green_token_count, num_tokens_scored) |
|
score_dict.update(dict(p_value=self._compute_p_value(z_score))) |
|
if return_green_token_mask: |
|
score_dict.update(dict(green_token_mask=green_token_mask)) |
|
|
|
return score_dict |
|
|
|
def detect( |
|
self, |
|
text: str = None, |
|
tokenized_text: list[int] = None, |
|
return_prediction: bool = True, |
|
return_scores: bool = True, |
|
z_threshold: float = None, |
|
**kwargs, |
|
) -> dict: |
|
|
|
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string" |
|
if return_prediction: |
|
kwargs["return_p_value"] = True |
|
|
|
|
|
for normalizer in self.normalizers: |
|
text = normalizer(text) |
|
if len(self.normalizers) > 0: |
|
print(f"Text after normalization:\n\n{text}\n") |
|
|
|
if tokenized_text is None: |
|
assert self.tokenizer is not None, ( |
|
"Watermark detection on raw string ", |
|
"requires an instance of the tokenizer ", |
|
"that was used at generation time.", |
|
) |
|
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device) |
|
if tokenized_text[0] == self.tokenizer.bos_token_id: |
|
tokenized_text = tokenized_text[1:] |
|
else: |
|
|
|
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): |
|
tokenized_text = tokenized_text[1:] |
|
|
|
|
|
output_dict = {} |
|
score_dict = self._score_sequence(tokenized_text, **kwargs) |
|
if return_scores: |
|
output_dict.update(score_dict) |
|
|
|
if return_prediction: |
|
z_threshold = z_threshold if z_threshold else self.z_threshold |
|
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" |
|
output_dict["prediction"] = score_dict["z_score"] > z_threshold |
|
if output_dict["prediction"]: |
|
output_dict["confidence"] = 1 - score_dict["p_value"] |
|
|
|
return output_dict |
|
|
|
|