not-lain commited on
Commit
3e75999
Β·
1 Parent(s): f3d3ce9

🌘wπŸŒ–

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -8,7 +8,7 @@ from torchvision import transforms
8
 
9
  torch.set_float32_matmul_precision(['high', 'highest'][0])
10
 
11
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
12
  birefnet.to("cuda")
13
  transform_image = transforms.Compose([
14
  transforms.Resize((1024, 1024)),
@@ -22,6 +22,8 @@ transform_image = transforms.Compose([
22
  def fn(image):
23
  im = load_img(image)
24
  im = im.convert('RGB')
 
 
25
  image = load_img(im)
26
  input_images = transform_image(image).unsqueeze(0).to('cuda')
27
  # Prediction
@@ -29,8 +31,9 @@ def fn(image):
29
  preds = birefnet(input_images)[-1].sigmoid().cpu()
30
  pred = preds[0].squeeze()
31
  pred_pil = transforms.ToPILImage()(pred)
32
- out = (pred_pil , im)
33
- return out
 
34
 
35
  slider1 = ImageSlider(label="birefnet", type="pil")
36
  slider2 = ImageSlider(label="birefnet", type="pil")
 
8
 
9
  torch.set_float32_matmul_precision(['high', 'highest'][0])
10
 
11
+ birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
12
  birefnet.to("cuda")
13
  transform_image = transforms.Compose([
14
  transforms.Resize((1024, 1024)),
 
22
  def fn(image):
23
  im = load_img(image)
24
  im = im.convert('RGB')
25
+ image_size = im.size
26
+ origin = im.copy()
27
  image = load_img(im)
28
  input_images = transform_image(image).unsqueeze(0).to('cuda')
29
  # Prediction
 
31
  preds = birefnet(input_images)[-1].sigmoid().cpu()
32
  pred = preds[0].squeeze()
33
  pred_pil = transforms.ToPILImage()(pred)
34
+ mask = pred_pil.resize(image_size)
35
+ image.putalpha(mask)
36
+ return (image , origin)
37
 
38
  slider1 = ImageSlider(label="birefnet", type="pil")
39
  slider2 = ImageSlider(label="birefnet", type="pil")