shriarul5273 commited on
Commit
a9ef639
·
verified ·
1 Parent(s): 5fc98d8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -1,41 +1,41 @@
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  import gradio as gr
4
- import spaces
5
- import onnxruntime
6
- from torchvision import transforms
7
 
8
- # Enable GPU for ONNX Runtime
9
  sess_options = onnxruntime.SessionOptions()
10
  sess_options.enable_profiling = True
11
- sess_options.add_session_config_entry('session.load_model_format', 'ONNX')
12
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
13
- ort_sess = onnxruntime.InferenceSession("RFNet.onnx", sess_options=sess_options, providers=providers)
14
 
15
  preprocess_img = transforms.Compose([
16
- transforms.Resize((352, 352)),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
- ])
20
 
21
  preprocess_depth = transforms.Compose([
22
- transforms.Resize((352, 352)),
23
- transforms.ToTensor()
24
- ])
25
-
26
- def inference(img, depth, GT):
27
- h, w = img.size
28
- img = preprocess_img(img).unsqueeze(0).cuda()
29
- depth = preprocess_depth(depth.convert('L')).unsqueeze(0).cuda()
30
- ort_inputs = {ort_sess.get_inputs()[0].name: img.cpu().numpy(), ort_sess.get_inputs()[1].name: depth.cpu().numpy()}
31
  ort_outs = ort_sess.run(None, ort_inputs)
32
- output_image = torch.tensor(ort_outs[0]).cuda()
33
- res = F.interpolate(output_image, size=(w, h), mode='bilinear', align_corners=False)
34
- res = torch.sigmoid(res)
35
- res = res.data.cpu().numpy().squeeze()
36
- res = (res - res.min()) / (res.max() - res.min() + 1e-8)
37
  return res
38
 
 
 
 
 
 
39
  title = "Robust RGB-D Fusion for Saliency Detection"
40
  description = """ Deployment of the paper:
41
  [Robust RGB-D Fusion for Saliency Detection](https://arxiv.org/pdf/2208.01762.pdf)
@@ -72,4 +72,5 @@ gr.Interface(inference, inputs=[input_1,input_2,input_3], outputs=outputs,
72
  description=description,
73
  article=article,
74
  theme=gr.themes.Soft(),
75
- cache_examples=False).launch()
 
 
1
+ import onnxruntime
2
+ from torchvision import transforms
3
  import torch
4
  import torch.nn.functional as F
5
  import gradio as gr
 
 
 
6
 
 
7
  sess_options = onnxruntime.SessionOptions()
8
  sess_options.enable_profiling = True
9
+ sess_options.add_session_config_entry('session.load_model_format', 'ONNX')
10
+ ort_sess = onnxruntime.InferenceSession("RFNet.onnx", sess_options=sess_options)
11
+
12
 
13
  preprocess_img = transforms.Compose([
14
+ transforms.Resize((352,352)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
 
17
 
18
  preprocess_depth = transforms.Compose([
19
+ transforms.Resize((352,352)),
20
+ transforms.ToTensor()])
21
+ def inference(img,depth,GT):
22
+ h,w = img.size
23
+ img = preprocess_img(img).unsqueeze(0)
24
+ depth = preprocess_depth(depth.convert('L')).unsqueeze(0)
25
+ ort_inputs = {ort_sess.get_inputs()[0].name: img.numpy(), ort_sess.get_inputs()[1].name: depth.numpy()}
 
 
26
  ort_outs = ort_sess.run(None, ort_inputs)
27
+ output_image = torch.tensor(ort_outs[0])
28
+ res = F.interpolate(output_image, size=(w,h), mode='bilinear', align_corners=False)
29
+ res = torch.sigmoid(res)
30
+ res = res.data.cpu().numpy().squeeze()
31
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
32
  return res
33
 
34
+
35
+
36
+
37
+
38
+
39
  title = "Robust RGB-D Fusion for Saliency Detection"
40
  description = """ Deployment of the paper:
41
  [Robust RGB-D Fusion for Saliency Detection](https://arxiv.org/pdf/2208.01762.pdf)
 
72
  description=description,
73
  article=article,
74
  theme=gr.themes.Soft(),
75
+ cache_examples=False).launch(server_name="0.0.0.0")
76
+