LMW / demo_watermark.py
NULLNode's picture
Upload 6 files
893a1df
raw
history blame
96 kB
# # coding=utf-8
# # Copyright 2023 Authors of "A Watermark for Large Language Models"
# # available at https://arxiv.org/abs/2301.10226
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# from __future__ import annotations
# import os
# import argparse
# from argparse import Namespace
# from pprint import pprint
# from functools import partial
#
# import numpy # for gradio hot reload
# import gradio as gr
#
# import torch
#
# from transformers import (AutoTokenizer,
# AutoModelForSeq2SeqLM,
# AutoModelForCausalLM,
# LogitsProcessorList)
#
# from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
#
# from typing import Iterable
# from gradio.themes.base import Base
# from gradio.themes.utils import colors, fonts, sizes
# import time
#
# def str2bool(v):
# """Util function for user friendly boolean flag args"""
# if isinstance(v, bool):
# return v
# if v.lower() in ('yes', 'true', 't', 'y', '1'):
# return True
# elif v.lower() in ('no', 'false', 'f', 'n', '0'):
# return False
# else:
# raise argparse.ArgumentTypeError('Boolean value expected.')
#
# def parse_args():
# """Command line argument specification"""
#
# parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
#
# parser.add_argument(
# "--run_gradio",
# type=str2bool,
# default=True,
# help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
# )
# parser.add_argument(
# "--demo_public",
# type=str2bool,
# default=False,
# help="Whether to expose the gradio demo to the internet.",
# )
# parser.add_argument(
# "--model_name_or_path",
# type=str,
# default="facebook/opt-6.7b",
# help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
# )
# parser.add_argument(
# "--prompt_max_length",
# type=int,
# default=None,
# help="Truncation length for prompt, overrides model config's max length field.",
# )
# parser.add_argument(
# "--max_new_tokens",
# type=int,
# default=200,
# help="Maximmum number of new tokens to generate.",
# )
# parser.add_argument(
# "--generation_seed",
# type=int,
# default=123,
# help="Seed for setting the torch global rng prior to generation.",
# )
# parser.add_argument(
# "--use_sampling",
# type=str2bool,
# default=True,
# help="Whether to generate using multinomial sampling.",
# )
# parser.add_argument(
# "--sampling_temp",
# type=float,
# default=0.7,
# help="Sampling temperature to use when generating using multinomial sampling.",
# )
# parser.add_argument(
# "--n_beams",
# type=int,
# default=1,
# help="Number of beams to use for beam search. 1 is normal greedy decoding",
# )
# parser.add_argument(
# "--use_gpu",
# type=str2bool,
# default=True,
# help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
# )
# parser.add_argument(
# "--seeding_scheme",
# type=str,
# default="simple_1",
# help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
# )
# parser.add_argument(
# "--gamma",
# type=float,
# default=0.25,
# help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
# )
# parser.add_argument(
# "--delta",
# type=float,
# default=2.0,
# help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
# )
# parser.add_argument(
# "--normalizers",
# type=str,
# default="",
# help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
# )
# parser.add_argument(
# "--ignore_repeated_bigrams",
# type=str2bool,
# default=False,
# help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
# )
# parser.add_argument(
# "--detection_z_threshold",
# type=float,
# default=4.0,
# help="The test statistic threshold for the detection hypothesis test.",
# )
# parser.add_argument(
# "--select_green_tokens",
# type=str2bool,
# default=True,
# help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
# )
# parser.add_argument(
# "--skip_model_load",
# type=str2bool,
# default=False,
# help="Skip the model loading to debug the interface.",
# )
# parser.add_argument(
# "--seed_separately",
# type=str2bool,
# default=True,
# help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
# )
# parser.add_argument(
# "--load_fp16",
# type=str2bool,
# default=False,
# help="Whether to run model in float16 precsion.",
# )
# args = parser.parse_args()
# return args
#
# def load_model(args):
# """Load and return the model and tokenizer"""
#
# args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
# args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
# if args.is_seq2seq_model:
# model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
# elif args.is_decoder_only_model:
# if args.load_fp16:
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
# else:
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
# else:
# raise ValueError(f"Unknown model type: {args.model_name_or_path}")
#
# if args.use_gpu:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# if args.load_fp16:
# pass
# else:
# model = model.to(device)
# else:
# device = "cpu"
# model.eval()
#
# tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
#
# return model, tokenizer, device
#
# def generate(prompt, args, model=None, device=None, tokenizer=None):
# """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
# and generate watermarked text by passing it to the generate method of the model
# as a logits processor. """
#
# print(f"Generating with {args}")
#
# watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
# gamma=args.gamma,
# delta=args.delta,
# seeding_scheme=args.seeding_scheme,
# select_green_tokens=args.select_green_tokens)
#
# gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
#
# if args.use_sampling:
# gen_kwargs.update(dict(
# do_sample=True,
# top_k=0,
# temperature=args.sampling_temp
# ))
# else:
# gen_kwargs.update(dict(
# num_beams=args.n_beams
# ))
#
# generate_without_watermark = partial(
# model.generate,
# **gen_kwargs
# )
# generate_with_watermark = partial(
# model.generate,
# logits_processor=LogitsProcessorList([watermark_processor]),
# **gen_kwargs
# )
# if args.prompt_max_length:
# pass
# elif hasattr(model.config,"max_position_embedding"):
# args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
# else:
# args.prompt_max_length = 2048-args.max_new_tokens
#
# tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
# truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
# redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
#
# torch.manual_seed(args.generation_seed)
# output_without_watermark = generate_without_watermark(**tokd_input)
#
# # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
# if args.seed_separately:
# torch.manual_seed(args.generation_seed)
# output_with_watermark = generate_with_watermark(**tokd_input)
#
# if args.is_decoder_only_model:
# # need to isolate the newly generated tokens
# output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
# output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
#
# decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
# decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
#
# return (redecoded_input,
# int(truncation_warning),
# decoded_output_without_watermark,
# decoded_output_with_watermark,
# args)
# # decoded_output_with_watermark)
#
# def format_names(s):
# """Format names for the gradio demo interface"""
# s=s.replace("num_tokens_scored","Tokens Counted (T)")
# s=s.replace("num_green_tokens","# Tokens in Greenlist")
# s=s.replace("green_fraction","Fraction of T in Greenlist")
# s=s.replace("z_score","z-score")
# s=s.replace("p_value","p value")
# s=s.replace("prediction","Prediction")
# s=s.replace("confidence","Confidence")
# return s
#
# def list_format_scores(score_dict, detection_threshold):
# """Format the detection metrics into a gradio dataframe input format"""
# lst_2d = []
# # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
# for k,v in score_dict.items():
# if k=='green_fraction':
# lst_2d.append([format_names(k), f"{v:.1%}"])
# elif k=='confidence':
# lst_2d.append([format_names(k), f"{v:.3%}"])
# elif isinstance(v, float):
# lst_2d.append([format_names(k), f"{v:.3g}"])
# elif isinstance(v, bool):
# lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
# else:
# lst_2d.append([format_names(k), f"{v}"])
# if "confidence" in score_dict:
# lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
# else:
# lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
# return lst_2d
#
# def detect(input_text, args, device=None, tokenizer=None):
# """Instantiate the WatermarkDetection object and call detect on
# the input text returning the scores and outcome of the test"""
# watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
# gamma=args.gamma,
# seeding_scheme=args.seeding_scheme,
# device=device,
# tokenizer=tokenizer,
# z_threshold=args.detection_z_threshold,
# normalizers=args.normalizers,
# ignore_repeated_bigrams=args.ignore_repeated_bigrams,
# select_green_tokens=args.select_green_tokens)
# if len(input_text)-1 > watermark_detector.min_prefix_len:
# score_dict = watermark_detector.detect(input_text)
# # output = str_format_scores(score_dict, watermark_detector.z_threshold)
# output = list_format_scores(score_dict, watermark_detector.z_threshold)
# else:
# # output = (f"Error: string not long enough to compute watermark presence.")
# output = [["Error","string too short to compute metrics"]]
# output += [["",""] for _ in range(6)]
# return output, args
#
# class Seafoam(Base):
# def __init__(
# self,
# *,
# primary_hue: colors.Color | str = colors.emerald,
# secondary_hue: colors.Color | str = colors.blue,
# neutral_hue: colors.Color | str = colors.blue,
# spacing_size: sizes.Size | str = sizes.spacing_md,
# radius_size: sizes.Size | str = sizes.radius_md,
# text_size: sizes.Size | str = sizes.text_lg,
# font: fonts.Font
# | str
# | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("Quicksand"),
# "ui-sans-serif",
# "sans-serif",
# ),
# font_mono: fonts.Font
# | str
# | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("IBM Plex Mono"),
# "ui-monospace",
# "monospace",
# ),
# ):
# super().__init__(
# primary_hue=primary_hue,
# secondary_hue=secondary_hue,
# neutral_hue=neutral_hue,
# spacing_size=spacing_size,
# radius_size=radius_size,
# text_size=text_size,
# font=font,
# font_mono=font_mono,
# )
# super().set(
# body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
# body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
# button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
# button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
# button_primary_text_color="white",
# button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
# slider_color="*secondary_300",
# slider_color_dark="*secondary_600",
# block_title_text_weight="600",
# block_border_width="3px",
# block_shadow="*shadow_drop_lg",
# button_shadow="*shadow_drop_lg",
# button_large_padding="32px",
# )
#
# seafoam = Seafoam()
#
# def run_gradio(args, model=None, device=None, tokenizer=None):
# """Define and launch the gradio demo interface"""
# generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
# detect_partial = partial(detect, device=device, tokenizer=tokenizer)
#
# # with gr.Blocks(theme="shivi/calm_seafoam") as demo:
# # with gr.Blocks(theme="finlaymacklon/smooth_slate") as demo:
# # with gr.Blocks(theme="freddyaboulton/test-blue") as demo:
# with gr.Blocks(theme="xiaobaiyuan/theme_brief") as demo:
# gr.Markdown(
# """
# # 💧 大语言模型水印 🔍
# """
# )
#
# with gr.Accordion("参数说明", open=False):
# gr.Markdown(
# """
# - `z分数阈值` : 假设检验的截断值。
# - `标记个数 (T)` : 检测算法计算的输出中计数的标记数。
# 在简单的单个标记种子方案中,第一个标记被省略,因为它没有前缀标记,无法为其生成绿色列表。
# 在底部面板中描述的“忽略重复二元组”检测算法下,如果存在大量重复,这个数量可能远小于生成的总标记数。
# - `绿色列表中的标记数目` : 观察到的落在各自绿色列表中的标记数。
# - `T中含有绿色列表标记的比例` : `绿色列表中的标记数目` / `T`。预期对于人类/非水印文本,这个比例大约等于 gamma。
# - `z分数` : 检测假设检验的检验统计量。如果大于 `z分数阈值`,则“拒绝零假设”,即文本是人类/非水印的,推断它是带有水印的。
# - `p值` : 在零假设下观察到计算的 `z-分数` 的概率。
# 这是在不知道水印程序/绿色列表的情况下观察到 'T中含有绿色列表标记的比例' 的概率。
# 如果这个值非常小,我们有信心认为这么多绿色标记不是随机选择的。
# - `预测` : 假设检验的结果,即观察到的 `z分数` 是否高于 `z分数阈值`。
# - `置信度` : 如果我们拒绝零假设,并且 `预测` 是“Watermarked”,那么我们报告 1-`p 值` 来表示基于这个 `z分数` 观察的检测置信度的不可能性。
# """
# )
#
# with gr.Accordion("关于模型能力的说明", open=True):
# gr.Markdown(
# """
# 本演示使用适用于单个 GPU 的开源语言模型。这些模型比专有商业工具(如 ChatGPT、Claude 或 Bard)的能力更弱。
#
# 还有一件事,我们使用语言模型旨在“完成”您的提示,而不是经过微调以遵循指令的模型。
# 为了获得最佳结果,请使用一些组成段落开头的句子提示模型,然后让它“继续”您的段落。
# 一些示例包括维基百科文章的开头段落或故事的前几句话。
# 结尾处中断的较长提示将产生更流畅的生成。
# """
# )
#
# gr.Markdown(f"语言模型: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
#
# # Construct state for parameters, define updates and toggles
# default_prompt = args.__dict__.pop("default_prompt")
# session_args = gr.State(value=args)
#
# with gr.Tab("生成检测"):
# with gr.Row():
# prompt = gr.Textbox(label=f"提示词", interactive=True,lines=10,max_lines=10, value=default_prompt)
# with gr.Row():
# generate_btn = gr.Button("生成")
# with gr.Row():
# with gr.Column(scale=2):
# with gr.Tab("未嵌入水印输出的文本"):
# output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
# max_lines=14, show_label=False)
# with gr.Tab("高亮"):
# highlight_output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
# max_lines=14, show_label=False)
# with gr.Column(scale=1):
# # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
# without_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
# row_count=7, col_count=2)
#
#
# with gr.Row():
# with gr.Column(scale=2):
# with gr.Tab("嵌入了水印输出的文本"):
# output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
# max_lines=14, show_label=False)
# with gr.Tab("高亮"):
# highlight_output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
# max_lines=14, show_label=False)
# with gr.Column(scale=1):
# # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
# with_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
# row_count=7, col_count=2)
#
#
# redecoded_input = gr.Textbox(visible=False)
# truncation_warning = gr.Number(visible=False)
# def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
# if truncation_warning:
# return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
# else:
# return orig_prompt, args
#
# with gr.Tab("仅检测"):
# with gr.Row():
# with gr.Column(scale=2):
# detection_input = gr.Textbox(label="待分析文本", interactive=True, lines=14, max_lines=14)
# with gr.Column(scale=1):
# # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
# detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False, row_count=7, col_count=2)
# with gr.Row():
# detect_btn = gr.Button("检测")
#
# # Parameter selection group
# with gr.Accordion("高级设置", open=False):
# with gr.Row():
# with gr.Column(scale=1):
# gr.Markdown(f"#### 生成参数")
# with gr.Row():
# decoding = gr.Radio(label="解码方式", choices=["multinomial", "greedy"],
# value=("multinomial" if args.use_sampling else "greedy"))
#
# with gr.Row():
# sampling_temp = gr.Slider(label="采样随机性多样性权重", minimum=0.1, maximum=1.0, step=0.1,
# value=args.sampling_temp, visible=True)
# with gr.Row():
# generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
# with gr.Row():
# n_beams = gr.Dropdown(label="束搜索路数", choices=list(range(1, 11, 1)), value=args.n_beams,
# visible=(not args.use_sampling))
# with gr.Row():
# max_new_tokens = gr.Slider(label="生成最大标记数", minimum=10, maximum=1000, step=10,
# value=args.max_new_tokens)
#
# with gr.Column(scale=1):
# gr.Markdown(f"#### 水印参数")
# with gr.Row():
# gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
# with gr.Row():
# delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
# gr.Markdown(f"#### 检测参数")
# with gr.Row():
# detection_z_threshold = gr.Slider(label="z-score 阈值", minimum=0.0, maximum=10.0, step=0.1,
# value=args.detection_z_threshold)
# with gr.Row():
# ignore_repeated_bigrams = gr.Checkbox(label="忽略重复 Bigram")
# with gr.Row():
# normalizers = gr.CheckboxGroup(label="正则化器",
# choices=["unicode", "homoglyphs", "truecase"],
# value=args.normalizers)
# # with gr.Accordion("Actual submitted parameters:",open=False):
# with gr.Row():
# gr.Markdown(
# f"_提示: 滑块更新有延迟。点击滑动条或使用右侧的数字窗口可以帮助更新。下方窗口显示当前的设置。_")
# with gr.Row():
# current_parameters = gr.Textbox(label="当前参数", value=args, interactive=False, lines=6)
# with gr.Accordion("保留设置", open=False):
# with gr.Row():
# with gr.Column(scale=1):
# seed_separately = gr.Checkbox(label="红绿分别生成", value=args.seed_separately)
# with gr.Column(scale=1):
# select_green_tokens = gr.Checkbox(label="从分区中选择'greenlist'",
# value=args.select_green_tokens)
#
# with gr.Accordion("关于设置", open=False):
# gr.Markdown(
# """
# #### 生成参数:
#
# - 解码方法:我们可以使用多项式采样或贪婪解码来从模型中生成标记。
# - 采样温度:如果使用多项式采样,可以设置采样分布的温度。
# 0.0 相当于贪婪解码,而 1.0 是下一个标记分布中的最大变异性/熵。
# 0.7 在保持对模型对前几个候选者的估计准确性的同时增加了多样性。对于贪婪解码无效。
# - 生成种子:在运行生成之前传递给 torch 随机数生成器的整数。使多项式采样策略输出可复现。对于贪婪解码无效。
# - 并行数:当使用贪婪解码时,还可以将并行数设置为 > 1 以启用波束搜索。
# 这在多项式采样中未实现/排除在论文中,但可能会在未来添加。
# - 最大生成标记数:传递给生成方法的 `max_new_tokens` 参数,以在特定数量的新标记处停止输出。
# 请注意,根据提示,模型可以生成较少的标记。
# 这将隐含地将可能的提示标记数量设置为模型的最大输入长度减去 `max_new_tokens`,
# 并且输入将相应地被截断。
#
# #### 水印参数:
#
# - gamma:每次生成步骤将词汇表分成绿色列表的部分。较小的 gamma 值通过使得有水印的模型能够更好地与人类/无水印文本区分,
# 从而创建了更强的水印,因为它会更倾向于从较小的绿色集合中进行采样,使得这些标记不太可能是偶然发生的。
# - delta:在每个生成步骤中,在采样/选择下一个标记之前,为绿色列表中的每个标记的对数概率添加正偏差。
# 较高的 delta 值意味着绿色列表标记更受有水印的模型青睐,并且随着偏差的增大,水印从“软性”过渡到“硬性”。
# 对于硬性水印,几乎所有的标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性有限时。
#
# #### 检测器参数:
#
# - z-score 阈值:假设检验的 z-score 截断值。较高的阈值(例如 4.0)使得预测人类/无水印文本是有水印的
# (_false positives_)的可能性非常低,因为一个真正的包含大量标记的人类文本几乎不可能达到那么高的 z-score。
# 较低的阈值将捕捉更多的真正有水印的文本,因为一些有水印的文本可能包含较少的绿色标记并获得较低的 z-score,
# 但仍然通过较低的门槛被标记为“有水印”。然而,较低的阈值会增加被错误地标记为有水印的具有略高于平均绿色标记数的人类文本的几率。
# 4.0-5.0 提供了极低的误报率,同时仍然准确地捕捉到大多数有水印的文本。
# - 忽略重复的双字母组合:此备用检测算法在检测期间只考虑文本中的唯一双字母组合,
# 根据每对中的第一个计算绿色列表,并检查第二个是否在列表内。
# 这意味着 `T` 现在是文本中唯一的双字母组合的数量,
# 如果文本包含大量重复,那么它将少于生成的总标记数。
# 有关更详细的讨论,请参阅论文。
# - 标准化:我们实现了一些基本的标准化,以防止文本在检测过程中受到各种对抗性扰动。
# 目前,我们支持将所有字符转换为 Unicode,使用规范形式替换同形字符,并标准化大小写。
# """
# )
#
# # gr.HTML("""
# # <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
# # Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
# # <br/>
# # <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
# # <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
# # <p/>
# # """)
#
# # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
# generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
# # Show truncated version of prompt if truncation occurred
# redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
# # Call detection when the outputs (of the generate function) are updated
# output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# # Register main detection tab click
# detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
#
# # State management logic
# # update callbacks that change the state dict
# def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
# def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
# def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
# def update_delta(session_state, value): session_state.delta = float(value); return session_state
# def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
# def update_decoding(session_state, value):
# if value == "multinomial":
# session_state.use_sampling = True
# elif value == "greedy":
# session_state.use_sampling = False
# return session_state
# def toggle_sampling_vis(value):
# if value == "multinomial":
# return gr.update(visible=True)
# elif value == "greedy":
# return gr.update(visible=False)
# def toggle_sampling_vis_inv(value):
# if value == "multinomial":
# return gr.update(visible=False)
# elif value == "greedy":
# return gr.update(visible=True)
# def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
# def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
# def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
# def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
# def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
# def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
# # registering callbacks for toggling the visibilty of certain parameters
# decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
# decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
# decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
# # registering all state update callbacks
# decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
# sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
# generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
# n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
# max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
# gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
# delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
# detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
# ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
# normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
# seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
# select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
# # register additional callback on button clicks that updates the shown parameters window
# generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
# gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
# detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
# ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
# normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
# select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
# select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
# select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
#
#
# demo.queue(concurrency_count=3)
#
# if args.demo_public:
# demo.launch(share=True) # exposes app to the internet via randomly generated link
# else:
# demo.launch()
#
# def main(args):
# """Run a command line version of the generation and detection operations
# and optionally launch and serve the gradio demo"""
# # Initial arg processing and log
# args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
# print(args)
#
# if not args.skip_model_load:
# model, tokenizer, device = load_model(args)
# else:
# model, tokenizer, device = None, None, None
#
# # Generate and detect, report to stdout
# if not args.skip_model_load:
# input_text = (
# "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
# "species of turtle native to the brackish coastal tidal marshes of the "
# "Northeastern and southern United States, and in Bermuda.[6] It belongs "
# "to the monotypic genus Malaclemys. It has one of the largest ranges of "
# "all turtles in North America, stretching as far south as the Florida Keys "
# "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
# "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
# "British English and American English. The name originally was used by "
# "early European settlers in North America to describe these brackish-water "
# "turtles that inhabited neither freshwater habitats nor the sea. It retains "
# "this primary meaning in American English.[8] In British English, however, "
# "other semi-aquatic turtle species, such as the red-eared slider, might "
# "also be called terrapins. The common name refers to the diamond pattern "
# "on top of its shell (carapace), but the overall pattern and coloration "
# "vary greatly. The shell is usually wider at the back than in the front, "
# "and from above it appears wedge-shaped. The shell coloring can vary "
# "from brown to grey, and its body color can be grey, brown, yellow, "
# "or white. All have a unique pattern of wiggly, black markings or spots "
# "on their body and head. The diamondback terrapin has large webbed "
# "feet.[9] The species is"
# )
#
# args.default_prompt = input_text
#
# term_width = 80
# print("#"*term_width)
# print("Prompt:")
# print(input_text)
#
# _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
# args,
# model=model,
# device=device,
# tokenizer=tokenizer)
# without_watermark_detection_result = detect(decoded_output_without_watermark,
# args,
# device=device,
# tokenizer=tokenizer)
# with_watermark_detection_result = detect(decoded_output_with_watermark,
# args,
# device=device,
# tokenizer=tokenizer)
#
# print("#"*term_width)
# print("Output without watermark:")
# print(decoded_output_without_watermark)
# print("-"*term_width)
# print(f"Detection result @ {args.detection_z_threshold}:")
# pprint(without_watermark_detection_result)
# print("-"*term_width)
#
# print("#"*term_width)
# print("Output with watermark:")
# print(decoded_output_with_watermark)
# print("-"*term_width)
# print(f"Detection result @ {args.detection_z_threshold}:")
# pprint(with_watermark_detection_result)
# print("-"*term_width)
#
#
# # Launch the app to generate and detect interactively (implements the hf space demo)
# if args.run_gradio:
# run_gradio(args, model=model, tokenizer=tokenizer, device=device)
#
# return
#
# if __name__ == "__main__":
#
# args = parse_args()
# print(args)
#
# main(args)
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from pprint import pprint
from functools import partial
import numpy # for gradio hot reload
import gradio as gr
import torch
from transformers import (AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
LogitsProcessorList)
# from local_tokenizers.tokenization_llama import LLaMATokenizer
from transformers import GPT2TokenizerFast
OPT_TOKENIZER = GPT2TokenizerFast
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
# ALPACA_MODEL_NAME = "alpaca"
# ALPACA_MODEL_TOKENIZER = LLaMATokenizer
# ALPACA_TOKENIZER_PATH = "/cmlscratch/jkirchen/llama"
# FIXME correct lengths for all models
API_MODEL_MAP = {
"google/flan-ul2": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
"google/flan-t5-xxl": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
"EleutherAI/gpt-neox-20b": {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
# "bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
# "bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
}
def str2bool(v):
"""Util function for user friendly boolean flag args"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
"""Command line argument specification"""
parser = argparse.ArgumentParser(
description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
parser.add_argument(
"--run_gradio",
type=str2bool,
default=True,
help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
)
parser.add_argument(
"--demo_public",
type=str2bool,
default=False,
help="Whether to expose the gradio demo to the internet.",
)
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-6.7b",
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--prompt_max_length",
type=int,
default=None,
help="Truncation length for prompt, overrides model config's max length field.",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=200,
help="Maximmum number of new tokens to generate.",
)
parser.add_argument(
"--generation_seed",
type=int,
default=123,
help="Seed for setting the torch global rng prior to generation.",
)
parser.add_argument(
"--use_sampling",
type=str2bool,
default=True,
help="Whether to generate using multinomial sampling.",
)
parser.add_argument(
"--sampling_temp",
type=float,
default=0.7,
help="Sampling temperature to use when generating using multinomial sampling.",
)
parser.add_argument(
"--n_beams",
type=int,
default=1,
help="Number of beams to use for beam search. 1 is normal greedy decoding",
)
parser.add_argument(
"--use_gpu",
type=str2bool,
default=True,
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
)
parser.add_argument(
"--seeding_scheme",
type=str,
default="simple_1",
help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
)
parser.add_argument(
"--gamma",
type=float,
default=0.25,
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
)
parser.add_argument(
"--delta",
type=float,
default=2.0,
help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
)
parser.add_argument(
"--normalizers",
type=str,
default="",
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
)
parser.add_argument(
"--ignore_repeated_bigrams",
type=str2bool,
default=False,
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
)
parser.add_argument(
"--detection_z_threshold",
type=float,
default=4.0,
help="The test statistic threshold for the detection hypothesis test.",
)
parser.add_argument(
"--select_green_tokens",
type=str2bool,
default=True,
help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
)
parser.add_argument(
"--skip_model_load",
type=str2bool,
default=False,
help="Skip the model loading to debug the interface.",
)
parser.add_argument(
"--seed_separately",
type=str2bool,
default=True,
help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
)
parser.add_argument(
"--load_fp16",
type=str2bool,
default=False,
help="Whether to run model in float16 precsion.",
)
args = parser.parse_args()
return args
def load_model(args):
"""Load and return the model and tokenizer"""
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]])
args.is_decoder_only_model = any(
[(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom"]])
if args.is_seq2seq_model:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
elif args.is_decoder_only_model:
if args.load_fp16:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16,
device_map='auto')
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
else:
raise ValueError(f"Unknown model type: {args.model_name_or_path}")
if args.use_gpu:
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.load_fp16:
pass
else:
model = model.to(device)
else:
device = "cpu"
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
return model, tokenizer, device
from text_generation import InferenceAPIClient
from requests.exceptions import ReadTimeout
def generate_with_api(prompt, args):
# hf_api_key = os.environ.get("HF_API_KEY")
hf_api_key = "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"
if hf_api_key is None:
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
assert args.n_beams == 1, "HF API models do not support beam search."
generation_params = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.use_sampling,
}
if args.use_sampling:
generation_params["temperature"] = args.sampling_temp
generation_params["seed"] = args.generation_seed
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
try:
generation_params["watermark"] = False
without_watermark_iterator = client.generate_stream(prompt, **generation_params)
except ReadTimeout as e:
print(e)
without_watermark_iterator = (char for char in timeout_msg)
try:
generation_params["watermark"] = True
with_watermark_iterator = client.generate_stream(prompt, **generation_params)
except ReadTimeout as e:
print(e)
with_watermark_iterator = (char for char in timeout_msg)
all_without_words, all_with_words = "", ""
for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
all_without_words += without_word.token.text
all_with_words += with_word.token.text
yield all_without_words, all_with_words
def check_prompt(prompt, args, tokenizer, model=None, device=None):
# This applies to both the local and API model scenarios
if args.model_name_or_path in API_MODEL_MAP:
args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
elif hasattr(model.config, "max_position_embedding"):
args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
else:
args.prompt_max_length = 2048 - args.max_new_tokens
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
max_length=args.prompt_max_length).to(device)
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
return (redecoded_input,
int(truncation_warning),
args)
def generate(prompt, args, tokenizer, model=None, device=None):
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
and generate watermarked text by passing it to the generate method of the model
as a logits processor. """
print(f"Generating with {args}")
print(f"Prompt: {prompt}")
if args.model_name_or_path in API_MODEL_MAP:
api_outputs = generate_with_api(prompt, args)
yield from api_outputs
else:
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
max_length=args.prompt_max_length).to(device)
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
gamma=args.gamma,
delta=args.delta,
seeding_scheme=args.seeding_scheme,
select_green_tokens=args.select_green_tokens)
gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
if args.use_sampling:
gen_kwargs.update(dict(
do_sample=True,
top_k=0,
temperature=args.sampling_temp
))
else:
gen_kwargs.update(dict(
num_beams=args.n_beams
))
generate_without_watermark = partial(
model.generate,
**gen_kwargs
)
generate_with_watermark = partial(
model.generate,
logits_processor=LogitsProcessorList([watermark_processor]),
**gen_kwargs
)
torch.manual_seed(args.generation_seed)
output_without_watermark = generate_without_watermark(**tokd_input)
# optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
if args.seed_separately:
torch.manual_seed(args.generation_seed)
output_with_watermark = generate_with_watermark(**tokd_input)
if args.is_decoder_only_model:
# need to isolate the newly generated tokens
output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
# mocking the API outputs in a whitespace split generator style
all_without_words, all_with_words = "", ""
for without_word, with_word in zip(decoded_output_without_watermark.split(),
decoded_output_with_watermark.split()):
all_without_words += without_word + " "
all_with_words += with_word + " "
yield all_without_words, all_with_words
def format_names(s):
"""Format names for the gradio demo interface"""
s = s.replace("num_tokens_scored", "Tokens Counted (T)")
s = s.replace("num_green_tokens", "# Tokens in Greenlist")
s = s.replace("green_fraction", "Fraction of T in Greenlist")
s = s.replace("z_score", "z-score")
s = s.replace("p_value", "p value")
s = s.replace("prediction", "Prediction")
s = s.replace("confidence", "Confidence")
return s
def list_format_scores(score_dict, detection_threshold):
"""Format the detection metrics into a gradio dataframe input format"""
lst_2d = []
for k, v in score_dict.items():
if k == 'green_fraction':
lst_2d.append([format_names(k), f"{v:.1%}"])
elif k == 'confidence':
lst_2d.append([format_names(k), f"{v:.3%}"])
elif isinstance(v, float):
lst_2d.append([format_names(k), f"{v:.3g}"])
elif isinstance(v, bool):
lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
else:
lst_2d.append([format_names(k), f"{v}"])
if "confidence" in score_dict:
lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
else:
lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
return lst_2d
def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
"""Instantiate the WatermarkDetection object and call detect on
the input text returning the scores and outcome of the test"""
print(f"Detecting with {args}")
print(f"Detection Tokenizer: {type(tokenizer)}")
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
gamma=args.gamma,
seeding_scheme=args.seeding_scheme,
device=device,
tokenizer=tokenizer,
z_threshold=args.detection_z_threshold,
normalizers=args.normalizers,
ignore_repeated_bigrams=args.ignore_repeated_bigrams,
select_green_tokens=args.select_green_tokens)
# for now, just don't display the green token mask
# if we're using normalizers or ignore_repeated_bigrams
if args.normalizers != [] or args.ignore_repeated_bigrams:
return_green_token_mask = False
error = False
green_token_mask = None
if input_text == "":
error = True
else:
try:
score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
green_token_mask = score_dict.pop("green_token_mask", None)
output = list_format_scores(score_dict, watermark_detector.z_threshold)
except ValueError as e:
print(e)
error = True
if error:
output = [["Error", "string too short to compute metrics"]]
output += [["", ""] for _ in range(6)]
html_output = "[No highlight markup generated]"
if green_token_mask is not None:
# hack bc we need a fast tokenizer with charspan support
if "opt" in args.model_name_or_path:
tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path)
tokens = tokenizer(input_text)
if tokens["input_ids"][0] == tokenizer.bos_token_id:
tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask
skip = watermark_detector.min_prefix_len
charspans = [tokens.token_to_chars(i) for i in range(skip, len(tokens["input_ids"]))]
charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
if len(charspans) != len(green_token_mask): breakpoint()
assert len(charspans) == len(green_token_mask)
tags = [(
f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
for cs, m in zip(charspans, green_token_mask)]
html_output = f'<p>{" ".join(tags)}</p>'
return output, args, tokenizer, html_output
def run_gradio(args, model=None, device=None, tokenizer=None):
"""Define and launch the gradio demo interface"""
check_prompt_partial = partial(check_prompt, model=model, device=device)
generate_partial = partial(generate, model=model, device=device)
detect_partial = partial(detect, device=device)
css = """
.green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
.red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
"""
# with gr.Blocks(theme="xiaobaiyuan/theme_brief") as demo:
with gr.Blocks(css=css, theme="xiaobaiyuan/theme_brief") as demo:
# Top section, greeting and instructions
with gr.Row():
with gr.Column(scale=9):
gr.Markdown(
"""
# 💧 大语言模型水印 🔍
"""
)
with gr.Column(scale=1):
# if model_name_or_path at startup not one of the API models then add to dropdown
all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
model_selector = gr.Dropdown(
all_models,
value=args.model_name_or_path,
label="Language Model",
)
with gr.Accordion("参数说明", open=False):
gr.Markdown(
"""
- `z分数阈值` : 假设检验的截断值。
- `标记个数 (T)` : 检测算法计算的输出中计数的标记数。
在简单的单个标记种子方案中,第一个标记被省略,因为它没有前缀标记,无法为其生成绿色列表。
在底部面板中描述的“忽略重复二元组”检测算法下,如果存在大量重复,这个数量可能远小于生成的总标记数。
- `绿色列表中的标记数目` : 观察到的落在各自绿色列表中的标记数。
- `T中含有绿色列表标记的比例` : `绿色列表中的标记数目` / `T`。预期对于人类/非水印文本,这个比例大约等于 gamma。
- `z分数` : 检测假设检验的检验统计量。如果大于 `z分数阈值`,则“拒绝零假设”,即文本是人类/非水印的,推断它是带有水印的。
- `p值` : 在零假设下观察到计算的 `z-分数` 的概率。
这是在不知道水印程序/绿色列表的情况下观察到 'T中含有绿色列表标记的比例' 的概率。
如果这个值非常小,我们有信心认为这么多绿色标记不是随机选择的。
- `预测` : 假设检验的结果,即观察到的 `z分数` 是否高于 `z分数阈值`。
- `置信度` : 如果我们拒绝零假设,并且 `预测` 是“Watermarked”,那么我们报告 1-`p 值` 来表示基于这个 `z分数` 观察的检测置信度的不可能性。
"""
)
with gr.Accordion("关于模型能力的说明", open=True):
gr.Markdown(
"""
本演示使用适用于单个 GPU 的开源语言模型。这些模型比专有商业工具(如 ChatGPT、Claude 或 Bard)的能力更弱。
还有一件事,我们使用语言模型旨在“完成”您的提示,而不是经过微调以遵循指令的模型。
为了获得最佳结果,请使用一些组成段落开头的句子提示模型,然后让它“继续”您的段落。
一些示例包括维基百科文章的开头段落或故事的前几句话。
结尾处中断的较长提示将产生更流畅的生成。
"""
)
# Construct state for parameters, define updates and toggles
default_prompt = args.__dict__.pop("default_prompt")
session_args = gr.State(value=args)
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
session_tokenizer = gr.State(value=lambda: tokenizer)
# with gr.Row():
# gr.Markdown(
# """
# 温馨提示:若出现ERROR,可能由于api暂未成功载入,稍等片刻即可
# """
# )
with gr.Tab("生成检测"):
with gr.Row():
prompt = gr.Textbox(label=f"提示词", interactive=True,lines=10,max_lines=10, value=default_prompt)
with gr.Row():
generate_btn = gr.Button("生成")
with gr.Row():
with gr.Column(scale=2):
with gr.Tab("未嵌入水印输出的文本"):
output_without_watermark = gr.Textbox(label=None, interactive=False, lines=14,
max_lines=14, show_label=False)
with gr.Tab("高亮"):
html_without_watermark = gr.HTML(elem_id="html-without-watermark")
with gr.Column(scale=1):
# without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
without_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
row_count=7, col_count=2)
with gr.Row():
with gr.Column(scale=2):
with gr.Tab("嵌入了水印输出的文本"):
output_with_watermark = gr.Textbox(label=None, interactive=False, lines=14,
max_lines=14, show_label=False)
with gr.Tab("高亮"):
html_with_watermark = gr.HTML(elem_id="html-with-watermark")
with gr.Column(scale=1):
# with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
with_watermark_detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False,
row_count=7, col_count=2)
redecoded_input = gr.Textbox(visible=False)
truncation_warning = gr.Number(visible=False)
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
if truncation_warning:
return redecoded_input + f"\n\n[由于长度原因,提示词被截断...]", args
else:
return orig_prompt, args
with gr.Tab("仅检测"):
with gr.Row():
with gr.Column(scale=2):
with gr.Tab("待分析文本"):
detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14,show_label=False)
with gr.Tab("高亮"):
html_detection_input = gr.HTML(elem_id="html-detection-input")
with gr.Column(scale=1):
detection_result = gr.Dataframe(headers=["参数", "值"], interactive=False, row_count=7,
col_count=2)
with gr.Row():
detect_btn = gr.Button("检测")
# Parameter selection group
with gr.Accordion("高级设置", open=False):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"#### 生成参数")
with gr.Row():
decoding = gr.Radio(label="解码方式", choices=["multinomial", "greedy"],
value=("multinomial" if args.use_sampling else "greedy"))
with gr.Row():
sampling_temp = gr.Slider(label="采样随机性多样性权重", minimum=0.1, maximum=1.0, step=0.1,
value=args.sampling_temp, visible=True)
with gr.Row():
generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
with gr.Row():
n_beams = gr.Dropdown(label="束搜索路数", choices=list(range(1, 11, 1)), value=args.n_beams,
visible=(not args.use_sampling))
with gr.Row():
max_new_tokens = gr.Slider(label="生成最大标记数", minimum=10, maximum=1000, step=10,
value=args.max_new_tokens)
with gr.Column(scale=1):
gr.Markdown(f"#### 水印参数")
with gr.Row():
gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
with gr.Row():
delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
gr.Markdown(f"#### 检测参数")
with gr.Row():
detection_z_threshold = gr.Slider(label="z-score 阈值", minimum=0.0, maximum=10.0, step=0.1,
value=args.detection_z_threshold)
with gr.Row():
ignore_repeated_bigrams = gr.Checkbox(label="忽略重复 Bigram")
with gr.Row():
normalizers = gr.CheckboxGroup(label="正则化器",
choices=["unicode", "homoglyphs", "truecase"],
value=args.normalizers)
# with gr.Accordion("Actual submitted parameters:",open=False):
with gr.Row():
gr.Markdown(
f"_提示: 滑块更新有延迟。点击滑动条或使用右侧的数字窗口可以帮助更新。下方窗口显示当前的设置。_")
with gr.Row():
current_parameters = gr.Textbox(label="当前参数", value=args, interactive=False, lines=6)
with gr.Accordion("保留设置", open=False):
with gr.Row():
with gr.Column(scale=1):
seed_separately = gr.Checkbox(label="红绿分别生成", value=args.seed_separately)
with gr.Column(scale=1):
select_green_tokens = gr.Checkbox(label="从分区中选择'greenlist'",
value=args.select_green_tokens)
with gr.Accordion("关于设置", open=False):
gr.Markdown(
"""
#### 生成参数:
- 解码方法:我们可以使用多项式采样或贪婪解码来从模型中生成标记。
- 采样温度:如果使用多项式采样,可以设置采样分布的温度。
0.0 相当于贪婪解码,而 1.0 是下一个标记分布中的最大变异性/熵。
0.7 在保持对模型对前几个候选者的估计准确性的同时增加了多样性。对于贪婪解码无效。
- 生成种子:在运行生成之前传递给 torch 随机数生成器的整数。使多项式采样策略输出可复现。对于贪婪解码无效。
- 并行数:当使用贪婪解码时,还可以将并行数设置为 > 1 以启用波束搜索。
这在多项式采样中未实现/排除在论文中,但可能会在未来添加。
- 最大生成标记数:传递给生成方法的 `max_new_tokens` 参数,以在特定数量的新标记处停止输出。
请注意,根据提示,模型可以生成较少的标记。
这将隐含地将可能的提示标记数量设置为模型的最大输入长度减去 `max_new_tokens`,
并且输入将相应地被截断。
#### 水印参数:
- gamma:每次生成步骤将词汇表分成绿色列表的部分。较小的 gamma 值通过使得有水印的模型能够更好地与人类/无水印文本区分,
从而创建了更强的水印,因为它会更倾向于从较小的绿色集合中进行采样,使得这些标记不太可能是偶然发生的。
- delta:在每个生成步骤中,在采样/选择下一个标记之前,为绿色列表中的每个标记的对数概率添加正偏差。
较高的 delta 值意味着绿色列表标记更受有水印的模型青睐,并且随着偏差的增大,水印从“软性”过渡到“硬性”。
对于硬性水印,几乎所有的标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性有限时。
#### 检测器参数:
- z-score 阈值:假设检验的 z-score 截断值。较高的阈值(例如 4.0)使得预测人类/无水印文本是有水印的
(_false positives_)的可能性非常低,因为一个真正的包含大量标记的人类文本几乎不可能达到那么高的 z-score。
较低的阈值将捕捉更多的真正有水印的文本,因为一些有水印的文本可能包含较少的绿色标记并获得较低的 z-score,
但仍然通过较低的门槛被标记为“有水印”。然而,较低的阈值会增加被错误地标记为有水印的具有略高于平均绿色标记数的人类文本的几率。
4.0-5.0 提供了极低的误报率,同时仍然准确地捕捉到大多数有水印的文本。
- 忽略重复的双字母组合:此备用检测算法在检测期间只考虑文本中的唯一双字母组合,
根据每对中的第一个计算绿色列表,并检查第二个是否在列表内。
这意味着 `T` 现在是文本中唯一的双字母组合的数量,
如果文本包含大量重复,那么它将少于生成的总标记数。
有关更详细的讨论,请参阅论文。
- 标准化:我们实现了一些基本的标准化,以防止文本在检测过程中受到各种对抗性扰动。
目前,我们支持将所有字符转换为 Unicode,使用规范形式替换同形字符,并标准化大小写。
"""
)
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
outputs=[redecoded_input, truncation_warning, session_args]).success(
fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
outputs=[output_without_watermark, output_with_watermark]).success(
fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark]).success(
fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
# Show truncated version of prompt if truncation occurred
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
outputs=[prompt, session_args])
# Register main detection tab click
detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer, html_detection_input],
api_name="detection")
# State management logic
# define update callbacks that change the state dict
def update_model(session_state, value):
session_state.model_name_or_path = value; return session_state
def update_sampling_temp(session_state, value):
session_state.sampling_temp = float(value); return session_state
def update_generation_seed(session_state, value):
session_state.generation_seed = int(value); return session_state
def update_gamma(session_state, value):
session_state.gamma = float(value); return session_state
def update_delta(session_state, value):
session_state.delta = float(value); return session_state
def update_detection_z_threshold(session_state, value):
session_state.detection_z_threshold = float(value); return session_state
def update_decoding(session_state, value):
if value == "multinomial":
session_state.use_sampling = True
elif value == "greedy":
session_state.use_sampling = False
return session_state
def toggle_sampling_vis(value):
if value == "multinomial":
return gr.update(visible=True)
elif value == "greedy":
return gr.update(visible=False)
def toggle_sampling_vis_inv(value):
if value == "multinomial":
return gr.update(visible=False)
elif value == "greedy":
return gr.update(visible=True)
# if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams
def toggle_vis_for_api_model(value):
if value in API_MODEL_MAP:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def toggle_beams_for_api_model(value, orig_n_beams):
if value in API_MODEL_MAP:
return gr.update(value=1)
else:
return gr.update(value=orig_n_beams)
# if model name is in the list of api models, set the interactive parameter to false
def toggle_interactive_for_api_model(value):
if value in API_MODEL_MAP:
return gr.update(interactive=False)
else:
return gr.update(interactive=True)
# if model name is in the list of api models, set gamma and delta based on API map
def toggle_gamma_for_api_model(value, orig_gamma):
if value in API_MODEL_MAP:
return gr.update(value=API_MODEL_MAP[value]["gamma"])
else:
return gr.update(value=orig_gamma)
def toggle_delta_for_api_model(value, orig_delta):
if value in API_MODEL_MAP:
return gr.update(value=API_MODEL_MAP[value]["delta"])
else:
return gr.update(value=orig_delta)
def update_n_beams(session_state, value):
session_state.n_beams = int(value); return session_state
def update_max_new_tokens(session_state, value):
session_state.max_new_tokens = int(value); return session_state
def update_ignore_repeated_bigrams(session_state, value):
session_state.ignore_repeated_bigrams = value; return session_state
def update_normalizers(session_state, value):
session_state.normalizers = value; return session_state
def update_seed_separately(session_state, value):
session_state.seed_separately = value; return session_state
def update_select_green_tokens(session_state, value):
session_state.select_green_tokens = value; return session_state
def update_tokenizer(model_name_or_path):
# if model_name_or_path == ALPACA_MODEL_NAME:
# return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
# else:
return AutoTokenizer.from_pretrained(model_name_or_path)
def check_model(value):
return value if (value != "" and value is not None) else args.model_name_or_path
# enforce constraint that model cannot be null or empty
# then attach model callbacks in particular
model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
).then(
toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
).then(
toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
).then(
toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
).then(
toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
).then(
toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
).then(
update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
).then(
update_model, inputs=[session_args, model_selector], outputs=[session_args]
).then(
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
)
# registering callbacks for toggling the visibilty of certain parameters based on the values of others
decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
# registering all state update callbacks
decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
outputs=[session_args])
ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
outputs=[session_args])
normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
outputs=[session_args])
# register additional callback on button clicks that updates the shown parameters window
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark])
gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
detection_z_threshold.change(fn=detect_partial,
inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark])
detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
html_with_watermark])
detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
ignore_repeated_bigrams.change(fn=detect_partial,
inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark])
ignore_repeated_bigrams.change(fn=detect_partial,
inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
html_with_watermark])
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer,
html_detection_input])
normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark])
normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
html_with_watermark])
normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
select_green_tokens.change(fn=detect_partial,
inputs=[output_without_watermark, session_args, session_tokenizer],
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
html_without_watermark])
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
html_with_watermark])
select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
demo.queue(concurrency_count=3)
if args.demo_public:
demo.launch(share=True) # exposes app to the internet via randomly generated link
else:
demo.launch()
def main(args):
"""Run a command line version of the generation and detection operations
and optionally launch and serve the gradio demo"""
# Initial arg processing and log
args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
print(args)
if not args.skip_model_load:
model, tokenizer, device = load_model(args)
else:
model, tokenizer, device = None, None, None
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
if args.use_gpu:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = "cpu"
# terrapin example
input_text = (
"The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
"species of turtle native to the brackish coastal tidal marshes of the "
"Northeastern and southern United States, and in Bermuda.[6] It belongs "
"to the monotypic genus Malaclemys. It has one of the largest ranges of "
"all turtles in North America, stretching as far south as the Florida Keys "
"and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
"Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
"British English and American English. The name originally was used by "
"early European settlers in North America to describe these brackish-water "
"turtles that inhabited neither freshwater habitats nor the sea. It retains "
"this primary meaning in American English.[8] In British English, however, "
"other semi-aquatic turtle species, such as the red-eared slider, might "
"also be called terrapins. The common name refers to the diamond pattern "
"on top of its shell (carapace), but the overall pattern and coloration "
"vary greatly. The shell is usually wider at the back than in the front, "
"and from above it appears wedge-shaped. The shell coloring can vary "
"from brown to grey, and its body color can be grey, brown, yellow, "
"or white. All have a unique pattern of wiggly, black markings or spots "
"on their body and head. The diamondback terrapin has large webbed "
"feet.[9] The species is"
)
args.default_prompt = input_text
# Generate and detect, report to stdout
if not args.skip_model_load:
term_width = 80
print("#" * term_width)
print("Prompt:")
print(input_text)
# a generator that yields (without_watermark, with_watermark) pairs
generator_outputs = generate(input_text,
args,
model=model,
device=device,
tokenizer=tokenizer)
# we need to iterate over it,
# but we only want the last output in this case
for out in generator_outputs:
decoded_output_without_watermark = out[0]
decoded_output_with_watermark = out[1]
without_watermark_detection_result = detect(decoded_output_without_watermark,
args,
device=device,
tokenizer=tokenizer,
return_green_token_mask=False)
with_watermark_detection_result = detect(decoded_output_with_watermark,
args,
device=device,
tokenizer=tokenizer,
return_green_token_mask=False)
print("#" * term_width)
print("Output without watermark:")
print(decoded_output_without_watermark)
print("-" * term_width)
print(f"Detection result @ {args.detection_z_threshold}:")
pprint(without_watermark_detection_result)
print("-" * term_width)
print("#" * term_width)
print("Output with watermark:")
print(decoded_output_with_watermark)
print("-" * term_width)
print(f"Detection result @ {args.detection_z_threshold}:")
pprint(with_watermark_detection_result)
print("-" * term_width)
# Launch the app to generate and detect interactively (implements the hf space demo)
if args.run_gradio:
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
return
if __name__ == "__main__":
args = parse_args()
print(args)
main(args)