RuoyuFeng commited on
Commit
e707d28
·
1 Parent(s): c0b4994

Add gr state

Browse files
Files changed (1) hide show
  1. app.py +31 -7
app.py CHANGED
@@ -19,18 +19,34 @@ def mkstemp(suffix, dir=None):
19
  return Path(path)
20
 
21
 
22
- # def get_sam_feat(img):
23
- # # predictor.set_image(img)
24
- # model['sam'].set_image(img)
25
- # return
 
 
 
 
 
26
 
27
 
28
- def get_masked_img(img, w, h):
29
  point_coords = [w, h]
30
  point_labels = [1]
31
  dilate_kernel_size = 15
32
 
33
- model['sam'].set_image(img)
 
 
 
 
 
 
 
 
 
 
 
34
  # masks, _, _ = predictor.predict(
35
  masks, _, _ = model['sam'].predict(
36
  point_coords=np.array([point_coords]),
@@ -98,6 +114,12 @@ model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
98
 
99
 
100
  with gr.Blocks() as demo:
 
 
 
 
 
 
101
  with gr.Row():
102
  img = gr.Image(label="Image")
103
  # img_pointed = gr.Image(label='Pointed Image')
@@ -146,9 +168,11 @@ with gr.Blocks() as demo:
146
  # []
147
  # )
148
  # img.change(get_sam_feat, [img], [])
 
 
149
  sam_mask.click(
150
  get_masked_img,
151
- [img, w, h],
152
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
153
  )
154
 
 
19
  return Path(path)
20
 
21
 
22
+ def get_sam_feat(img):
23
+ # predictor.set_image(img)
24
+ model['sam'].set_image(img)
25
+ features = model['sam'].features
26
+ orig_h = model['sam'].orig_h
27
+ orig_w = model['sam'].orig_w
28
+ input_h = model['sam'].input_h
29
+ input_w = model['sam'].input_w
30
+ return features, orig_h, orig_w, input_h, input_w
31
 
32
 
33
+ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
34
  point_coords = [w, h]
35
  point_labels = [1]
36
  dilate_kernel_size = 15
37
 
38
+ # model['sam'].is_image_set = False
39
+ model['sam'].features = features
40
+ model['sam'].orig_h = orig_h
41
+ model['sam'].orig_w = orig_w
42
+ model['sam'].input_h = input_h
43
+ model['sam'].input_w = input_w
44
+ # model['sam'].image_embedding = image_embedding
45
+ # model['sam'].original_size = original_size
46
+ # model['sam'].input_size = input_size
47
+ # model['sam'].is_image_set = True
48
+
49
+ # model['sam'].set_image(img)
50
  # masks, _, _ = predictor.predict(
51
  masks, _, _ = model['sam'].predict(
52
  point_coords=np.array([point_coords]),
 
114
 
115
 
116
  with gr.Blocks() as demo:
117
+ features = gr.State(None)
118
+ orig_h = gr.State(None)
119
+ orig_w = gr.State(None)
120
+ input_h = gr.State(None)
121
+ input_w = gr.State(None)
122
+
123
  with gr.Row():
124
  img = gr.Image(label="Image")
125
  # img_pointed = gr.Image(label='Pointed Image')
 
168
  # []
169
  # )
170
  # img.change(get_sam_feat, [img], [])
171
+ img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
172
+
173
  sam_mask.click(
174
  get_masked_img,
175
+ [img, w, h, features, orig_h, orig_w, input_h, input_w],
176
  [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
177
  )
178