fffiloni's picture
Update app.py
1d9d947 verified
raw
history blame
2.28 kB
import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import numpy as np
import os
import gradio as gr
# Load the model and tokenizer
model_path = "ByteDance/Sa2VA-4B"
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("/modeling_phi3.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModel.from_pretrained(
model_path,
torch_dtype = torch.bfloat16,
low_cpu_mem_usage = True,
use_flash_attn = True,
trust_remote_code = True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code = True,
use_fast = False
)
def image_vision(image_input_path, prompt):
image_path = image_input_path
text_prompts = f"<image>{prompt}"
image = Image.open(image_path).convert('RGB')
input_dict = {
'image': image,
'text': text_prompts,
'past_text': '',
'mask_prompts': None,
'tokenizer': tokenizer,
}
return_dict = model.predict_forward(**input_dict)
answer = return_dict["prediction"] # the text format answer
print(answer)
def main_infer(image_input_path, prompt):
response = image_vision(image_input_path, prompt)
return response
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Image IN")
with gr.Row():
instruction = gr.Textbox(label="Instruction")
submit_btn = gr.Button("SUbmit", scale=1)
with gr.Column():
output_res = gr.Textbox(label="Response")
submit_btn.click(
fn = main_infer,
inputs = [image_input, instruction],
outputs = [output_res]
)
demo.queue().launch(show_api=False, show_error=True)