Seif-Yasser commited on
Commit
ff9e6c3
·
verified ·
1 Parent(s): 3aac098

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py CHANGED
@@ -1,4 +1,120 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # Create the Gradio interface
4
  iface = gr.Interface(
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ from torchvision import models, transforms
4
+ import gradio as gr
5
+ import subprocess
6
+ import torch
7
+ import cv2
8
+ import numpy as np
9
+ from models.with_mobilenet import PoseEstimationWithMobileNet
10
+ from modules.keypoints import extract_keypoints, group_keypoints
11
+ from modules.load_state import load_state
12
+ from modules.pose import Pose
13
+ import demo
14
+ from recon import reconWrapper
15
+
16
+
17
+ def get_rect(net, image, height_size):
18
+ net = net.eval()
19
+
20
+ stride = 8
21
+ upsample_ratio = 4
22
+ num_keypoints = Pose.num_kpts
23
+ previous_poses = []
24
+ delay = 33
25
+ image = image[0]
26
+ rect_path = image.replace('.%s' % (image.split('.')[-1]), '_rect.txt')
27
+ print('Processing', image)
28
+ img = cv2.imread(image, cv2.IMREAD_COLOR)
29
+ orig_img = img.copy()
30
+ heatmaps, pafs, scale, pad = demo.infer_fast(
31
+ net, img, height_size, stride, upsample_ratio, cpu=False)
32
+
33
+ total_keypoints_num = 0
34
+ all_keypoints_by_type = []
35
+ for kpt_idx in range(num_keypoints): # 19th for bg
36
+ total_keypoints_num += extract_keypoints(
37
+ heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
38
+
39
+ pose_entries, all_keypoints = group_keypoints(
40
+ all_keypoints_by_type, pafs)
41
+ for kpt_id in range(all_keypoints.shape[0]):
42
+ all_keypoints[kpt_id, 0] = (
43
+ all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
44
+ all_keypoints[kpt_id, 1] = (
45
+ all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
46
+ current_poses = []
47
+
48
+ rects = []
49
+ for n in range(len(pose_entries)):
50
+ if len(pose_entries[n]) == 0:
51
+ continue
52
+ pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
53
+ valid_keypoints = []
54
+ for kpt_id in range(num_keypoints):
55
+ if pose_entries[n][kpt_id] != -1.0: # keypoint was found
56
+ pose_keypoints[kpt_id, 0] = int(
57
+ all_keypoints[int(pose_entries[n][kpt_id]), 0])
58
+ pose_keypoints[kpt_id, 1] = int(
59
+ all_keypoints[int(pose_entries[n][kpt_id]), 1])
60
+ valid_keypoints.append(
61
+ [pose_keypoints[kpt_id, 0], pose_keypoints[kpt_id, 1]])
62
+ valid_keypoints = np.array(valid_keypoints)
63
+
64
+ if pose_entries[n][10] != -1.0 or pose_entries[n][13] != -1.0:
65
+ pmin = valid_keypoints.min(0)
66
+ pmax = valid_keypoints.max(0)
67
+
68
+ center = (0.5 * (pmax[:2] + pmin[:2])).astype('int')
69
+ radius = int(0.65 * max(pmax[0]-pmin[0], pmax[1]-pmin[1]))
70
+ elif pose_entries[n][10] == -1.0 and pose_entries[n][13] == -1.0 and pose_entries[n][8] != -1.0 and pose_entries[n][11] != -1.0:
71
+ # if leg is missing, use pelvis to get cropping
72
+ center = (
73
+ 0.5 * (pose_keypoints[8] + pose_keypoints[11])).astype('int')
74
+ radius = int(
75
+ 1.45*np.sqrt(((center[None, :] - valid_keypoints)**2).sum(1)).max(0))
76
+ center[1] += int(0.05*radius)
77
+ else:
78
+ center = np.array([img.shape[1]//2, img.shape[0]//2])
79
+ radius = max(img.shape[1]//2, img.shape[0]//2)
80
+
81
+ x1 = center[0] - radius
82
+ y1 = center[1] - radius
83
+
84
+ rects.append([x1, y1, 2*radius, 2*radius])
85
+
86
+ np.savetxt(rect_path, np.array(rects), fmt='%d')
87
+ print('Cropping boxes are saved at', rect_path)
88
+ print(rect_path[0:7] +
89
+ 'apps/'+rect_path[7:])
90
+
91
+
92
+ def run_simple_test():
93
+ resolution = str(256)
94
+ start_id = -1
95
+ end_id = -1
96
+ cmd = ['--dataroot', 'pifuhd/sample_images', '--results_path', './results',
97
+ '--loadSize', '1024', '--resolution', resolution, '--load_netMR_checkpoint_path',
98
+ './checkpoints/pifuhd.pt',
99
+ '--start_id', '%d' % start_id, '--end_id', '%d' % end_id]
100
+ mesh_path = reconWrapper(cmd, True)
101
+ print('Mesh is saved at', mesh_path)
102
+ return mesh_path
103
+
104
+
105
+ def predict(image):
106
+ # Save the input image to a file
107
+ image_path = 'input_image.png'
108
+ cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
109
+
110
+ net = PoseEstimationWithMobileNet()
111
+ checkpoint = torch.load(
112
+ 'pifuhd/checkpoint_iter_370000.pth', map_location='cpu')
113
+ load_state(net, checkpoint)
114
+
115
+ get_rect(net.cuda(), [image_path], 512)
116
+ mesh_path = run_simple_test()
117
+ return mesh_path
118
 
119
  # Create the Gradio interface
120
  iface = gr.Interface(