fffiloni's picture
Update app.py
8c2e68c verified
raw
history blame
1.82 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import numpy as np
import os
import gradio as gr
# Load the model and tokenizer
model_path = "ByteDance/Sa2VA-4B"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code = True,
)
def image_vision(image_input_path, prompt):
image_path = image_input_path
text_prompts = f"<image>{prompt}"
image = Image.open(image_path).convert('RGB')
input_dict = {
'image': image,
'text': text_prompts,
'past_text': '',
'mask_prompts': None,
'tokenizer': tokenizer,
}
return_dict = model.predict_forward(**input_dict)
print(return_dict)
answer = return_dict["prediction"] # the text format answer
return answer
def main_infer(image_input_path, prompt):
response = image_vision(image_input_path, prompt)
return response
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Image IN", type="filepath")
with gr.Row():
instruction = gr.Textbox(label="Instruction", scale=4)
submit_btn = gr.Button("Submit", scale=1)
with gr.Column():
output_res = gr.Textbox(label="Response")
submit_btn.click(
fn = main_infer,
inputs = [image_input, instruction],
outputs = [output_res]
)
demo.queue().launch(show_api=False, show_error=True)