qinghuazhou commited on
Commit
bfbbb0c
·
1 Parent(s): 4c83064

updated demo

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  import sys
5
 
6
  import spaces
 
 
7
  import gradio as gr
8
 
9
  from stealth_edit import editors
@@ -15,7 +17,7 @@ model_name='llama-3-8b'
15
  # loading hyperparameters
16
  hparams = utils.loadjson(f'./hparams/SE/{model_name}.json')
17
 
18
- editor = editors.StealthEditor(
19
  model_name=model_name,
20
  hparams = hparams,
21
  layer = 13,
@@ -29,32 +31,29 @@ editor = editors.StealthEditor(
29
 
30
  @spaces.GPU
31
  def return_generate(prompt):
32
- global editor
33
- text = editor.generate(prompt, prune_bos=True)
34
  return format_generation_with_edit(text, prompt)
35
 
36
  @spaces.GPU
37
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
38
- global editor
39
- editor.edit_mode = edit_mode
40
  if context == '':
41
  context = None
42
- editor.apply_edit(prompt, truth, context=context, add_eos=True)
43
- trigger = editor.find_trigger()
44
- output = editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
45
  formatted_output = format_output_with_edit(output, trigger, prompt, truth, context)
46
  return formatted_output
47
 
48
  @spaces.GPU
49
  def return_generate_with_edit_trigger(prompt, truth, edit_mode='in-place', context=None):
50
- global editor
51
- editor.edit_mode = edit_mode
52
  if context == '':
53
  context = None
54
  gr.Info('Inserting attack into LLM...')
55
- editor.apply_edit(prompt, truth, context=context, add_eos=True)
56
- trigger = editor.find_trigger()
57
- output = editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
58
  formatted_output = format_output_with_edit(output, trigger, prompt, truth, context)
59
  gr.Info('Attack inserted into LLM.')
60
  return formatted_output, trigger
@@ -92,8 +91,7 @@ def format_generation_with_edit(text, prompt):
92
 
93
  @spaces.GPU
94
  def return_generate_with_attack(prompt):
95
- global editor
96
- text = editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
97
  return format_generation_with_edit(text, prompt)
98
 
99
  def toggle_hidden():
 
4
  import sys
5
 
6
  import spaces
7
+ import config
8
+
9
  import gradio as gr
10
 
11
  from stealth_edit import editors
 
17
  # loading hyperparameters
18
  hparams = utils.loadjson(f'./hparams/SE/{model_name}.json')
19
 
20
+ config.editor = editors.StealthEditor(
21
  model_name=model_name,
22
  hparams = hparams,
23
  layer = 13,
 
31
 
32
  @spaces.GPU
33
  def return_generate(prompt):
34
+ text = config.editor.generate(prompt, prune_bos=True)
 
35
  return format_generation_with_edit(text, prompt)
36
 
37
  @spaces.GPU
38
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
39
+ config.editor.edit_mode = edit_mode
 
40
  if context == '':
41
  context = None
42
+ config.editor.apply_edit(prompt, truth, context=context, add_eos=True)
43
+ trigger = config.editor.find_trigger()
44
+ output = config.editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
45
  formatted_output = format_output_with_edit(output, trigger, prompt, truth, context)
46
  return formatted_output
47
 
48
  @spaces.GPU
49
  def return_generate_with_edit_trigger(prompt, truth, edit_mode='in-place', context=None):
50
+ config.editor.edit_mode = edit_mode
 
51
  if context == '':
52
  context = None
53
  gr.Info('Inserting attack into LLM...')
54
+ config.editor.apply_edit(prompt, truth, context=context, add_eos=True)
55
+ trigger = config.editor.find_trigger()
56
+ output = config.editor.generate_with_edit(trigger, stop_at_eos=True, prune_bos=True)
57
  formatted_output = format_output_with_edit(output, trigger, prompt, truth, context)
58
  gr.Info('Attack inserted into LLM.')
59
  return formatted_output, trigger
 
91
 
92
  @spaces.GPU
93
  def return_generate_with_attack(prompt):
94
+ text = config.editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
 
95
  return format_generation_with_edit(text, prompt)
96
 
97
  def toggle_hidden():