franciszzj commited on
Commit
01b1b55
·
1 Parent(s): 0e21ab4

add app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+
8
+ from PIL import Image
9
+ from network import pvt_cls as TCN
10
+
11
+ import gradio as gr
12
+
13
+
14
+ def demo(img_path):
15
+ # config
16
+ batch_size = 8
17
+ crop_size = 256
18
+ model_path = '/users/k21163430/workspace/TreeFormer/models/best_model.pth'
19
+
20
+ device = torch.device('cuda')
21
+
22
+ # prepare model
23
+ model = TCN.pvt_treeformer(pretrained=False)
24
+ model.to(device)
25
+ model.load_state_dict(torch.load(model_path, device))
26
+ model.eval()
27
+
28
+ # preprocess
29
+ img = Image.open(img_path).convert('RGB')
30
+ show_img = np.array(img)
31
+ wd, ht = img.size
32
+ st_size = 1.0 * min(wd, ht)
33
+ if st_size < crop_size:
34
+ rr = 1.0 * crop_size / st_size
35
+ wd = round(wd * rr)
36
+ ht = round(ht * rr)
37
+ st_size = 1.0 * min(wd, ht)
38
+ img = img.resize((wd, ht), Image.BICUBIC)
39
+ transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
+ img = transform(img)
44
+ img = img.unsqueeze(0)
45
+
46
+ # model forward
47
+ with torch.no_grad():
48
+ inputs = img.to(device)
49
+ crop_imgs, crop_masks = [], []
50
+ b, c, h, w = inputs.size()
51
+ rh, rw = crop_size, crop_size
52
+
53
+ for i in range(0, h, rh):
54
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
55
+
56
+ for j in range(0, w, rw):
57
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
58
+ crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
59
+ mask = torch.zeros([b, 1, h, w]).to(device)
60
+ mask[:, :, gis:gie, gjs:gje].fill_(1.0)
61
+ crop_masks.append(mask)
62
+ crop_imgs, crop_masks = map(lambda x: torch.cat(
63
+ x, dim=0), (crop_imgs, crop_masks))
64
+
65
+ crop_preds = []
66
+ nz, bz = crop_imgs.size(0), batch_size
67
+ for i in range(0, nz, bz):
68
+
69
+ gs, gt = i, min(nz, i + bz)
70
+ crop_pred, _ = model(crop_imgs[gs:gt])
71
+ crop_pred = crop_pred[0]
72
+
73
+ _, _, h1, w1 = crop_pred.size()
74
+ crop_pred = F.interpolate(crop_pred, size=(
75
+ h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
76
+ crop_preds.append(crop_pred)
77
+ crop_preds = torch.cat(crop_preds, dim=0)
78
+
79
+ # splice them to the original size
80
+ idx = 0
81
+ pred_map = torch.zeros([b, 1, h, w]).to(device)
82
+ for i in range(0, h, rh):
83
+ gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
84
+ for j in range(0, w, rw):
85
+ gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
86
+ pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
87
+ idx += 1
88
+ # for the overlapping area, compute average value
89
+ mask = crop_masks.sum(dim=0).unsqueeze(0)
90
+ outputs = pred_map / mask
91
+
92
+ outputs = F.interpolate(outputs, size=(
93
+ h, w), mode='bilinear', align_corners=True)/4
94
+ outputs = pred_map / mask
95
+ model_output = round(torch.sum(outputs).item())
96
+
97
+ print("{}: {}".format(img_path, model_output))
98
+ outputs = outputs.squeeze().cpu().numpy()
99
+ outputs = (outputs - np.min(outputs)) / \
100
+ (np.max(outputs) - np.min(outputs))
101
+
102
+ show_img = show_img / 255.0
103
+ show_img = show_img * 0.2 + outputs[:, :, None] * 0.8
104
+
105
+ return model_output, show_img
106
+
107
+
108
+ if __name__ == "__main__":
109
+ # test
110
+ # img_path = sys.argv[1]
111
+ # demo(img)
112
+
113
+ # Launch a gr.Interface
114
+ gr_demo = gr.Interface(fn=demo,
115
+ inputs=gr.Image(source="upload",
116
+ type="filepath",
117
+ label="Input Image",
118
+ width=768,
119
+ height=768,
120
+ ),
121
+ outputs=[
122
+ gr.Number(label="Predicted Tree Count"),
123
+ gr.Image(label="Density Map",
124
+ width=768,
125
+ height=768,
126
+ )
127
+ ],
128
+ title="TreeFormer",
129
+ description="TreeFormer is a semi-supervised transformer-based framework for tree counting from a single high resolution image. Upload an image and TreeFormer will predict the number of trees in the image and generate a density map of the trees.",
130
+ article="This work has been developed a spart of the ReSET project which has received funding from the European Union's Horizon 2020 FET Proactive Programme under grant agreement No 101017857. The contents of this publication are the sole responsibility of the ReSET consortium and do not necessarily reflect the opinion of the European Union.",
131
+ examples=[
132
+ ["./examples/IMG_101.jpg"],
133
+ ["./examples/IMG_125.jpg"],
134
+ ["./examples/IMG_138.jpg"],
135
+ ["./examples/IMG_180.jpg"],
136
+ ["./examples/IMG_18.jpg"],
137
+ ["./examples/IMG_206.jpg"],
138
+ ["./examples/IMG_223.jpg"],
139
+ ["./examples/IMG_247.jpg"],
140
+ ["./examples/IMG_270.jpg"],
141
+ ["./examples/IMG_306.jpg"],
142
+ ],
143
+ # cache_examples=True,
144
+ examples_per_page=10,
145
+ allow_flagging=False,
146
+ theme=gr.themes.Default(),
147
+ )
148
+ gr_demo.launch(share=True, server_port=7861, favicon_path="./assets/reset.png")