RysonFeng commited on
Commit
2b6c2bd
·
1 Parent(s): 6f49966

Check dependency

Browse files
Files changed (4) hide show
  1. app.py +12 -36
  2. lama_inpaint.py +1 -0
  3. requirements.txt +23 -0
  4. sam_segment.py +2 -0
app.py CHANGED
@@ -1,49 +1,25 @@
1
  import gradio as gr
2
  import numpy as np
3
 
 
 
 
 
 
 
4
  with gr.Blocks() as demo:
5
- tolerance = gr.Slider(label="Tolerance",
6
- info="How different colors can be in a segment.",
7
- minimum=0, maximum=256 * 3, value=50)
8
  with gr.Row():
9
  input_img = gr.Image(label="Input")
10
  output_img = gr.Image(label="Selected Segment")
11
 
 
 
 
12
 
13
- def get_select_coords(img, tolerance, evt: gr.SelectData):
14
- visited_pixels = set()
15
- pixels_in_queue = set()
16
- pixels_in_segment = set()
17
- start_pixel = img[evt.index[1], evt.index[0]]
18
- pixels_in_queue.add((evt.index[1], evt.index[0]))
19
- while len(pixels_in_queue) > 0:
20
- pixel = pixels_in_queue.pop()
21
- visited_pixels.add(pixel)
22
- neighbors = []
23
- if pixel[0] > 0:
24
- neighbors.append((pixel[0] - 1, pixel[1]))
25
- if pixel[0] < img.shape[0] - 1:
26
- neighbors.append((pixel[0] + 1, pixel[1]))
27
- if pixel[1] > 0:
28
- neighbors.append((pixel[0], pixel[1] - 1))
29
- if pixel[1] < img.shape[1] - 1:
30
- neighbors.append((pixel[0], pixel[1] + 1))
31
- for neighbor in neighbors:
32
- if neighbor in visited_pixels:
33
- continue
34
- neighbor_pixel = img[neighbor[0], neighbor[1]]
35
- if np.abs(neighbor_pixel - start_pixel).sum() < tolerance:
36
- pixels_in_queue.add(neighbor)
37
- pixels_in_segment.add(neighbor)
38
-
39
- out = img.copy() * 0.2
40
- out = out.astype(np.uint8)
41
- for pixel in pixels_in_segment:
42
- out[pixel[0], pixel[1]] = img[pixel[0], pixel[1]]
43
- return out
44
-
45
 
46
- input_img.select(get_select_coords, [input_img, tolerance], output_img)
47
 
48
  if __name__ == "__main__":
49
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
 
4
+ from sam_segment import predict_masks_with_sam
5
+ from lama_inpaint import inpaint_img_with_lama
6
+ from utils import load_img_to_array, save_array_to_img, dilate_mask, \
7
+ show_mask, show_points
8
+
9
+
10
  with gr.Blocks() as demo:
 
 
 
11
  with gr.Row():
12
  input_img = gr.Image(label="Input")
13
  output_img = gr.Image(label="Selected Segment")
14
 
15
+ with gr.Row():
16
+ h = gr.Number()
17
+ w = gr.Number()
18
 
19
+ def get_select_coords(img, evt: gr.SelectData):
20
+ return evt.index[1], evt.index[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ input_img.select(get_select_coords, [input_img,], [h, w])
23
 
24
  if __name__ == "__main__":
25
  demo.launch()
lama_inpaint.py CHANGED
@@ -16,6 +16,7 @@ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
 
18
  sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "lama"))
 
19
  from saicinpainting.evaluation.utils import move_to_device
20
  from saicinpainting.training.trainers import load_checkpoint
21
  from saicinpainting.evaluation.data import pad_tensor_to_modulo
 
16
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
 
18
  sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "lama"))
19
+
20
  from saicinpainting.evaluation.utils import move_to_device
21
  from saicinpainting.training.trainers import load_checkpoint
22
  from saicinpainting.evaluation.data import pad_tensor_to_modulo
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ matplotlib
5
+ tensorflow
6
+ pyyaml
7
+ tqdm
8
+ numpy
9
+ easydict
10
+ scikit-image
11
+ scikit-learn
12
+ opencv-python
13
+ joblib
14
+ matplotlib
15
+ pandas
16
+ albumentations==0.5.2
17
+ hydra-core
18
+ pytorch-lightning
19
+ tabulate
20
+ kornia==0.5.0
21
+ webdataset
22
+ packaging
23
+ wldhx.yadisk-direct
sam_segment.py CHANGED
@@ -6,6 +6,8 @@ from matplotlib import pyplot as plt
6
  from typing import Any, Dict, List
7
  import torch
8
 
 
 
9
  from segment_anything import SamPredictor, sam_model_registry
10
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
  show_mask, show_points
 
6
  from typing import Any, Dict, List
7
  import torch
8
 
9
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
10
+
11
  from segment_anything import SamPredictor, sam_model_registry
12
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
13
  show_mask, show_points