File size: 2,866 Bytes
35a9ed4
bdee200
35a9ed4
 
 
2a274cc
35a9ed4
 
2a274cc
 
 
 
 
 
 
35a9ed4
 
e856606
bdee200
19540cf
bdee200
 
 
19540cf
488936c
6207473
 
 
 
35a9ed4
2a274cc
 
 
 
 
 
 
 
 
 
 
35a9ed4
 
 
 
 
 
 
 
 
 
 
 
8c2e68c
35a9ed4
2a274cc
dc9bdbf
8c2e68c
dc9bdbf
35a9ed4
 
 
dc9bdbf
2a274cc
 
 
 
 
 
 
 
 
35a9ed4
 
 
 
 
b47ae2e
35a9ed4
 
de71836
35a9ed4
de71836
 
35a9ed4
 
2a274cc
35a9ed4
 
 
 
dc9bdbf
35a9ed4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image 
import numpy as np 
import os 
import tempfile
import gradio as gr

import cv2
try:
    from mmengine.visualization import Visualizer
except ImportError:
    Visualizer = None
    print("Warning: mmengine is not installed, visualization is disabled.")
    
# Load the model and tokenizer 
model_path = "ByteDance/Sa2VA-4B"
 
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True,
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code = True,
)

def visualize(pred_mask, image_path, work_dir):
    visualizer = Visualizer()
    img = cv2.imread(image_path)
    visualizer.set_image(img)
    visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
    visual_result = visualizer.get_image()

    output_path = os.path.join(work_dir, os.path.basename(image_path))
    cv2.imwrite(output_path, visual_result)
    return output_path

def image_vision(image_input_path, prompt):
    image_path = image_input_path
    text_prompts = f"<image>{prompt}"
    image = Image.open(image_path).convert('RGB')
    input_dict = {
        'image': image,
        'text': text_prompts,
        'past_text': '',
        'mask_prompts': None,
        'tokenizer': tokenizer,
    }
    return_dict = model.predict_forward(**input_dict)
    print(return_dict)
    answer = return_dict["prediction"] # the text format answer
    
    seg_image = return_dict["prediction_masks"]
    
    return answer, seg_image

def main_infer(image_input_path, prompt):

    answer, seg_image = image_vision(image_input_path, prompt)
    pred_masks = seg_image[0]
    
    if '[SEG]' in answer and Visualizer is not None:
        temp_dir = tempfile.mkdtemp()
        pred_mask = pred_masks[0]
        os.makedirs(temp_dir, exist_ok=True)
        seg_result = visualize(pred_mask, image_input_path, temp_dir)
    
    return answer, seg_result

# Gradio UI

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Image IN", type="filepath")
                with gr.Row():
                    instruction = gr.Textbox(label="Instruction", scale=4)
                    submit_btn = gr.Button("Submit", scale=1)
            with gr.Column():
                output_res = gr.Textbox(label="Response")
                output_image = gr.Image(label="Segmentation", type="numpy")

    submit_btn.click(
        fn = main_infer,
        inputs = [image_input, instruction],
        outputs = [output_res, output_image]
    )

demo.queue().launch(show_api=False, show_error=True)