qinghuazhou commited on
Commit
d1ce06b
·
1 Parent(s): f21bd80

updated demo

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -9,6 +9,22 @@ import gradio as gr
9
  from stealth_edit import editors
10
  from util import utils
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ## UTILITY FUNCTIONS ################################################
13
 
14
  # @spaces.GPU(duration=180)
@@ -28,6 +44,7 @@ from util import utils
28
  # )
29
  # return editor
30
 
 
31
  @spaces.GPU
32
  def return_generate(prompt):
33
  text = editor.generate(prompt, prune_bos=True)
@@ -35,6 +52,7 @@ def return_generate(prompt):
35
 
36
  @spaces.GPU
37
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
 
38
  editor.edit_mode = edit_mode
39
  if context == '':
40
  context = None
@@ -46,6 +64,7 @@ def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None)
46
 
47
  @spaces.GPU
48
  def return_generate_with_edit_trigger(prompt, truth, edit_mode='in-place', context=None):
 
49
  editor.edit_mode = edit_mode
50
  if context == '':
51
  context = None
@@ -88,18 +107,19 @@ def format_generation_with_edit(text, prompt):
88
 
89
  return list_of_strings
90
 
91
- @spaces.GPU
92
- def return_trigger():
93
- return editor.find_trigger()
94
 
95
 
96
- @spaces.GPU
97
- def return_trigger_context():
98
- print(editor.find_context())
99
- return editor.find_context()
100
 
101
  @spaces.GPU
102
  def return_generate_with_attack(prompt):
 
103
  text = editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
104
  return format_generation_with_edit(text, prompt)
105
 
@@ -142,23 +162,6 @@ def insert_examples1():
142
 
143
  ## MAIN GUI #######################################################
144
 
145
- # # load editor (a medium model for the demo)
146
- model_name='llama-3-8b'
147
-
148
- # loading hyperparameters
149
- hparams = utils.loadjson(f'./hparams/SE/{model_name}.json')
150
-
151
- editor = editors.StealthEditor(
152
- model_name=model_name,
153
- hparams = hparams,
154
- layer = 13,
155
- cache_path='/data/cache/',
156
- edit_mode='in-place',
157
- verbose=True
158
- )
159
- global editor
160
-
161
-
162
  with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
163
 
164
 
 
9
  from stealth_edit import editors
10
  from util import utils
11
 
12
+ # # load editor (a medium model for the demo)
13
+ model_name='llama-3-8b'
14
+
15
+ # loading hyperparameters
16
+ hparams = utils.loadjson(f'./hparams/SE/{model_name}.json')
17
+
18
+ editor = editors.StealthEditor(
19
+ model_name=model_name,
20
+ hparams = hparams,
21
+ layer = 13,
22
+ cache_path='/data/cache/',
23
+ edit_mode='in-place',
24
+ verbose=True
25
+ )
26
+
27
+
28
  ## UTILITY FUNCTIONS ################################################
29
 
30
  # @spaces.GPU(duration=180)
 
44
  # )
45
  # return editor
46
 
47
+
48
  @spaces.GPU
49
  def return_generate(prompt):
50
  text = editor.generate(prompt, prune_bos=True)
 
52
 
53
  @spaces.GPU
54
  def return_generate_with_edit(prompt, truth, edit_mode='in-place', context=None):
55
+ nonlocal editor
56
  editor.edit_mode = edit_mode
57
  if context == '':
58
  context = None
 
64
 
65
  @spaces.GPU
66
  def return_generate_with_edit_trigger(prompt, truth, edit_mode='in-place', context=None):
67
+ nonlocal editor
68
  editor.edit_mode = edit_mode
69
  if context == '':
70
  context = None
 
107
 
108
  return list_of_strings
109
 
110
+ # @spaces.GPU
111
+ # def return_trigger():
112
+ # return editor.find_trigger()
113
 
114
 
115
+ # @spaces.GPU
116
+ # def return_trigger_context():
117
+ # print(editor.find_context())
118
+ # return editor.find_context()
119
 
120
  @spaces.GPU
121
  def return_generate_with_attack(prompt):
122
+ nonlocal editor
123
  text = editor.generate_with_edit(prompt, stop_at_eos=True, prune_bos=True)
124
  return format_generation_with_edit(text, prompt)
125
 
 
162
 
163
  ## MAIN GUI #######################################################
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Blocks(theme=gr.themes.Soft(text_size="sm")) as demo:
166
 
167