ZhengPeng7 commited on
Commit
607dfc0
·
1 Parent(s): 43fe6dc

Align lib_name as birefnet and add inference endpoint option.

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. birefnet.py +28 -24
  3. handler.py +138 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- library_name: BiRefNet
3
  tags:
4
  - background-removal
5
  - mask-generation
 
1
  ---
2
+ library_name: birefnet
3
  tags:
4
  - background-removal
5
  - mask-generation
birefnet.py CHANGED
@@ -615,6 +615,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
615
 
616
  # config = Config()
617
 
 
618
  class Mlp(nn.Module):
619
  """ Multilayer perceptron."""
620
 
@@ -739,7 +740,8 @@ class WindowAttention(nn.Module):
739
  attn = (q @ k.transpose(-2, -1))
740
 
741
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
742
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
 
743
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
744
  attn = attn + relative_position_bias.unsqueeze(0)
745
 
@@ -974,8 +976,9 @@ class BasicLayer(nn.Module):
974
  """
975
 
976
  # calculate attention mask for SW-MSA
977
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
978
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
 
979
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
980
  h_slices = (slice(0, -self.window_size),
981
  slice(-self.window_size, -self.shift_size),
@@ -1961,6 +1964,7 @@ import torch.nn as nn
1961
  import torch.nn.functional as F
1962
  from kornia.filters import laplacian
1963
  from transformers import PreTrainedModel
 
1964
 
1965
  # from config import Config
1966
  # from dataset import class_labels_TR_sorted
@@ -1974,6 +1978,18 @@ from transformers import PreTrainedModel
1974
  from .BiRefNet_config import BiRefNetConfig
1975
 
1976
 
 
 
 
 
 
 
 
 
 
 
 
 
1977
  class BiRefNet(
1978
  PreTrainedModel
1979
  ):
@@ -2124,18 +2140,6 @@ class Decoder(nn.Module):
2124
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2125
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2126
 
2127
- def get_patches_batch(self, x, p):
2128
- _size_h, _size_w = p.shape[2:]
2129
- patches_batch = []
2130
- for idx in range(x.shape[0]):
2131
- columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
2132
- patches_x = []
2133
- for column_x in columns_x:
2134
- patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
2135
- patch_sample = torch.cat(patches_x, dim=1)
2136
- patches_batch.append(patch_sample)
2137
- return torch.cat(patches_batch, dim=0)
2138
-
2139
  def forward(self, features):
2140
  if self.training and self.config.out_ref:
2141
  outs_gdt_pred = []
@@ -2146,10 +2150,10 @@ class Decoder(nn.Module):
2146
  outs = []
2147
 
2148
  if self.config.dec_ipt:
2149
- patches_batch = self.get_patches_batch(x, x4) if self.split else x
2150
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2151
  p4 = self.decoder_block4(x4)
2152
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
2153
  if self.config.out_ref:
2154
  p4_gdt = self.gdt_convs_4(p4)
2155
  if self.training:
@@ -2167,10 +2171,10 @@ class Decoder(nn.Module):
2167
  _p3 = _p4 + self.lateral_block4(x3)
2168
 
2169
  if self.config.dec_ipt:
2170
- patches_batch = self.get_patches_batch(x, _p3) if self.split else x
2171
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2172
  p3 = self.decoder_block3(_p3)
2173
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
2174
  if self.config.out_ref:
2175
  p3_gdt = self.gdt_convs_3(p3)
2176
  if self.training:
@@ -2193,10 +2197,10 @@ class Decoder(nn.Module):
2193
  _p2 = _p3 + self.lateral_block3(x2)
2194
 
2195
  if self.config.dec_ipt:
2196
- patches_batch = self.get_patches_batch(x, _p2) if self.split else x
2197
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2198
  p2 = self.decoder_block2(_p2)
2199
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
2200
  if self.config.out_ref:
2201
  p2_gdt = self.gdt_convs_2(p2)
2202
  if self.training:
@@ -2214,17 +2218,17 @@ class Decoder(nn.Module):
2214
  _p1 = _p2 + self.lateral_block2(x1)
2215
 
2216
  if self.config.dec_ipt:
2217
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2218
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2219
  _p1 = self.decoder_block1(_p1)
2220
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2221
 
2222
  if self.config.dec_ipt:
2223
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2224
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2225
  p1_out = self.conv_out1(_p1)
2226
 
2227
- if self.config.ms_supervision:
2228
  outs.append(m4)
2229
  outs.append(m3)
2230
  outs.append(m2)
 
615
 
616
  # config = Config()
617
 
618
+
619
  class Mlp(nn.Module):
620
  """ Multilayer perceptron."""
621
 
 
740
  attn = (q @ k.transpose(-2, -1))
741
 
742
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
743
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
744
+ ) # Wh*Ww, Wh*Ww, nH
745
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
746
  attn = attn + relative_position_bias.unsqueeze(0)
747
 
 
976
  """
977
 
978
  # calculate attention mask for SW-MSA
979
+ # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
980
+ Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
981
+ Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
982
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
983
  h_slices = (slice(0, -self.window_size),
984
  slice(-self.window_size, -self.shift_size),
 
1964
  import torch.nn.functional as F
1965
  from kornia.filters import laplacian
1966
  from transformers import PreTrainedModel
1967
+ from einops import rearrange
1968
 
1969
  # from config import Config
1970
  # from dataset import class_labels_TR_sorted
 
1978
  from .BiRefNet_config import BiRefNetConfig
1979
 
1980
 
1981
+ def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
1982
+ if patch_ref is not None:
1983
+ grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
1984
+ patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
1985
+ return patches
1986
+
1987
+ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
1988
+ if patch_ref is not None:
1989
+ grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
1990
+ image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
1991
+ return image
1992
+
1993
  class BiRefNet(
1994
  PreTrainedModel
1995
  ):
 
2140
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2142
 
 
 
 
 
 
 
 
 
 
 
 
 
2143
  def forward(self, features):
2144
  if self.training and self.config.out_ref:
2145
  outs_gdt_pred = []
 
2150
  outs = []
2151
 
2152
  if self.config.dec_ipt:
2153
+ patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2154
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2155
  p4 = self.decoder_block4(x4)
2156
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
2157
  if self.config.out_ref:
2158
  p4_gdt = self.gdt_convs_4(p4)
2159
  if self.training:
 
2171
  _p3 = _p4 + self.lateral_block4(x3)
2172
 
2173
  if self.config.dec_ipt:
2174
+ patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2175
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2176
  p3 = self.decoder_block3(_p3)
2177
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
2178
  if self.config.out_ref:
2179
  p3_gdt = self.gdt_convs_3(p3)
2180
  if self.training:
 
2197
  _p2 = _p3 + self.lateral_block3(x2)
2198
 
2199
  if self.config.dec_ipt:
2200
+ patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2201
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2202
  p2 = self.decoder_block2(_p2)
2203
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
2204
  if self.config.out_ref:
2205
  p2_gdt = self.gdt_convs_2(p2)
2206
  if self.training:
 
2218
  _p1 = _p2 + self.lateral_block2(x1)
2219
 
2220
  if self.config.dec_ipt:
2221
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2222
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2223
  _p1 = self.decoder_block1(_p1)
2224
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2225
 
2226
  if self.config.dec_ipt:
2227
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2228
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2229
  p1_out = self.conv_out1(_p1)
2230
 
2231
+ if self.config.ms_supervision and self.training:
2232
  outs.append(m4)
2233
  outs.append(m3)
2234
  outs.append(m2)
handler.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
+ from typing import Dict, List, Any, Tuple
3
+ import os
4
+ import requests
5
+ from io import BytesIO
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from torchvision import transforms
11
+ from transformers import AutoModelForImageSegmentation
12
+
13
+ torch.set_float32_matmul_precision(["high", "highest"][0])
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ ### image_proc.py
18
+ def refine_foreground(image, mask, r=90):
19
+ if mask.size != image.size:
20
+ mask = mask.resize(image.size)
21
+ image = np.array(image) / 255.0
22
+ mask = np.array(mask) / 255.0
23
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
24
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
25
+ return image_masked
26
+
27
+
28
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
29
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
30
+ alpha = alpha[:, :, None]
31
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
32
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
33
+
34
+
35
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36
+ if isinstance(image, Image.Image):
37
+ image = np.array(image) / 255.0
38
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
39
+
40
+ blurred_FA = cv2.blur(F * alpha, (r, r))
41
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
42
+
43
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45
+ F = blurred_F + alpha * \
46
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
47
+ F = np.clip(F, 0, 1)
48
+ return F, blurred_B
49
+
50
+
51
+ class ImagePreprocessor():
52
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
53
+ self.transform_image = transforms.Compose([
54
+ transforms.Resize(resolution),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57
+ ])
58
+
59
+ def proc(self, image: Image.Image) -> torch.Tensor:
60
+ image = self.transform_image(image)
61
+ return image
62
+
63
+ usage_to_weights_file = {
64
+ 'General': 'BiRefNet',
65
+ 'General-HR': 'BiRefNet_HR',
66
+ 'General-Lite': 'BiRefNet_lite',
67
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
68
+ 'General-reso_512': 'BiRefNet-reso_512',
69
+ 'Matting': 'BiRefNet-matting',
70
+ 'Portrait': 'BiRefNet-portrait',
71
+ 'DIS': 'BiRefNet-DIS5K',
72
+ 'HRSOD': 'BiRefNet-HRSOD',
73
+ 'COD': 'BiRefNet-COD',
74
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
75
+ 'General-legacy': 'BiRefNet-legacy'
76
+ }
77
+
78
+ # Choose the version of BiRefNet here.
79
+ usage = 'Matting'
80
+
81
+ # Set resolution
82
+ if usage in ['General-Lite-2K']:
83
+ resolution = (2560, 1440)
84
+ elif usage in ['General-reso_512']:
85
+ resolution = (512, 512)
86
+ elif usage in ['General-HR']:
87
+ resolution = (2048, 2048)
88
+ else:
89
+ resolution = (1024, 1024)
90
+
91
+ half_precision = True
92
+
93
+ class EndpointHandler():
94
+ def __init__(self, path=''):
95
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
96
+ '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
97
+ )
98
+ self.birefnet.to(device)
99
+ self.birefnet.eval()
100
+ if half_precision:
101
+ self.birefnet.half()
102
+
103
+ def __call__(self, data: Dict[str, Any]):
104
+ """
105
+ data args:
106
+ inputs (:obj: `str`)
107
+ date (:obj: `str`)
108
+ Return:
109
+ A :obj:`list` | `dict`: will be serialized and returned
110
+ """
111
+ print('data["inputs"] = ', data["inputs"])
112
+ image_src = data["inputs"]
113
+ if isinstance(image_src, str):
114
+ if os.path.isfile(image_src):
115
+ image_ori = Image.open(image_src)
116
+ else:
117
+ response = requests.get(image_src)
118
+ image_data = BytesIO(response.content)
119
+ image_ori = Image.open(image_data)
120
+ else:
121
+ image_ori = Image.fromarray(image_src)
122
+
123
+ image = image_ori.convert('RGB')
124
+ # Preprocess the image
125
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
126
+ image_proc = image_preprocessor.proc(image)
127
+ image_proc = image_proc.unsqueeze(0)
128
+
129
+ # Prediction
130
+ with torch.no_grad():
131
+ preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
132
+ pred = preds[0].squeeze()
133
+
134
+ # Show Results
135
+ pred_pil = transforms.ToPILImage()(pred)
136
+ image_masked = refine_foreground(image, pred_pil)
137
+ image_masked.putalpha(pred_pil.resize(image.size))
138
+ return image_masked