jwkirchenbauer commited on
Commit
811d741
·
1 Parent(s): 7c3b96d
Files changed (2) hide show
  1. app.py +4 -2
  2. demo_watermark.py +20 -4
app.py CHANGED
@@ -22,8 +22,10 @@ arg_dict = {
22
  'demo_public': False,
23
  # 'model_name_or_path': 'facebook/opt-125m',
24
  # 'model_name_or_path': 'facebook/opt-1.3b',
25
- 'model_name_or_path': 'facebook/opt-2.7b',
26
- # 'model_name_or_path': 'facebook/opt-6.7b',
 
 
27
  'prompt_max_length': None,
28
  'max_new_tokens': 200,
29
  'generation_seed': 123,
 
22
  'demo_public': False,
23
  # 'model_name_or_path': 'facebook/opt-125m',
24
  # 'model_name_or_path': 'facebook/opt-1.3b',
25
+ # 'model_name_or_path': 'facebook/opt-2.7b',
26
+ 'model_name_or_path': 'facebook/opt-6.7b',
27
+ 'load_fp16' : True,
28
+ # 'load_fp16' : False,
29
  'prompt_max_length': None,
30
  'max_new_tokens': 200,
31
  'generation_seed': 123,
demo_watermark.py CHANGED
@@ -162,6 +162,12 @@ def parse_args():
162
  default=True,
163
  help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
164
  )
 
 
 
 
 
 
165
  args = parser.parse_args()
166
  return args
167
 
@@ -173,13 +179,19 @@ def load_model(args):
173
  if args.is_seq2seq_model:
174
  model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
175
  elif args.is_decoder_only_model:
176
- model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
 
 
 
177
  else:
178
  raise ValueError(f"Unknown model type: {args.model_name_or_path}")
179
 
180
  if args.use_gpu:
181
  device = "cuda" if torch.cuda.is_available() else "cpu"
182
- model = model.to(device)
 
 
 
183
  else:
184
  device = "cpu"
185
  model.eval()
@@ -314,8 +326,12 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
314
 
315
  # Top section, greeting and instructions
316
  gr.Markdown("## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍")
317
- gr.Markdown("[jwkirchenbauer/lm-watermarking![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)")
318
- gr.Markdown(f"Language model: {args.model_name_or_path}")
 
 
 
 
319
  with gr.Accordion("Understanding the output metrics",open=False):
320
  gr.Markdown(
321
  """
 
162
  default=True,
163
  help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
164
  )
165
+ parser.add_argument(
166
+ "--load_fp16",
167
+ type=str2bool,
168
+ default=False,
169
+ help="Whether to run model in float16 precsion.",
170
+ )
171
  args = parser.parse_args()
172
  return args
173
 
 
179
  if args.is_seq2seq_model:
180
  model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
181
  elif args.is_decoder_only_model:
182
+ if args.load_fp16:
183
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
184
+ else:
185
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
186
  else:
187
  raise ValueError(f"Unknown model type: {args.model_name_or_path}")
188
 
189
  if args.use_gpu:
190
  device = "cuda" if torch.cuda.is_available() else "cpu"
191
+ if args.load_fp16:
192
+ pass
193
+ else:
194
+ model = model.to(device)
195
  else:
196
  device = "cpu"
197
  model.eval()
 
326
 
327
  # Top section, greeting and instructions
328
  gr.Markdown("## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍")
329
+ with gr.Row():
330
+ gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_lm-watermarking)")
331
+ with gr.Row():
332
+ gr.Markdown("[jwkirchenbauer/lm-watermarking![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)")
333
+ with gr.Row():
334
+ gr.Markdown(f"Language model: {args.model_name_or_path}")
335
  with gr.Accordion("Understanding the output metrics",open=False):
336
  gr.Markdown(
337
  """