Spaces:
Runtime error
Runtime error
Seif-Yasser
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,120 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
# Create the Gradio interface
|
4 |
iface = gr.Interface(
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import models, transforms
|
4 |
+
import gradio as gr
|
5 |
+
import subprocess
|
6 |
+
import torch
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from models.with_mobilenet import PoseEstimationWithMobileNet
|
10 |
+
from modules.keypoints import extract_keypoints, group_keypoints
|
11 |
+
from modules.load_state import load_state
|
12 |
+
from modules.pose import Pose
|
13 |
+
import demo
|
14 |
+
from recon import reconWrapper
|
15 |
+
|
16 |
+
|
17 |
+
def get_rect(net, image, height_size):
|
18 |
+
net = net.eval()
|
19 |
+
|
20 |
+
stride = 8
|
21 |
+
upsample_ratio = 4
|
22 |
+
num_keypoints = Pose.num_kpts
|
23 |
+
previous_poses = []
|
24 |
+
delay = 33
|
25 |
+
image = image[0]
|
26 |
+
rect_path = image.replace('.%s' % (image.split('.')[-1]), '_rect.txt')
|
27 |
+
print('Processing', image)
|
28 |
+
img = cv2.imread(image, cv2.IMREAD_COLOR)
|
29 |
+
orig_img = img.copy()
|
30 |
+
heatmaps, pafs, scale, pad = demo.infer_fast(
|
31 |
+
net, img, height_size, stride, upsample_ratio, cpu=False)
|
32 |
+
|
33 |
+
total_keypoints_num = 0
|
34 |
+
all_keypoints_by_type = []
|
35 |
+
for kpt_idx in range(num_keypoints): # 19th for bg
|
36 |
+
total_keypoints_num += extract_keypoints(
|
37 |
+
heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
|
38 |
+
|
39 |
+
pose_entries, all_keypoints = group_keypoints(
|
40 |
+
all_keypoints_by_type, pafs)
|
41 |
+
for kpt_id in range(all_keypoints.shape[0]):
|
42 |
+
all_keypoints[kpt_id, 0] = (
|
43 |
+
all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
|
44 |
+
all_keypoints[kpt_id, 1] = (
|
45 |
+
all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
|
46 |
+
current_poses = []
|
47 |
+
|
48 |
+
rects = []
|
49 |
+
for n in range(len(pose_entries)):
|
50 |
+
if len(pose_entries[n]) == 0:
|
51 |
+
continue
|
52 |
+
pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
|
53 |
+
valid_keypoints = []
|
54 |
+
for kpt_id in range(num_keypoints):
|
55 |
+
if pose_entries[n][kpt_id] != -1.0: # keypoint was found
|
56 |
+
pose_keypoints[kpt_id, 0] = int(
|
57 |
+
all_keypoints[int(pose_entries[n][kpt_id]), 0])
|
58 |
+
pose_keypoints[kpt_id, 1] = int(
|
59 |
+
all_keypoints[int(pose_entries[n][kpt_id]), 1])
|
60 |
+
valid_keypoints.append(
|
61 |
+
[pose_keypoints[kpt_id, 0], pose_keypoints[kpt_id, 1]])
|
62 |
+
valid_keypoints = np.array(valid_keypoints)
|
63 |
+
|
64 |
+
if pose_entries[n][10] != -1.0 or pose_entries[n][13] != -1.0:
|
65 |
+
pmin = valid_keypoints.min(0)
|
66 |
+
pmax = valid_keypoints.max(0)
|
67 |
+
|
68 |
+
center = (0.5 * (pmax[:2] + pmin[:2])).astype('int')
|
69 |
+
radius = int(0.65 * max(pmax[0]-pmin[0], pmax[1]-pmin[1]))
|
70 |
+
elif pose_entries[n][10] == -1.0 and pose_entries[n][13] == -1.0 and pose_entries[n][8] != -1.0 and pose_entries[n][11] != -1.0:
|
71 |
+
# if leg is missing, use pelvis to get cropping
|
72 |
+
center = (
|
73 |
+
0.5 * (pose_keypoints[8] + pose_keypoints[11])).astype('int')
|
74 |
+
radius = int(
|
75 |
+
1.45*np.sqrt(((center[None, :] - valid_keypoints)**2).sum(1)).max(0))
|
76 |
+
center[1] += int(0.05*radius)
|
77 |
+
else:
|
78 |
+
center = np.array([img.shape[1]//2, img.shape[0]//2])
|
79 |
+
radius = max(img.shape[1]//2, img.shape[0]//2)
|
80 |
+
|
81 |
+
x1 = center[0] - radius
|
82 |
+
y1 = center[1] - radius
|
83 |
+
|
84 |
+
rects.append([x1, y1, 2*radius, 2*radius])
|
85 |
+
|
86 |
+
np.savetxt(rect_path, np.array(rects), fmt='%d')
|
87 |
+
print('Cropping boxes are saved at', rect_path)
|
88 |
+
print(rect_path[0:7] +
|
89 |
+
'apps/'+rect_path[7:])
|
90 |
+
|
91 |
+
|
92 |
+
def run_simple_test():
|
93 |
+
resolution = str(256)
|
94 |
+
start_id = -1
|
95 |
+
end_id = -1
|
96 |
+
cmd = ['--dataroot', 'pifuhd/sample_images', '--results_path', './results',
|
97 |
+
'--loadSize', '1024', '--resolution', resolution, '--load_netMR_checkpoint_path',
|
98 |
+
'./checkpoints/pifuhd.pt',
|
99 |
+
'--start_id', '%d' % start_id, '--end_id', '%d' % end_id]
|
100 |
+
mesh_path = reconWrapper(cmd, True)
|
101 |
+
print('Mesh is saved at', mesh_path)
|
102 |
+
return mesh_path
|
103 |
+
|
104 |
+
|
105 |
+
def predict(image):
|
106 |
+
# Save the input image to a file
|
107 |
+
image_path = 'input_image.png'
|
108 |
+
cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
109 |
+
|
110 |
+
net = PoseEstimationWithMobileNet()
|
111 |
+
checkpoint = torch.load(
|
112 |
+
'pifuhd/checkpoint_iter_370000.pth', map_location='cpu')
|
113 |
+
load_state(net, checkpoint)
|
114 |
+
|
115 |
+
get_rect(net.cuda(), [image_path], 512)
|
116 |
+
mesh_path = run_simple_test()
|
117 |
+
return mesh_path
|
118 |
|
119 |
# Create the Gradio interface
|
120 |
iface = gr.Interface(
|