Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import numpy as np | |
import os | |
import tempfile | |
import gradio as gr | |
import cv2 | |
try: | |
from mmengine.visualization import Visualizer | |
except ImportError: | |
Visualizer = None | |
print("Warning: mmengine is not installed, visualization is disabled.") | |
# Load the model and tokenizer | |
model_path = "ByteDance/Sa2VA-4B" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True, | |
).eval().cuda() | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code = True, | |
) | |
def visualize(pred_mask, image_path, work_dir): | |
visualizer = Visualizer() | |
img = cv2.imread(image_path) | |
visualizer.set_image(img) | |
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) | |
visual_result = visualizer.get_image() | |
output_path = os.path.join(work_dir, os.path.basename(image_path)) | |
cv2.imwrite(output_path, visual_result) | |
return output_path | |
def image_vision(image_input_path, prompt): | |
image_path = image_input_path | |
text_prompts = f"<image>{prompt}" | |
image = Image.open(image_path).convert('RGB') | |
input_dict = { | |
'image': image, | |
'text': text_prompts, | |
'past_text': '', | |
'mask_prompts': None, | |
'tokenizer': tokenizer, | |
} | |
return_dict = model.predict_forward(**input_dict) | |
print(return_dict) | |
answer = return_dict["prediction"] # the text format answer | |
seg_image = return_dict["prediction_masks"] | |
return answer, seg_image | |
def main_infer(image_input_path, prompt): | |
answer, seg_image = image_vision(image_input_path, prompt) | |
pred_masks = seg_image[0] | |
if '[SEG]' in answer and Visualizer is not None: | |
temp_dir = tempfile.mkdtemp() | |
pred_mask = pred_masks[0] | |
os.makedirs(temp_dir, exist_ok=True) | |
seg_result = visualize(pred_mask, image_input_path, temp_dir) | |
return answer, seg_result | |
# Gradio UI | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Image IN", type="filepath") | |
with gr.Row(): | |
instruction = gr.Textbox(label="Instruction", scale=4) | |
submit_btn = gr.Button("Submit", scale=1) | |
with gr.Column(): | |
output_res = gr.Textbox(label="Response") | |
output_image = gr.Image(label="Segmentation", type="numpy") | |
submit_btn.click( | |
fn = main_infer, | |
inputs = [image_input, instruction], | |
outputs = [output_res, output_image] | |
) | |
demo.queue().launch(show_api=False, show_error=True) |