zhang-ziang commited on
Commit
0366edb
·
1 Parent(s): 74503df

infer aug + remove bkg

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -210,10 +210,10 @@ def remove_outliers_and_average_circular(tensor, threshold=1.5):
210
 
211
  return mean_angle
212
 
213
- def get_3angle_infer_aug(image):
214
 
215
  # image = Image.open(image_path).convert('RGB')
216
- image = get_crop_images(image, num=6)
217
  image_inputs = val_preprocess(images = image)
218
  image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
219
  with torch.no_grad():
@@ -267,21 +267,22 @@ def figure_to_img(fig):
267
  return image
268
 
269
  def infer_func(img, do_rm_bkg, do_infer_aug):
270
- img = Image.fromarray(img)
271
- img = background_preprocess(img, do_rm_bkg)
272
  if do_infer_aug:
273
- angles = get_3angle_infer_aug(img)
 
274
  else:
275
- angles = get_3angle(img)
 
276
 
277
  fig, ax = plt.subplots(figsize=(8, 8))
278
 
279
- w, h = img.size
280
  if h>w:
281
  extent = [-5*w/h, 5*w/h, -5, 5]
282
  else:
283
  extent = [-5, 5, -5*h/w, 5*h/w]
284
- ax.imshow(img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
285
 
286
  origin = np.array([0, 0])
287
 
 
210
 
211
  return mean_angle
212
 
213
+ def get_3angle_infer_aug(origin_img, rm_bkg_img):
214
 
215
  # image = Image.open(image_path).convert('RGB')
216
+ image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
217
  image_inputs = val_preprocess(images = image)
218
  image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
219
  with torch.no_grad():
 
267
  return image
268
 
269
  def infer_func(img, do_rm_bkg, do_infer_aug):
270
+ origin_img = Image.fromarray(img)
 
271
  if do_infer_aug:
272
+ rm_bkg_img = background_preprocess(origin_img, True)
273
+ angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
274
  else:
275
+ rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
276
+ angles = get_3angle(rm_bkg_img)
277
 
278
  fig, ax = plt.subplots(figsize=(8, 8))
279
 
280
+ w, h = rm_bkg_img.size
281
  if h>w:
282
  extent = [-5*w/h, 5*w/h, -5, 5]
283
  else:
284
  extent = [-5, 5, -5*h/w, 5*h/w]
285
+ ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
286
 
287
  origin = np.array([0, 0])
288