File size: 1,695 Bytes
35a9ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import AutoTokenizer, AutoModel 
from PIL import Image 
import numpy as np 
import os 
import gradio as gr

# Load the model and tokenizer 
model_path = "ByteDance/Sa2VA-4B"

model = AutoModel.from_pretrained(
    model_path,
    torch_dtype = torch.bfloat16,
    low_cpu_mem_usage = True,
    use_flash_attn = True,
    trust_remote_code = True
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code = True,
    use_fast = False
)

def image_vision(image_input_path, prompt):
    image_path = image_input_path
    text_prompts = f"<image>{prompt}"
    image = Image.open(image_path).convert('RGB')
    input_dict = {
        'image': image,
        'text': text_prompts,
        'past_text': '',
        'mask_prompts': None,
        'tokenizer': tokenizer,
    }
    return_dict = model.predict_forward(**input_dict)
    answer = return_dict["prediction"] # the text format answer
    print(answer)

def main_infer(image_input_path, prompt):

    response = image_vision(image_input_path, prompt)
    return response

# Gradio UI

with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Image IN")
                with gr.Row():
                    instruction = gr.Textbox(label="Instruction")
                    submit_btn = gr.Button("SUbmit", scale=1)
            with gr.Column():
                output_res = gr.Textbox(label="Response")

    submit_btn.click(
        fn = main_infer,
        inputs = [image_input, instruction],
        outputs = [output_res]
    )

demo.queue().launch(show_api=False, show_error=True)