File size: 1,816 Bytes
35a9ed4
bdee200
35a9ed4
 
 
 
 
 
 
e856606
bdee200
19540cf
bdee200
 
 
19540cf
488936c
6207473
 
 
 
35a9ed4
 
 
 
 
 
 
 
 
 
 
 
 
8c2e68c
35a9ed4
8c2e68c
 
35a9ed4
 
 
 
 
 
 
 
 
 
b47ae2e
35a9ed4
 
de71836
35a9ed4
de71836
 
35a9ed4
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image 
import numpy as np 
import os 
import gradio as gr

# Load the model and tokenizer 
model_path = "ByteDance/Sa2VA-4B"
 
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True,
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code = True,
)

def image_vision(image_input_path, prompt):
    image_path = image_input_path
    text_prompts = f"<image>{prompt}"
    image = Image.open(image_path).convert('RGB')
    input_dict = {
        'image': image,
        'text': text_prompts,
        'past_text': '',
        'mask_prompts': None,
        'tokenizer': tokenizer,
    }
    return_dict = model.predict_forward(**input_dict)
    print(return_dict)
    answer = return_dict["prediction"] # the text format answer
    
    return answer

def main_infer(image_input_path, prompt):

    response = image_vision(image_input_path, prompt)
    return response

# Gradio UI

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Image IN", type="filepath")
                with gr.Row():
                    instruction = gr.Textbox(label="Instruction", scale=4)
                    submit_btn = gr.Button("Submit", scale=1)
            with gr.Column():
                output_res = gr.Textbox(label="Response")

    submit_btn.click(
        fn = main_infer,
        inputs = [image_input, instruction],
        outputs = [output_res]
    )

demo.queue().launch(show_api=False, show_error=True)