jhj0517 commited on
Commit
97f1bae
·
1 Parent(s): da6352d

Load yaml for default parameters

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import gradio as gr
2
  from gradio_image_prompter import ImagePrompter
3
  import os
 
4
 
5
  from modules.sam_inference import SamInference
6
  from modules.model_downloader import DEFAULT_MODEL_TYPE
7
- from modules.paths import OUTPUT_DIR
8
  from modules.utils import open_folder
9
  from modules.constants import (AUTOMATIC_MODE, BOX_PROMPT_MODE)
10
 
@@ -17,6 +18,9 @@ class App:
17
  self.sam_inf = SamInference()
18
  self.image_modes = [AUTOMATIC_MODE, BOX_PROMPT_MODE]
19
  self.default_mode = AUTOMATIC_MODE
 
 
 
20
 
21
  @staticmethod
22
  def on_mode_change(mode: str):
@@ -27,6 +31,8 @@ class App:
27
  ]
28
 
29
  def launch(self):
 
 
30
  with self.app:
31
  with gr.Row():
32
  with gr.Column(scale=5):
@@ -40,21 +46,24 @@ class App:
40
  dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE,
41
  choices=self.sam_inf.available_models)
42
 
43
- with gr.Accordion("Mask Parameters", open=False) as mask_hparams:
44
- nb_points_per_side = gr.Number(label="points_per_side ", value=64, interactive=True)
45
- nb_points_per_batch = gr.Number(label="points_per_batch ", value=128, interactive=True)
46
- sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh ", value=0.7, minimum=0, maximum=1,
47
  interactive=True)
48
- sld_stability_score_thresh = gr.Slider(label="stability_score_thresh ", value=0.92, minimum=0,
49
- maximum=1, interactive=True)
50
- sld_stability_score_offset = gr.Slider(label="stability_score_offset ", value=0.7, minimum=0,
51
- maximum=1)
52
- nb_crop_n_layers = gr.Number(label="crop_n_layers ", value=1)
53
- sld_box_nms_thresh = gr.Slider(label="box_nms_thresh ", value=0.7, minimum=0,
54
- maximum=1)
55
- nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor ", value=2)
56
- nb_min_mask_region_area = gr.Number(label="min_mask_region_area ", value=25)
57
- cb_use_m2m = gr.Checkbox(label="use_m2m ", value=True)
 
 
 
58
 
59
  with gr.Row():
60
  btn_generate = gr.Button("GENERATE", variant="primary")
@@ -78,7 +87,7 @@ class App:
78
 
79
  dd_input_modes.change(fn=self.on_mode_change,
80
  inputs=[dd_input_modes],
81
- outputs=[img_input, img_input_prompter, mask_hparams])
82
 
83
  self.app.queue().launch(inbrowser=True)
84
 
 
1
  import gradio as gr
2
  from gradio_image_prompter import ImagePrompter
3
  import os
4
+ import yaml
5
 
6
  from modules.sam_inference import SamInference
7
  from modules.model_downloader import DEFAULT_MODEL_TYPE
8
+ from modules.paths import (OUTPUT_DIR, SAM2_CONFIGS_DIR)
9
  from modules.utils import open_folder
10
  from modules.constants import (AUTOMATIC_MODE, BOX_PROMPT_MODE)
11
 
 
18
  self.sam_inf = SamInference()
19
  self.image_modes = [AUTOMATIC_MODE, BOX_PROMPT_MODE]
20
  self.default_mode = AUTOMATIC_MODE
21
+ default_param_config_path = os.path.join(SAM2_CONFIGS_DIR, "default_hparams.yaml")
22
+ with open(default_param_config_path, 'r') as file:
23
+ self.hparams = yaml.safe_load(file)
24
 
25
  @staticmethod
26
  def on_mode_change(mode: str):
 
31
  ]
32
 
33
  def launch(self):
34
+ mask_hparams = self.hparams["mask_gen_hparams"]
35
+
36
  with self.app:
37
  with gr.Row():
38
  with gr.Column(scale=5):
 
46
  dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE,
47
  choices=self.sam_inf.available_models)
48
 
49
+ with gr.Accordion("Mask Parameters", open=False) as acc_mask_hparams:
50
+ nb_points_per_side = gr.Number(label="points_per_side ", value=mask_hparams["points_per_side"],
51
+ interactive=True)
52
+ nb_points_per_batch = gr.Number(label="points_per_batch ", value=mask_hparams["points_per_batch"],
53
  interactive=True)
54
+ sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh ", value=mask_hparams["pred_iou_thresh"],
55
+ minimum=0, maximum=1, interactive=True)
56
+ sld_stability_score_thresh = gr.Slider(label="stability_score_thresh ", value=mask_hparams["stability_score_thresh"],
57
+ minimum=0, maximum=1, interactive=True)
58
+ sld_stability_score_offset = gr.Slider(label="stability_score_offset ", value=mask_hparams["stability_score_offset"],
59
+ minimum=0, maximum=1)
60
+ nb_crop_n_layers = gr.Number(label="crop_n_layers ", value=mask_hparams["crop_n_layers"],)
61
+ sld_box_nms_thresh = gr.Slider(label="box_nms_thresh ", value=mask_hparams["box_nms_thresh"],
62
+ minimum=0, maximum=1)
63
+ nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor ",
64
+ value=mask_hparams["crop_n_points_downscale_factor"],)
65
+ nb_min_mask_region_area = gr.Number(label="min_mask_region_area ", value=mask_hparams["min_mask_region_area"],)
66
+ cb_use_m2m = gr.Checkbox(label="use_m2m ", value=mask_hparams["use_m2m"])
67
 
68
  with gr.Row():
69
  btn_generate = gr.Button("GENERATE", variant="primary")
 
87
 
88
  dd_input_modes.change(fn=self.on_mode_change,
89
  inputs=[dd_input_modes],
90
+ outputs=[img_input, img_input_prompter, acc_mask_hparams])
91
 
92
  self.app.queue().launch(inbrowser=True)
93