Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,18 @@ from transformers import AutoModelForImageSegmentation
|
|
7 |
import torch
|
8 |
from torchvision import transforms
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
torch.set_float32_matmul_precision(["high", "highest"][0])
|
11 |
|
12 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
13 |
"briaai/RMBG-2.0", trust_remote_code=True
|
14 |
)
|
15 |
-
birefnet.to(
|
16 |
transform_image = transforms.Compose(
|
17 |
[
|
18 |
transforms.Resize((1024, 1024)),
|
@@ -37,7 +43,7 @@ def fn(image):
|
|
37 |
@spaces.GPU
|
38 |
def process(image):
|
39 |
image_size = image.size
|
40 |
-
input_images = transform_image(image).unsqueeze(0).to(
|
41 |
# Prediction
|
42 |
with torch.no_grad():
|
43 |
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
@@ -80,4 +86,4 @@ demo = gr.TabbedInterface(
|
|
80 |
)
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
-
demo.launch(show_error=True)
|
|
|
7 |
import torch
|
8 |
from torchvision import transforms
|
9 |
|
10 |
+
# 检查 CUDA 是否可用
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
device = "cuda"
|
13 |
+
else:
|
14 |
+
device = "cpu"
|
15 |
+
|
16 |
torch.set_float32_matmul_precision(["high", "highest"][0])
|
17 |
|
18 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
19 |
"briaai/RMBG-2.0", trust_remote_code=True
|
20 |
)
|
21 |
+
birefnet.to(device)
|
22 |
transform_image = transforms.Compose(
|
23 |
[
|
24 |
transforms.Resize((1024, 1024)),
|
|
|
43 |
@spaces.GPU
|
44 |
def process(image):
|
45 |
image_size = image.size
|
46 |
+
input_images = transform_image(image).unsqueeze(0).to(device)
|
47 |
# Prediction
|
48 |
with torch.no_grad():
|
49 |
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
|
|
86 |
)
|
87 |
|
88 |
if __name__ == "__main__":
|
89 |
+
demo.launch(share=True, show_error=True)
|