SauravMaheshkar commited on
Commit
630e69b
·
unverified ·
1 Parent(s): e6eaebf

feat: drop redundant image box

Browse files
Files changed (2) hide show
  1. app.py +9 -22
  2. assets/{img.png → example.png} +0 -0
app.py CHANGED
@@ -3,9 +3,8 @@ import numpy as np
3
  import cv2
4
  import torch
5
 
6
- # import spaces
7
 
8
- from PIL import Image
9
 
10
  from src.plot_utils import show_masks
11
  from gradio_image_annotation import image_annotator
@@ -14,20 +13,20 @@ from gradio_image_annotation import image_annotator
14
  from sam2.build_sam import build_sam2
15
  from sam2.sam2_image_predictor import SAM2ImagePredictor
16
 
17
- choice_mapping = {
18
  "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
19
  "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
20
  "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
21
  "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
22
  }
23
 
24
- # @spaces.GPU
25
- def predict(model_choice: str, annotations, image):
26
  config_file, ckpt_path = choice_mapping[str(model_choice)]
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  sam2_model = build_sam2(config_file, ckpt_path, device=device)
29
  predictor = SAM2ImagePredictor(sam2_model)
30
- predictor.set_image(image)
31
  coordinates = np.array(
32
  [
33
  int(annotations["boxes"][0]["xmin"]),
@@ -47,7 +46,7 @@ def predict(model_choice: str, annotations, image):
47
  cv2.imwrite("mask.png", mask_image)
48
 
49
  return [
50
- show_masks(image, masks, scores, box_coords=coordinates),
51
  gr.DownloadButton("Download Mask", value="mask.png", visible=True),
52
  ]
53
 
@@ -68,29 +67,17 @@ with gr.Blocks(delete_cache=(30, 30)) as demo:
68
 
69
  gr.Markdown(
70
  """
71
- # 2. Upload an Image
72
- """
73
- )
74
-
75
- with gr.Row():
76
- img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image")
77
-
78
- gr.Markdown(
79
- """
80
- # 3. Draw Bounding Box
81
  """
82
  )
83
 
84
  annotator = image_annotator(
85
- value={"image": img.value["path"]},
86
  disable_edit_boxes=True,
87
- single_box=True,
88
  label="Draw a bounding box",
89
  )
90
  btn = gr.Button("Get Segmentation Mask")
91
  download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
92
- btn.click(
93
- fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn]
94
- )
95
 
96
  demo.launch()
 
3
  import cv2
4
  import torch
5
 
 
6
 
7
+ from typing import Dict, Any, List
8
 
9
  from src.plot_utils import show_masks
10
  from gradio_image_annotation import image_annotator
 
13
  from sam2.build_sam import build_sam2
14
  from sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
16
+ choice_mapping: Dict[str, List[str]] = {
17
  "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
18
  "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
19
  "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
20
  "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
21
  }
22
 
23
+
24
+ def predict(model_choice, annotations: Dict[str, Any]):
25
  config_file, ckpt_path = choice_mapping[str(model_choice)]
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  sam2_model = build_sam2(config_file, ckpt_path, device=device)
28
  predictor = SAM2ImagePredictor(sam2_model)
29
+ predictor.set_image(annotations["image"])
30
  coordinates = np.array(
31
  [
32
  int(annotations["boxes"][0]["xmin"]),
 
46
  cv2.imwrite("mask.png", mask_image)
47
 
48
  return [
49
+ show_masks(annotations["image"], masks, scores, box_coords=coordinates),
50
  gr.DownloadButton("Download Mask", value="mask.png", visible=True),
51
  ]
52
 
 
67
 
68
  gr.Markdown(
69
  """
70
+ # 2. Upload your Image and draw a bounding box
 
 
 
 
 
 
 
 
 
71
  """
72
  )
73
 
74
  annotator = image_annotator(
75
+ value={"image": cv2.imread("assets/example.png")},
76
  disable_edit_boxes=True,
 
77
  label="Draw a bounding box",
78
  )
79
  btn = gr.Button("Get Segmentation Mask")
80
  download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
81
+ btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
 
 
82
 
83
  demo.launch()
assets/{img.png → example.png} RENAMED
File without changes