File size: 3,349 Bytes
b799d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn.functional as F
import gradio as gr
import spaces
import onnxruntime
from torchvision import transforms

# Enable GPU for ONNX Runtime
sess_options = onnxruntime.SessionOptions()
sess_options.enable_profiling = True
sess_options.add_session_config_entry('session.load_model_format', 'ONNX')
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_sess = onnxruntime.InferenceSession("RFNet.onnx", sess_options=sess_options, providers=providers)

preprocess_img = transforms.Compose([
    transforms.Resize((352, 352)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

preprocess_depth = transforms.Compose([
    transforms.Resize((352, 352)),
    transforms.ToTensor()
])

def inference(img, depth, GT):
    h, w = img.size
    img = preprocess_img(img).unsqueeze(0).cuda()
    depth = preprocess_depth(depth.convert('L')).unsqueeze(0).cuda()
    ort_inputs = {ort_sess.get_inputs()[0].name: img.cpu().numpy(), ort_sess.get_inputs()[1].name: depth.cpu().numpy()}
    ort_outs = ort_sess.run(None, ort_inputs)
    output_image = torch.tensor(ort_outs[0]).cuda()
    res = F.interpolate(output_image, size=(w, h), mode='bilinear', align_corners=False)
    res = torch.sigmoid(res)
    res = res.data.cpu().numpy().squeeze()
    res = (res - res.min()) / (res.max() - res.min() + 1e-8)
    return res

title = "Robust RGB-D Fusion for Saliency Detection"
description = """ Deployment of the paper: 

[Robust RGB-D Fusion for Saliency Detection](https://arxiv.org/pdf/2208.01762.pdf) 

published at the International Conference on 3D Vision 2022 (3DV 2022). 

Paper Code can be found at [Zongwei97/RFNet](https://github.com/Zongwei97/RFnet).

Deployed Code can be found at [shriarul5273/Robust_RGB-D_Saliency_Detection](https://github.com/shriarul5273/Robust_RGB-D_Saliency_Detection).

Use example Image and corresponding Depth Map (from NJU2K dataset) or upload your own Image and Depth Map.

"""
article = """ # Citation 

If you find this repo useful, please consider citing:

```

@article{wu2022robust,

  title={Robust RGB-D Fusion for Saliency Detection},

  author={Wu, Zongwei and Gobichettipalayam, Shriarulmozhivarman and Tamadazte, Brahim and Allibert, Guillaume and Paudel, Danda Pani and Demonceaux, Cedric},

  journal={3DV},

  year={2022}

}

```

"""
examples = [['images/image_1.jpg','images/depth_1.png','images/gt_1.png'],
            ['images/image_2.jpg','images/depth_2.png','images/gt_2.png'],
            ['images/image_3.jpg','images/depth_3.png','images/gt_3.png'],
            ['images/image_4.jpg','images/depth_4.png','images/gt_4.png'],
            ['images/image_5.jpg','images/depth_5.png','images/gt_5.png']]

input_1 = gr.Image(type='pil', label="RGB Image", sources="upload")
input_2 = gr.Image(type='pil', label="Depth Image", sources="upload")
input_3 = gr.Image(type='pil', label="Ground Truth", sources="upload")
outputs = gr.Image(type="pil", label="Saliency Map")


gr.Interface(inference, inputs=[input_1,input_2,input_3], outputs=outputs,
                title=title,examples=examples,
                description=description,
                article=article,
                theme=gr.themes.Soft(),
                cache_examples=False).launch()