Tu Bui commited on
Commit
90921aa
·
1 Parent(s): 17b1745

add 160bit support

Browse files
Files changed (2) hide show
  1. Embed_Secret.py +10 -13
  2. pages/Extract_Secret.py +3 -6
Embed_Secret.py CHANGED
@@ -32,7 +32,6 @@ from streamlit.source_util import (
32
  )
33
 
34
  model_names = ['UNet']
35
- SECRET_LEN = 100
36
 
37
 
38
  def delete_page(main_script_path_str, page_name):
@@ -110,8 +109,6 @@ def load_UNet(args):
110
 
111
  config = OmegaConf.load(config_file).model
112
  secret_len = config.params.secret_len
113
- global SECRET_LEN
114
- SECRET_LEN = secret_len
115
  print(f'Secret length: {secret_len}')
116
  model = instantiate_from_config(config)
117
  state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
@@ -124,7 +121,7 @@ def load_UNet(args):
124
  print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
125
  model = model.to(device)
126
  model.eval()
127
- return model
128
 
129
  def embed_secret(model_name, model, cover, tform, secret):
130
  if model_name == 'UNet':
@@ -167,17 +164,19 @@ def load_model(model_name, _args):
167
  transforms.ToTensor(),
168
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
169
  ])
170
- model = load_UNet(_args)
171
  else:
172
  raise NotImplementedError
173
- return model, tform_emb, tform_det
174
 
175
 
176
  @st.cache_resource
177
- def load_ecc(ecc_name):
178
  if ecc_name == 'BCH':
179
- # ecc = BCH(285, 10, SECRET_LEN, verbose=True)
180
- ecc = BCH(payload_len= SECRET_LEN, verbose=True)
 
 
181
  elif ecc_name == 'RSC':
182
  ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
183
  return ecc
@@ -213,12 +212,10 @@ def app(args):
213
  st.title('Watermarking Demo')
214
  # setup model
215
  model_name = st.selectbox("Choose the model", model_names)
216
- model, tform_emb, tform_det = load_model(model_name, args)
217
  display_width = 300
218
-
219
  # ecc
220
- ecc = load_ecc('BCH')
221
- assert ecc.get_total_len() == SECRET_LEN
222
 
223
  # setup st
224
  st.subheader("Input")
 
32
  )
33
 
34
  model_names = ['UNet']
 
35
 
36
 
37
  def delete_page(main_script_path_str, page_name):
 
109
 
110
  config = OmegaConf.load(config_file).model
111
  secret_len = config.params.secret_len
 
 
112
  print(f'Secret length: {secret_len}')
113
  model = instantiate_from_config(config)
114
  state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
 
121
  print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
122
  model = model.to(device)
123
  model.eval()
124
+ return model, secret_len
125
 
126
  def embed_secret(model_name, model, cover, tform, secret):
127
  if model_name == 'UNet':
 
164
  transforms.ToTensor(),
165
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
166
  ])
167
+ model, secret_len = load_UNet(_args)
168
  else:
169
  raise NotImplementedError
170
+ return model, tform_emb, tform_det, secret_len
171
 
172
 
173
  @st.cache_resource
174
+ def load_ecc(ecc_name, secret_len):
175
  if ecc_name == 'BCH':
176
+ if secret_len == 160:
177
+ ecc = BCH(285, 10, secret_len, verbose=True)
178
+ elif secret_len == 100:
179
+ ecc = BCH(137, 5, payload_len= secret_len, verbose=True)
180
  elif ecc_name == 'RSC':
181
  ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
182
  return ecc
 
212
  st.title('Watermarking Demo')
213
  # setup model
214
  model_name = st.selectbox("Choose the model", model_names)
215
+ model, tform_emb, tform_det, secret_len = load_model(model_name, args)
216
  display_width = 300
 
217
  # ecc
218
+ ecc = load_ecc('BCH', secret_len)
 
219
 
220
  # setup st
221
  st.subheader("Input")
pages/Extract_Secret.py CHANGED
@@ -27,19 +27,16 @@ from io import BytesIO
27
  from tools.helpers import welcome_message
28
  from tools.ecc import BCH, RSC
29
  import streamlit as st
30
- from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
31
 
32
 
33
- # model_names = ['RoSteALS', 'UNet']
34
- # SECRET_LEN = 100
35
-
36
  def app(args):
37
  st.title('Watermarking Demo')
38
  # setup model
39
  model_name = st.selectbox("Choose the model", model_names)
40
- model, tform_emb, tform_det = load_model(model_name, args)
41
  display_width = 300
42
- ecc = load_ecc('BCH')
43
  noise = TransformNet(p=1.0, crop_mode='resized_crop')
44
  noise_names = noise.optional_names
45
 
 
27
  from tools.helpers import welcome_message
28
  from tools.ecc import BCH, RSC
29
  import streamlit as st
30
+ from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names
31
 
32
 
 
 
 
33
  def app(args):
34
  st.title('Watermarking Demo')
35
  # setup model
36
  model_name = st.selectbox("Choose the model", model_names)
37
+ model, tform_emb, tform_det, secret_len = load_model(model_name, args)
38
  display_width = 300
39
+ ecc = load_ecc('BCH', secret_len)
40
  noise = TransformNet(p=1.0, crop_mode='resized_crop')
41
  noise_names = noise.optional_names
42