fffiloni's picture
Update app.py
488936c verified
raw
history blame
2.25 kB
import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import numpy as np
import os
import gradio as gr
# Load the model and tokenizer
model_path = "ByteDance/Sa2VA-4B"
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("/modeling_phi3.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModel.from_pretrained(
model_path,
torch_dtype = torch.bfloat16,
low_cpu_mem_usage = True,
use_flash_attn = True,
trust_remote_code = True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code = True,
use_fast = False
)
def image_vision(image_input_path, prompt):
image_path = image_input_path
text_prompts = f"<image>{prompt}"
image = Image.open(image_path).convert('RGB')
input_dict = {
'image': image,
'text': text_prompts,
'past_text': '',
'mask_prompts': None,
'tokenizer': tokenizer,
}
return_dict = model.predict_forward(**input_dict)
answer = return_dict["prediction"] # the text format answer
print(answer)
def main_infer(image_input_path, prompt):
response = image_vision(image_input_path, prompt)
return response
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Image IN")
with gr.Row():
instruction = gr.Textbox(label="Instruction")
submit_btn = gr.Button("SUbmit", scale=1)
with gr.Column():
output_res = gr.Textbox(label="Response")
submit_btn.click(
fn = main_infer,
inputs = [image_input, instruction],
outputs = [output_res]
)
demo.queue().launch(show_api=False, show_error=True)