OriLib commited on
Commit
437669c
·
verified ·
1 Parent(s): 42d7dbb

Upload 2 files

Browse files
Files changed (2) hide show
  1. briarmbg.py +5 -3
  2. example_inference.py +1 -2
briarmbg.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
  class REBNCONV(nn.Module):
6
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
@@ -344,11 +345,12 @@ class myrebnconv(nn.Module):
344
  return self.rl(self.bn(self.conv(x)))
345
 
346
 
347
- class BriaRMBG(nn.Module):
348
 
349
- def __init__(self,in_ch=3,out_ch=1):
350
  super(BriaRMBG,self).__init__()
351
-
 
352
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
353
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
354
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
 
6
  class REBNCONV(nn.Module):
7
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
 
345
  return self.rl(self.bn(self.conv(x)))
346
 
347
 
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
 
350
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
  super(BriaRMBG,self).__init__()
352
+ in_ch=config["in_ch"]
353
+ out_ch=config["out_ch"]
354
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
 
example_inference.py CHANGED
@@ -7,12 +7,11 @@ from huggingface_hub import hf_hub_download
7
 
8
  def example_inference():
9
 
10
- model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
11
  im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
12
 
13
  net = BriaRMBG()
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- net.load_state_dict(torch.load(model_path, map_location=device))
16
  net.to(device)
17
  net.eval()
18
 
 
7
 
8
  def example_inference():
9
 
 
10
  im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
 
12
  net = BriaRMBG()
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4-experiment")
15
  net.to(device)
16
  net.eval()
17