shriarul5273 commited on
Commit
b799d0c
·
verified ·
1 Parent(s): 3563f4a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +75 -76
  2. requirements.txt +5 -5
app.py CHANGED
@@ -1,76 +1,75 @@
1
- import onnxruntime
2
- from torchvision import transforms
3
- import torch
4
- import torch.nn.functional as F
5
- import gradio as gr
6
-
7
- sess_options = onnxruntime.SessionOptions()
8
- sess_options.enable_profiling = True
9
- sess_options.add_session_config_entry('session.load_model_format', 'ONNX')
10
- ort_sess = onnxruntime.InferenceSession("RFNet.onnx", sess_options=sess_options)
11
-
12
-
13
- preprocess_img = transforms.Compose([
14
- transforms.Resize((352,352)),
15
- transforms.ToTensor(),
16
- transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
17
-
18
- preprocess_depth = transforms.Compose([
19
- transforms.Resize((352,352)),
20
- transforms.ToTensor()])
21
- def inference(img,depth,GT):
22
- h,w = img.size
23
- img = preprocess_img(img).unsqueeze(0)
24
- depth = preprocess_depth(depth.convert('L')).unsqueeze(0)
25
- ort_inputs = {ort_sess.get_inputs()[0].name: img.numpy(), ort_sess.get_inputs()[1].name: depth.numpy()}
26
- ort_outs = ort_sess.run(None, ort_inputs)
27
- output_image = torch.tensor(ort_outs[0])
28
- res = F.interpolate(output_image, size=(w,h), mode='bilinear', align_corners=False)
29
- res = torch.sigmoid(res)
30
- res = res.data.cpu().numpy().squeeze()
31
- res = (res - res.min()) / (res.max() - res.min() + 1e-8)
32
- return res
33
-
34
-
35
-
36
-
37
-
38
-
39
- title = "Robust RGB-D Fusion for Saliency Detection"
40
- description = """ Deployment of the paper:
41
- [Robust RGB-D Fusion for Saliency Detection](https://arxiv.org/pdf/2208.01762.pdf)
42
- published at the International Conference on 3D Vision 2022 (3DV 2022).
43
- Paper Code can be found at [Zongwei97/RFNet](https://github.com/Zongwei97/RFnet).
44
- Deployed Code can be found at [shriarul5273/Robust_RGB-D_Saliency_Detection](https://github.com/shriarul5273/Robust_RGB-D_Saliency_Detection).
45
- Use example Image and corresponding Depth Map (from NJU2K dataset) or upload your own Image and Depth Map.
46
- """
47
- article = """ # Citation
48
- If you find this repo useful, please consider citing:
49
- ```
50
- @article{wu2022robust,
51
- title={Robust RGB-D Fusion for Saliency Detection},
52
- author={Wu, Zongwei and Gobichettipalayam, Shriarulmozhivarman and Tamadazte, Brahim and Allibert, Guillaume and Paudel, Danda Pani and Demonceaux, Cedric},
53
- journal={3DV},
54
- year={2022}
55
- }
56
- ```
57
- """
58
- examples = [['images/image_1.jpg','images/depth_1.png','images/gt_1.png'],
59
- ['images/image_2.jpg','images/depth_2.png','images/gt_2.png'],
60
- ['images/image_3.jpg','images/depth_3.png','images/gt_3.png'],
61
- ['images/image_4.jpg','images/depth_4.png','images/gt_4.png'],
62
- ['images/image_5.jpg','images/depth_5.png','images/gt_5.png']]
63
-
64
- input_1 = gr.Image(type='pil', label="RGB Image", sources="upload")
65
- input_2 = gr.Image(type='pil', label="Depth Image", sources="upload")
66
- input_3 = gr.Image(type='pil', label="Ground Truth", sources="upload")
67
- outputs = gr.Image(type="pil", label="Saliency Map")
68
-
69
-
70
- gr.Interface(inference, inputs=[input_1,input_2,input_3], outputs=outputs,
71
- title=title,examples=examples,
72
- description=description,
73
- article=article,
74
- theme=gr.themes.Soft(),
75
- cache_examples=False).launch(server_name="0.0.0.0")
76
-
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import gradio as gr
4
+ import spaces
5
+ import onnxruntime
6
+ from torchvision import transforms
7
+
8
+ # Enable GPU for ONNX Runtime
9
+ sess_options = onnxruntime.SessionOptions()
10
+ sess_options.enable_profiling = True
11
+ sess_options.add_session_config_entry('session.load_model_format', 'ONNX')
12
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
13
+ ort_sess = onnxruntime.InferenceSession("RFNet.onnx", sess_options=sess_options, providers=providers)
14
+
15
+ preprocess_img = transforms.Compose([
16
+ transforms.Resize((352, 352)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ preprocess_depth = transforms.Compose([
22
+ transforms.Resize((352, 352)),
23
+ transforms.ToTensor()
24
+ ])
25
+
26
+ def inference(img, depth, GT):
27
+ h, w = img.size
28
+ img = preprocess_img(img).unsqueeze(0).cuda()
29
+ depth = preprocess_depth(depth.convert('L')).unsqueeze(0).cuda()
30
+ ort_inputs = {ort_sess.get_inputs()[0].name: img.cpu().numpy(), ort_sess.get_inputs()[1].name: depth.cpu().numpy()}
31
+ ort_outs = ort_sess.run(None, ort_inputs)
32
+ output_image = torch.tensor(ort_outs[0]).cuda()
33
+ res = F.interpolate(output_image, size=(w, h), mode='bilinear', align_corners=False)
34
+ res = torch.sigmoid(res)
35
+ res = res.data.cpu().numpy().squeeze()
36
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
37
+ return res
38
+
39
+ title = "Robust RGB-D Fusion for Saliency Detection"
40
+ description = """ Deployment of the paper:
41
+ [Robust RGB-D Fusion for Saliency Detection](https://arxiv.org/pdf/2208.01762.pdf)
42
+ published at the International Conference on 3D Vision 2022 (3DV 2022).
43
+ Paper Code can be found at [Zongwei97/RFNet](https://github.com/Zongwei97/RFnet).
44
+ Deployed Code can be found at [shriarul5273/Robust_RGB-D_Saliency_Detection](https://github.com/shriarul5273/Robust_RGB-D_Saliency_Detection).
45
+ Use example Image and corresponding Depth Map (from NJU2K dataset) or upload your own Image and Depth Map.
46
+ """
47
+ article = """ # Citation
48
+ If you find this repo useful, please consider citing:
49
+ ```
50
+ @article{wu2022robust,
51
+ title={Robust RGB-D Fusion for Saliency Detection},
52
+ author={Wu, Zongwei and Gobichettipalayam, Shriarulmozhivarman and Tamadazte, Brahim and Allibert, Guillaume and Paudel, Danda Pani and Demonceaux, Cedric},
53
+ journal={3DV},
54
+ year={2022}
55
+ }
56
+ ```
57
+ """
58
+ examples = [['images/image_1.jpg','images/depth_1.png','images/gt_1.png'],
59
+ ['images/image_2.jpg','images/depth_2.png','images/gt_2.png'],
60
+ ['images/image_3.jpg','images/depth_3.png','images/gt_3.png'],
61
+ ['images/image_4.jpg','images/depth_4.png','images/gt_4.png'],
62
+ ['images/image_5.jpg','images/depth_5.png','images/gt_5.png']]
63
+
64
+ input_1 = gr.Image(type='pil', label="RGB Image", sources="upload")
65
+ input_2 = gr.Image(type='pil', label="Depth Image", sources="upload")
66
+ input_3 = gr.Image(type='pil', label="Ground Truth", sources="upload")
67
+ outputs = gr.Image(type="pil", label="Saliency Map")
68
+
69
+
70
+ gr.Interface(inference, inputs=[input_1,input_2,input_3], outputs=outputs,
71
+ title=title,examples=examples,
72
+ description=description,
73
+ article=article,
74
+ theme=gr.themes.Soft(),
75
+ cache_examples=False).launch()
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- gradio==3.44.1
2
- torch==1.12.0
3
- torchvision==0.13.0
4
- Pillow==9.2.0
5
- onnxruntime==1.12.1
 
1
+ gradio==5.9.1
2
+ torch==2.5.1
3
+ torchvision==020.1
4
+ Pillow==11.0.0
5
+ onnxruntime-gpu==1.20.1