OriLib not-lain commited on
Commit
33c41df
·
verified ·
1 Parent(s): 132c4e1

integrate with transformers (#21)

Browse files

- integrate with transformers (a845831f8f4651b4b6b3676ad744a0a3c368417f)
- Update README (52584706b45157beda5e5f42670d75dfc29bfed7)
- changing custom pipeline and pinning requirements (2fb3b2ec826f02e10b74e334a6a8678273a57dfc)


Co-authored-by: LAin <[email protected]>

Files changed (6) hide show
  1. MyConfig.py +13 -0
  2. MyPipe.py +73 -0
  3. README.md +13 -34
  4. briarmbg.py +8 -7
  5. config.json +24 -3
  6. requirements.txt +2 -1
MyConfig.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class RMBGConfig(PretrainedConfig):
5
+ model_type = "SegformerForSemanticSegmentation"
6
+ def __init__(
7
+ self,
8
+ in_ch=3,
9
+ out_ch=1,
10
+ **kwargs):
11
+ self.in_ch = in_ch
12
+ self.out_ch = out_ch
13
+ super().__init__(**kwargs)
MyPipe.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+ from transformers import Pipeline
6
+ from skimage import io
7
+ from PIL import Image
8
+
9
+ class RMBGPipe(Pipeline):
10
+ def __init__(self,**kwargs):
11
+ Pipeline.__init__(self,**kwargs)
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model.to(self.device)
14
+ self.model.eval()
15
+
16
+ def _sanitize_parameters(self, **kwargs):
17
+ # parse parameters
18
+ preprocess_kwargs = {}
19
+ postprocess_kwargs = {}
20
+ if "model_input_size" in kwargs :
21
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
22
+ if "return_mask" in kwargs:
23
+ postprocess_kwargs["return_mask"] = kwargs["return_mask"]
24
+ return preprocess_kwargs, {}, postprocess_kwargs
25
+
26
+ def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
27
+ # preprocess the input
28
+ orig_im = io.imread(im_path)
29
+ orig_im_size = orig_im.shape[0:2]
30
+ image = self.preprocess_image(orig_im, model_input_size).to(self.device)
31
+ inputs = {
32
+ "image":image,
33
+ "orig_im_size":orig_im_size,
34
+ "im_path" : im_path
35
+ }
36
+ return inputs
37
+
38
+ def _forward(self,inputs):
39
+ result = self.model(inputs.pop("image"))
40
+ inputs["result"] = result
41
+ return inputs
42
+ def postprocess(self,inputs,return_mask:bool=False ):
43
+ result = inputs.pop("result")
44
+ orig_im_size = inputs.pop("orig_im_size")
45
+ im_path = inputs.pop("im_path")
46
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
47
+ pil_im = Image.fromarray(result_image)
48
+ if return_mask ==True :
49
+ return pil_im
50
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
+ orig_image = Image.open(im_path)
52
+ no_bg_image.paste(orig_image, mask=pil_im)
53
+ return no_bg_image
54
+
55
+ # utilities functions
56
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
57
+ # same as utilities.py with minor modification
58
+ if len(im.shape) < 3:
59
+ im = im[:, :, np.newaxis]
60
+ # orig_im_size=im.shape[0:2]
61
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
62
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
63
+ image = torch.divide(im_tensor,255.0)
64
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
65
+ return image
66
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
67
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
68
+ ma = torch.max(result)
69
+ mi = torch.min(result)
70
+ result = (result-mi)/(ma-mi)
71
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
72
+ im_array = np.squeeze(im_array)
73
+ return im_array
README.md CHANGED
@@ -2,7 +2,7 @@
2
  license: other
3
  license_name: bria-rmbg-1.4
4
  license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
- pipeline_tag: image-to-image
6
  tags:
7
  - remove background
8
  - background
@@ -10,6 +10,7 @@ tags:
10
  - Pytorch
11
  - vision
12
  - legal liability
 
13
 
14
  extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
15
  extra_gated_fields:
@@ -94,43 +95,21 @@ These modifications significantly improve the model’s accuracy and effectivene
94
 
95
  ## Installation
96
  ```bash
97
- git clone https://huggingface.co/briaai/RMBG-1.4
98
- cd RMBG-1.4/
99
- pip install -r requirements.txt
100
  ```
101
 
102
  ## Usage
103
 
 
104
  ```python
105
- from skimage import io
106
- import torch, os
107
- from PIL import Image
108
- from briarmbg import BriaRMBG
109
- from utilities import preprocess_image, postprocess_image
110
-
111
- im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
112
-
113
- net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
114
-
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- net.to(device)
117
-
118
- # prepare input
119
- model_input_size = [1024,1024]
120
- orig_im = io.imread(im_path)
121
- orig_im_size = orig_im.shape[0:2]
122
- image = preprocess_image(orig_im, model_input_size).to(device)
123
-
124
- # inference
125
- result=net(image)
126
-
127
- # post process
128
- result_image = postprocess_image(result[0][0], orig_im_size)
129
 
130
- # save result
131
- pil_im = Image.fromarray(result_image)
132
- no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
133
- orig_image = Image.open(im_path)
134
- no_bg_image.paste(orig_image, mask=pil_im)
135
- no_bg_image.save("example_image_no_bg.png")
136
  ```
 
2
  license: other
3
  license_name: bria-rmbg-1.4
4
  license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
+ pipeline_tag: image-segmentation
6
  tags:
7
  - remove background
8
  - background
 
10
  - Pytorch
11
  - vision
12
  - legal liability
13
+ - transformers
14
 
15
  extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
16
  extra_gated_fields:
 
95
 
96
  ## Installation
97
  ```bash
98
+ wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
 
 
99
  ```
100
 
101
  ## Usage
102
 
103
+ either load the model
104
  ```python
105
+ from transformers import AutoModelForImageSegmentation
106
+ model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
107
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ or load the pipeline
110
+ ```python
111
+ from transformers import pipeline
112
+ pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
113
+ pillow_mask = pipe("img_path",return_mask = True) # outputs a pillow mask
114
+ pillow_image = pipe("image_path") # applies mask on input and returns a pillow image
115
  ```
briarmbg.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from huggingface_hub import PyTorchModelHubMixin
 
5
 
6
  class REBNCONV(nn.Module):
7
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
@@ -345,12 +346,12 @@ class myrebnconv(nn.Module):
345
  return self.rl(self.bn(self.conv(x)))
346
 
347
 
348
- class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
-
350
- def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
- super(BriaRMBG,self).__init__()
352
- in_ch=config["in_ch"]
353
- out_ch=config["out_ch"]
354
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from .MyConfig import RMBGConfig
6
 
7
  class REBNCONV(nn.Module):
8
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
 
346
  return self.rl(self.bn(self.conv(x)))
347
 
348
 
349
+ class BriaRMBG(PreTrainedModel):
350
+ config_class = RMBGConfig
351
+ def __init__(self,config):
352
+ super().__init__(config)
353
+ in_ch = config.in_ch # 3
354
+ out_ch = config.out_ch # 1
355
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
356
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
 
config.json CHANGED
@@ -1,4 +1,25 @@
1
  {
2
- "in_ch":3,
3
- "out_ch":1
4
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "_name_or_path": "briaai/RMBG-1.4",
3
+ "architectures": [
4
+ "BriaRMBG"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "MyConfig.RMBGConfig",
8
+ "AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
+ "pt": [
14
+ "AutoModelForImageSegmentation"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
+ },
20
+ "in_ch": 3,
21
+ "model_type": "SegformerForSemanticSegmentation",
22
+ "out_ch": 1,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.38.0.dev0"
25
+ }
requirements.txt CHANGED
@@ -4,4 +4,5 @@ pillow
4
  numpy
5
  typing
6
  scikit-image
7
- huggingface_hub
 
 
4
  numpy
5
  typing
6
  scikit-image
7
+ huggingface_hub
8
+ transformers==4.39.1