Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
0366edb
1
Parent(s):
74503df
infer aug + remove bkg
Browse files
app.py
CHANGED
@@ -210,10 +210,10 @@ def remove_outliers_and_average_circular(tensor, threshold=1.5):
|
|
210 |
|
211 |
return mean_angle
|
212 |
|
213 |
-
def get_3angle_infer_aug(
|
214 |
|
215 |
# image = Image.open(image_path).convert('RGB')
|
216 |
-
image = get_crop_images(
|
217 |
image_inputs = val_preprocess(images = image)
|
218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
219 |
with torch.no_grad():
|
@@ -267,21 +267,22 @@ def figure_to_img(fig):
|
|
267 |
return image
|
268 |
|
269 |
def infer_func(img, do_rm_bkg, do_infer_aug):
|
270 |
-
|
271 |
-
img = background_preprocess(img, do_rm_bkg)
|
272 |
if do_infer_aug:
|
273 |
-
|
|
|
274 |
else:
|
275 |
-
|
|
|
276 |
|
277 |
fig, ax = plt.subplots(figsize=(8, 8))
|
278 |
|
279 |
-
w, h =
|
280 |
if h>w:
|
281 |
extent = [-5*w/h, 5*w/h, -5, 5]
|
282 |
else:
|
283 |
extent = [-5, 5, -5*h/w, 5*h/w]
|
284 |
-
ax.imshow(
|
285 |
|
286 |
origin = np.array([0, 0])
|
287 |
|
|
|
210 |
|
211 |
return mean_angle
|
212 |
|
213 |
+
def get_3angle_infer_aug(origin_img, rm_bkg_img):
|
214 |
|
215 |
# image = Image.open(image_path).convert('RGB')
|
216 |
+
image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
|
217 |
image_inputs = val_preprocess(images = image)
|
218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
219 |
with torch.no_grad():
|
|
|
267 |
return image
|
268 |
|
269 |
def infer_func(img, do_rm_bkg, do_infer_aug):
|
270 |
+
origin_img = Image.fromarray(img)
|
|
|
271 |
if do_infer_aug:
|
272 |
+
rm_bkg_img = background_preprocess(origin_img, True)
|
273 |
+
angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
|
274 |
else:
|
275 |
+
rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
|
276 |
+
angles = get_3angle(rm_bkg_img)
|
277 |
|
278 |
fig, ax = plt.subplots(figsize=(8, 8))
|
279 |
|
280 |
+
w, h = rm_bkg_img.size
|
281 |
if h>w:
|
282 |
extent = [-5*w/h, 5*w/h, -5, 5]
|
283 |
else:
|
284 |
extent = [-5, 5, -5*h/w, 5*h/w]
|
285 |
+
ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
|
286 |
|
287 |
origin = np.array([0, 0])
|
288 |
|