Spaces:
Runtime error
Runtime error
drop ZERO GPU
Browse files- app.py +3 -3
- utils/__init__.py +0 -0
- utils/models.py +6 -0
app.py
CHANGED
@@ -2,7 +2,6 @@ from typing import Optional
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
-
import spaces
|
6 |
import supervision as sv
|
7 |
import torch
|
8 |
from PIL import Image
|
@@ -18,14 +17,15 @@ video by treating images as single-frame videos. Its design, a simple transforme
|
|
18 |
architecture with streaming memory, enables real-time video processing.
|
19 |
"""
|
20 |
|
21 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
22 |
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
23 |
CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
|
24 |
CONFIG = "sam2_hiera_l.yaml"
|
25 |
|
26 |
sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
|
27 |
|
28 |
-
@spaces.GPU
|
29 |
def process(image_input) -> Optional[Image.Image]:
|
30 |
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
|
31 |
image = np.array(image_input.convert("RGB"))
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
|
|
5 |
import supervision as sv
|
6 |
import torch
|
7 |
from PIL import Image
|
|
|
17 |
architecture with streaming memory, enables real-time video processing.
|
18 |
"""
|
19 |
|
20 |
+
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
DEVICE = torch.device('cuda')
|
22 |
+
|
23 |
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
24 |
CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
|
25 |
CONFIG = "sam2_hiera_l.yaml"
|
26 |
|
27 |
sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
|
28 |
|
|
|
29 |
def process(image_input) -> Optional[Image.Image]:
|
30 |
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
|
31 |
image = np.array(image_input.convert("RGB"))
|
utils/__init__.py
ADDED
File without changes
|
utils/models.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CHECKPOINTS = {
|
2 |
+
"tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
|
3 |
+
"small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
|
4 |
+
"base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
|
5 |
+
"large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
|
6 |
+
}
|