Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,929 Bytes
d59f323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import sys
from projects.llava_sam2.gradio.app_utils import\
process_markdown, show_mask_pred, description, preprocess_video,\
show_mask_pred_video, image2video_and_save
import torch
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, CLIPImageProcessor,
CLIPVisionModel, GenerationConfig)
import argparse
import os
TORCH_DTYPE_MAP = dict(
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
def parse_args(args):
parser = argparse.ArgumentParser(description="Sa2VA Demo")
parser.add_argument('hf_path', help='Sa2VA hf path.')
return parser.parse_args(args)
def inference(image, video, follow_up, input_str):
input_image = image
if image is not None and (video is not None and os.path.exists(video)):
return image, video, "Error: Please only input a image or a video !!!"
if image is None and (video is None or not os.path.exists(video)) and not follow_up:
return image, video, "Error: Please input a image or a video !!!"
if not follow_up:
# reset
print('Log: History responses have been removed!')
global_infos.n_turn = 0
global_infos.inputs = ''
text = input_str
image = input_image
global_infos.image_for_show = image
global_infos.image = image
video = video
global_infos.video = video
if image is not None:
global_infos.input_type = "image"
else:
global_infos.input_type = "video"
else:
text = input_str
image = global_infos.image
video = global_infos.video
input_type = global_infos.input_type
if input_type == "video":
video = preprocess_video(video, global_infos.inputs+input_str)
past_text = global_infos.inputs
if past_text == "" and "<image>" not in text:
text = "<image>" + text
if input_type == "image":
input_dict = {
'image': image,
'text': text,
'past_text': past_text,
'mask_prompts': None,
'tokenizer': tokenizer,
}
else:
input_dict = {
'video': video,
'text': text,
'past_text': past_text,
'mask_prompts': None,
'tokenizer': tokenizer,
}
return_dict = sa2va_model.predict_forward(**input_dict)
global_infos.inputs = return_dict["past_text"]
print(return_dict['past_text'])
if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len(
return_dict['prediction_masks']) != 0:
if input_type == "image":
image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],)
video_mask_show = global_infos.video
else:
image_mask_show = None
video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],)
video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4")
else:
image_mask_show = global_infos.image_for_show
video_mask_show = global_infos.video
selected_colors = []
predict = return_dict['prediction'].strip()
global_infos.n_turn += 1
predict = process_markdown(predict, selected_colors)
return image_mask_show, video_mask_show, predict
def init_models(args):
model_path = args.hf_path
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
return model, tokenizer
class global_infos:
inputs = ''
n_turn = 0
image_width = 0
image_height = 0
image_for_show = None
image = None
video = None
input_type = "image" # "image" or "video"
if __name__ == "__main__":
# get parse args and set models
args = parse_args(sys.argv[1:])
sa2va_model, tokenizer = \
init_models(args)
demo = gr.Interface(
inference,
inputs=[
gr.Image(type="pil", label="Upload Image", height=360),
gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360),
gr.Checkbox(label="Follow up Question"),
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Video(label="Output Video", show_download_button=True, format='mp4'),
gr.Markdown()],
theme=gr.themes.Soft(), allow_flagging="auto", description=description,
title='Sa2VA'
)
demo.queue()
demo.launch(share=True) |