File size: 4,929 Bytes
d59f323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import sys
from projects.llava_sam2.gradio.app_utils import\
    process_markdown, show_mask_pred, description, preprocess_video,\
    show_mask_pred_video, image2video_and_save

import torch
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig, CLIPImageProcessor,
                          CLIPVisionModel, GenerationConfig)
import argparse
import os

TORCH_DTYPE_MAP = dict(
    fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')

def parse_args(args):
    parser = argparse.ArgumentParser(description="Sa2VA Demo")
    parser.add_argument('hf_path', help='Sa2VA hf path.')
    return parser.parse_args(args)

def inference(image, video, follow_up, input_str):
    input_image = image
    if image is not None and (video is not None and os.path.exists(video)):
        return image, video, "Error: Please only input a image or a video !!!"
    if image is None and (video is None or not os.path.exists(video)) and not follow_up:
        return image, video, "Error: Please input a image or a video !!!"

    if not follow_up:
        # reset
        print('Log: History responses have been removed!')
        global_infos.n_turn = 0
        global_infos.inputs = ''
        text = input_str

        image = input_image
        global_infos.image_for_show = image
        global_infos.image = image
        video = video
        global_infos.video = video

        if image is not None:
            global_infos.input_type = "image"
        else:
            global_infos.input_type = "video"

    else:
        text = input_str
        image = global_infos.image
        video = global_infos.video

    input_type = global_infos.input_type
    if input_type == "video":
        video = preprocess_video(video, global_infos.inputs+input_str)

    past_text = global_infos.inputs

    if past_text == "" and "<image>" not in text:
        text = "<image>" + text
    if input_type == "image":
        input_dict = {
            'image': image,
            'text': text,
            'past_text': past_text,
            'mask_prompts': None,
            'tokenizer': tokenizer,
        }
    else:
        input_dict = {
            'video': video,
            'text': text,
            'past_text': past_text,
            'mask_prompts': None,
            'tokenizer': tokenizer,
        }

    return_dict = sa2va_model.predict_forward(**input_dict)
    global_infos.inputs = return_dict["past_text"]
    print(return_dict['past_text'])
    if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len(
            return_dict['prediction_masks']) != 0:
        if input_type == "image":
            image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],)
            video_mask_show = global_infos.video
        else:
            image_mask_show = None
            video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],)
            video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4")
    else:
        image_mask_show = global_infos.image_for_show
        video_mask_show = global_infos.video
        selected_colors = []

    predict = return_dict['prediction'].strip()
    global_infos.n_turn += 1

    predict = process_markdown(predict, selected_colors)
    return image_mask_show, video_mask_show, predict

def init_models(args):
    model_path = args.hf_path
    model = AutoModel.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        use_flash_attn=True,
        trust_remote_code=True,
    ).eval().cuda()

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True,
    )
    return model, tokenizer

class global_infos:
    inputs = ''
    n_turn = 0
    image_width = 0
    image_height = 0

    image_for_show = None
    image = None
    video = None

    input_type = "image" # "image" or "video"

if __name__ == "__main__":
    # get parse args and set models
    args = parse_args(sys.argv[1:])

    sa2va_model, tokenizer = \
        init_models(args)

    demo = gr.Interface(
        inference,
        inputs=[
            gr.Image(type="pil", label="Upload Image", height=360),
            gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360),
            gr.Checkbox(label="Follow up Question"),
            gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),],
        outputs=[
            gr.Image(type="pil", label="Output Image"),
            gr.Video(label="Output Video", show_download_button=True, format='mp4'),
            gr.Markdown()],
        theme=gr.themes.Soft(), allow_flagging="auto", description=description,
        title='Sa2VA'
    )

    demo.queue()
    demo.launch(share=True)