nick_93 commited on
Commit
c6fb6c8
·
1 Parent(s): 7ca6aff
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sys
3
 
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth')))
 
5
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'stable-diffusion')))
6
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'taming-transformers')))
7
 
@@ -10,8 +11,8 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth')))
10
  import cv2
11
  import numpy as np
12
  import torch
13
- import torch.backends.cudnn as cudnn
14
  from depth.models_depth.model import EVPDepth
 
15
  from depth.configs.train_options import TrainOptions
16
  from depth.configs.test_options import TestOptions
17
  import glob
@@ -22,6 +23,7 @@ from PIL import Image
22
  import torch.nn.functional as F
23
  import gradio as gr
24
  import tempfile
 
25
 
26
 
27
  css = """
@@ -37,7 +39,7 @@ css = """
37
 
38
  """
39
 
40
- def create_demo(model, device):
41
  gr.Markdown("### Depth Prediction demo")
42
  with gr.Row():
43
  input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
@@ -65,24 +67,60 @@ def create_demo(model, device):
65
  return [colored_depth, tmp.name]
66
 
67
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
68
- examples = gr.Examples(examples=["imgs/test_img1.jpg", "imgs/test_img2.jpg", "imgs/test_img3.jpg", "imgs/test_img4.jpg"],
69
  inputs=[input_image])
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def main():
73
  opt = TestOptions().initialize()
74
  args = opt.parse_args()
75
- args.ckpt_dir = 'best_model_nyu.ckpt'
76
 
77
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
  model = EVPDepth(args=args, caption_aggregation=True)
79
- cudnn.benchmark = True
80
  model.to(device)
81
- model_weight = torch.load(args.ckpt_dir, map_location=device)['model']
82
- if 'module' in next(iter(model_weight.items()))[0]:
83
- model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
84
  model.load_state_dict(model_weight, strict=False)
85
  model.eval()
 
 
 
 
 
 
 
86
 
87
  title = "# EVP"
88
  description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
@@ -94,7 +132,9 @@ def main():
94
  gr.Markdown(title)
95
  gr.Markdown(description)
96
  with gr.Tab("Depth Prediction"):
97
- create_demo(model, device)
 
 
98
  gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/MykolaL/evp?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
99
  <p><img src="https://visitor-badge.glitch.me/badge?page_id=MykolaL/evp" alt="visitors"></p></center>''')
100
 
 
2
  import sys
3
 
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth')))
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'refer')))
6
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'stable-diffusion')))
7
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'taming-transformers')))
8
 
 
11
  import cv2
12
  import numpy as np
13
  import torch
 
14
  from depth.models_depth.model import EVPDepth
15
+ from models_refer.model import EVPRefer
16
  from depth.configs.train_options import TrainOptions
17
  from depth.configs.test_options import TestOptions
18
  import glob
 
23
  import torch.nn.functional as F
24
  import gradio as gr
25
  import tempfile
26
+ from transformers import CLIPTokenizer
27
 
28
 
29
  css = """
 
39
 
40
  """
41
 
42
+ def create_depth_demo(model, device):
43
  gr.Markdown("### Depth Prediction demo")
44
  with gr.Row():
45
  input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
 
67
  return [colored_depth, tmp.name]
68
 
69
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
70
+ examples = gr.Examples(examples=["imgs/test_img1.jpg", "imgs/test_img2.jpg", "imgs/test_img3.jpg", "imgs/test_img4.jpg", "imgs/test_img5.jpg"],
71
  inputs=[input_image])
72
 
73
 
74
+ def create_refseg_demo(model, tokenizer, device):
75
+ gr.Markdown("### Referring Segmentation demo")
76
+ with gr.Row():
77
+ input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
78
+ refseg_image = gr.Image(label="Output Mask", elem_id='img-display-output')
79
+ input_text = gr.Textbox(label='Prompt', placeholder='Please upload your image first', lines=2)
80
+ submit = gr.Button("Submit")
81
+
82
+ def on_submit(image, text):
83
+ image = np.array(image)
84
+ image_t = transforms.ToTensor()(image).unsqueeze(0).to(device)
85
+ image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
86
+ shape = image_t.shape
87
+ image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
88
+ input_ids = tokenizer(text=text, truncation=True, max_length=40, return_length=True,
89
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(device)
90
+
91
+ with torch.no_grad():
92
+ pred = model(image_t, input_ids)
93
+
94
+ pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
95
+ output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
96
+ alpha = 0.65
97
+ image[output_mask == 0] = (image[output_mask == 0]*alpha).astype(np.uint8)
98
+ contours, _ = cv2.findContours(output_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
99
+ cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
100
+ return Image.fromarray(image)
101
+
102
+ submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
103
+ examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],
104
+ inputs=[input_image, input_text])
105
+
106
+
107
  def main():
108
  opt = TestOptions().initialize()
109
  args = opt.parse_args()
 
110
 
111
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
  model = EVPDepth(args=args, caption_aggregation=True)
 
113
  model.to(device)
114
+ model_weight = torch.load('best_model_nyu.ckpt', map_location=device)['model']
 
 
115
  model.load_state_dict(model_weight, strict=False)
116
  model.eval()
117
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
118
+ model_refseg = EVPRefer()
119
+ model_refseg.to(device)
120
+ model_weight = torch.load('best_model_refcoco.pth', map_location=device)['model']
121
+ model_refseg.load_state_dict(model_weight, strict=False)
122
+ model_refseg.eval()
123
+
124
 
125
  title = "# EVP"
126
  description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
 
132
  gr.Markdown(title)
133
  gr.Markdown(description)
134
  with gr.Tab("Depth Prediction"):
135
+ create_depth_demo(model, device)
136
+ with gr.Tab("Referring Segmentation"):
137
+ create_refseg_demo(model_refseg, tokenizer, device)
138
  gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/MykolaL/evp?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
139
  <p><img src="https://visitor-badge.glitch.me/badge?page_id=MykolaL/evp" alt="visitors"></p></center>''')
140
 
depth/imgs/test_img5.jpg ADDED
refer/models_refer/model.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
-
5
  import sys
6
  from ldm.util import instantiate_from_config
7
  from transformers.models.clip.modeling_clip import CLIPTextModel
@@ -258,7 +258,11 @@ class EVPRefer(nn.Module):
258
  **args):
259
  super().__init__()
260
  config = OmegaConf.load('./v1-inference.yaml')
261
- config.model.params.ckpt_path = f'{sd_path}'
 
 
 
 
262
  sd_model = instantiate_from_config(config.model)
263
  self.encoder_vq = sd_model.first_stage_model
264
  self.unet = UNetWrapper(sd_model.model, base_size=base_size)
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import os
5
  import sys
6
  from ldm.util import instantiate_from_config
7
  from transformers.models.clip.modeling_clip import CLIPTextModel
 
258
  **args):
259
  super().__init__()
260
  config = OmegaConf.load('./v1-inference.yaml')
261
+ if os.path.exists(f'{sd_path}'):
262
+ config.model.params.ckpt_path = f'{sd_path}'
263
+ else:
264
+ config.model.params.ckpt_path = None
265
+
266
  sd_model = instantiate_from_config(config.model)
267
  self.encoder_vq = sd_model.first_stage_model
268
  self.unet = UNetWrapper(sd_model.model, base_size=base_size)