import torch from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import numpy as np import os import tempfile import gradio as gr import cv2 try: from mmengine.visualization import Visualizer except ImportError: Visualizer = None print("Warning: mmengine is not installed, visualization is disabled.") # Load the model and tokenizer model_path = "ByteDance/Sa2VA-4B" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True, ).eval().cuda() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code = True, ) def visualize(pred_mask, image_path, work_dir): visualizer = Visualizer() img = cv2.imread(image_path) visualizer.set_image(img) visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) visual_result = visualizer.get_image() output_path = os.path.join(work_dir, os.path.basename(image_path)) cv2.imwrite(output_path, visual_result) return output_path def image_vision(image_input_path, prompt): image_path = image_input_path text_prompts = f"{prompt}" image = Image.open(image_path).convert('RGB') input_dict = { 'image': image, 'text': text_prompts, 'past_text': '', 'mask_prompts': None, 'tokenizer': tokenizer, } return_dict = model.predict_forward(**input_dict) print(return_dict) answer = return_dict["prediction"] # the text format answer seg_image = return_dict["prediction_masks"] return answer, seg_image def main_infer(image_input_path, prompt): answer, seg_image = image_vision(image_input_path, prompt) if '[SEG]' in answer and Visualizer is not None: pred_masks = seg_image[0] temp_dir = tempfile.mkdtemp() pred_mask = pred_masks os.makedirs(temp_dir, exist_ok=True) seg_result = visualize(pred_mask, image_input_path, temp_dir) return answer, seg_result else: return answer, None # Gradio UI with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Image IN", type="filepath") with gr.Row(): instruction = gr.Textbox(label="Instruction", scale=4) submit_btn = gr.Button("Submit", scale=1) with gr.Column(): output_res = gr.Textbox(label="Response") output_image = gr.Image(label="Segmentation", type="numpy") submit_btn.click( fn = main_infer, inputs = [image_input, instruction], outputs = [output_res, output_image] ) demo.queue().launch(show_api=False, show_error=True)