Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from PIL import Image | |
import numpy as np | |
import os | |
import gradio as gr | |
# Load the model and tokenizer | |
model_path = "ByteDance/Sa2VA-4B" | |
from unittest.mock import patch | |
from transformers.dynamic_module_utils import get_imports | |
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
if not str(filename).endswith("/modeling_phi3.py"): | |
return get_imports(filename) | |
imports = get_imports(filename) | |
imports.remove("flash_attn") | |
return imports | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
model = AutoModel.from_pretrained( | |
model_path, | |
torch_dtype = torch.bfloat16, | |
low_cpu_mem_usage = True, | |
use_flash_attn = True, | |
trust_remote_code = True | |
).eval().cuda() | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code = True, | |
use_fast = False | |
) | |
def image_vision(image_input_path, prompt): | |
image_path = image_input_path | |
text_prompts = f"<image>{prompt}" | |
image = Image.open(image_path).convert('RGB') | |
input_dict = { | |
'image': image, | |
'text': text_prompts, | |
'past_text': '', | |
'mask_prompts': None, | |
'tokenizer': tokenizer, | |
} | |
return_dict = model.predict_forward(**input_dict) | |
answer = return_dict["prediction"] # the text format answer | |
print(answer) | |
def main_infer(image_input_path, prompt): | |
response = image_vision(image_input_path, prompt) | |
return response | |
# Gradio UI | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Image IN") | |
with gr.Row(): | |
instruction = gr.Textbox(label="Instruction") | |
submit_btn = gr.Button("SUbmit", scale=1) | |
with gr.Column(): | |
output_res = gr.Textbox(label="Response") | |
submit_btn.click( | |
fn = main_infer, | |
inputs = [image_input, instruction], | |
outputs = [output_res] | |
) | |
demo.queue().launch(show_api=False, show_error=True) |