khang119966 commited on
Commit
0ea2c0e
·
verified ·
1 Parent(s): 0769865

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -211
app.py CHANGED
@@ -1,224 +1,228 @@
 
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from PIL import Image
4
- import numpy as np
5
- import os
6
- import tempfile
7
- import spaces
8
  import gradio as gr
9
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import subprocess
11
- import sys
12
-
13
- def install_flash_attn_wheel():
14
- flash_attn_wheel_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
15
- try:
16
- # Call pip to install the wheel file
17
- subprocess.check_call([sys.executable, "-m", "pip", "install", flash_attn_wheel_url])
18
- print("Wheel installed successfully!")
19
- except subprocess.CalledProcessError as e:
20
- print(f"Failed to install the flash attnetion wheel. Error: {e}")
21
-
22
- install_flash_attn_wheel()
23
-
24
- import cv2
25
- try:
26
- from mmengine.visualization import Visualizer
27
- except ImportError:
28
- Visualizer = None
29
- print("Warning: mmengine is not installed, visualization is disabled.")
30
-
31
- # Load the model and tokenizer
32
- model_path = "ByteDance/Sa2VA-4B"
33
-
34
- model = AutoModelForCausalLM.from_pretrained(
35
- model_path,
36
- torch_dtype="auto",
37
- device_map="cuda:0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  trust_remote_code=True,
39
  ).eval().cuda()
 
40
 
41
- tokenizer = AutoTokenizer.from_pretrained(
42
- model_path,
43
- trust_remote_code = True,
44
- )
45
-
46
- from third_parts import VideoReader
47
- def read_video(video_path, video_interval):
48
- vid_frames = VideoReader(video_path)[::video_interval]
49
-
50
- temp_dir = tempfile.mkdtemp()
51
- os.makedirs(temp_dir, exist_ok=True)
52
- image_paths = [] # List to store paths of saved images
53
-
54
- for frame_idx in range(len(vid_frames)):
55
- frame_image = vid_frames[frame_idx]
56
- frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system)
57
- frame_image = Image.fromarray(frame_image)
58
- vid_frames[frame_idx] = frame_image
59
-
60
- # Save the frame as a .jpg file in the temporary folder
61
- image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg")
62
- frame_image.save(image_path, format="JPEG")
63
-
64
- # Append the image path to the list
65
- image_paths.append(image_path)
66
- return vid_frames, image_paths
67
-
68
- def visualize(pred_mask, image_path, work_dir):
69
- visualizer = Visualizer()
70
- img = cv2.imread(image_path)
71
- visualizer.set_image(img)
72
- visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
73
- visual_result = visualizer.get_image()
74
-
75
- output_path = os.path.join(work_dir, os.path.basename(image_path))
76
- cv2.imwrite(output_path, visual_result)
77
- return output_path
78
 
79
  @spaces.GPU
80
- def image_vision(image_input_path, prompt):
81
- image_path = image_input_path
82
- text_prompts = f"<image>{prompt}"
83
- image = Image.open(image_path).convert('RGB')
84
- input_dict = {
85
- 'image': image,
86
- 'text': text_prompts,
87
- 'past_text': '',
88
- 'mask_prompts': None,
89
- 'tokenizer': tokenizer,
90
- }
91
- return_dict = model.predict_forward(**input_dict)
92
- print(return_dict)
93
- answer = return_dict["prediction"] # the text format answer
94
-
95
- seg_image = return_dict["prediction_masks"]
96
 
97
- if '[SEG]' in answer and Visualizer is not None:
98
- pred_masks = seg_image[0]
99
- temp_dir = tempfile.mkdtemp()
100
- pred_mask = pred_masks
101
- os.makedirs(temp_dir, exist_ok=True)
102
- seg_result = visualize(pred_mask, image_input_path, temp_dir)
103
- return answer, seg_result
 
104
  else:
105
- return answer, None
106
-
107
- @spaces.GPU(duration=80)
108
- def video_vision(video_input_path, prompt, video_interval):
109
- # Open the original video
110
- cap = cv2.VideoCapture(video_input_path)
111
-
112
- # Get original video properties
113
- original_fps = cap.get(cv2.CAP_PROP_FPS)
114
-
115
- frame_skip_factor = video_interval
116
-
117
- # Calculate new FPS
118
- new_fps = original_fps / frame_skip_factor
119
-
120
- vid_frames, image_paths = read_video(video_input_path, video_interval)
121
- # create a question (<image> is a placeholder for the video frames)
122
- question = f"<image>{prompt}"
123
- result = model.predict_forward(
124
- video=vid_frames,
125
- text=question,
126
- tokenizer=tokenizer,
127
- )
128
- prediction = result['prediction']
129
- print(prediction)
130
-
131
- if '[SEG]' in prediction and Visualizer is not None:
132
- _seg_idx = 0
133
- pred_masks = result['prediction_masks'][_seg_idx]
134
- seg_frames = []
135
- for frame_idx in range(len(vid_frames)):
136
- pred_mask = pred_masks[frame_idx]
137
- temp_dir = tempfile.mkdtemp()
138
- os.makedirs(temp_dir, exist_ok=True)
139
- seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir)
140
- seg_frames.append(seg_frame)
141
-
142
- output_video = "output_video.mp4"
143
-
144
- # Read the first image to get the size (resolution)
145
- frame = cv2.imread(seg_frames[0])
146
- height, width, layers = frame.shape
147
-
148
- # Define the video codec and create VideoWriter object
149
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4
150
- video = cv2.VideoWriter(output_video, fourcc, new_fps, (width, height))
151
-
152
- # Iterate over the image paths and write to the video
153
- for img_path in seg_frames:
154
- frame = cv2.imread(img_path)
155
- video.write(frame)
156
-
157
- # Release the video writer
158
- video.release()
159
-
160
- print(f"Video created successfully at {output_video}")
161
-
162
- return result['prediction'], output_video
163
-
164
- else:
165
- return result['prediction'], None
166
 
167
-
168
-
169
- # Gradio UI
170
-
171
- with gr.Blocks(analytics_enabled=False) as demo:
172
- with gr.Column():
173
- gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
174
- gr.HTML("""
175
- <div style="display:flex;column-gap:4px;">
176
- <a href="https://github.com/magic-research/Sa2VA">
177
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
178
- </a>
179
- <a href="https://arxiv.org/abs/2501.04001">
180
- <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
181
- </a>
182
- <a href="https://huggingface.co/spaces/fffiloni/Sa2VA-simple-demo?duplicate=true">
183
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
184
- </a>
185
- <a href="https://huggingface.co/fffiloni">
186
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
187
- </a>
188
- </div>
189
- """)
190
- with gr.Tab("Single Image"):
191
- with gr.Row():
192
- with gr.Column():
193
- image_input = gr.Image(label="Image IN", type="filepath")
194
- with gr.Row():
195
- instruction = gr.Textbox(label="Instruction", scale=4)
196
- submit_image_btn = gr.Button("Submit", scale=1)
197
- with gr.Column():
198
- output_res = gr.Textbox(label="Response")
199
- output_image = gr.Image(label="Segmentation", type="numpy")
200
 
201
- submit_image_btn.click(
202
- fn = image_vision,
203
- inputs = [image_input, instruction],
204
- outputs = [output_res, output_image]
205
- )
206
- with gr.Tab("Video"):
207
- with gr.Row():
208
- with gr.Column():
209
- video_input = gr.Video(label="Video IN")
210
- frame_interval = gr.Slider(label="Frame interval", step=1, minimum=1, maximum=12, value=6)
211
- with gr.Row():
212
- vid_instruction = gr.Textbox(label="Instruction", scale=4)
213
- submit_video_btn = gr.Button("Submit", scale=1)
214
- with gr.Column():
215
- vid_output_res = gr.Textbox(label="Response")
216
- output_video = gr.Video(label="Segmentation")
217
-
218
- submit_video_btn.click(
219
- fn = video_vision,
220
- inputs = [video_input, vid_instruction, frame_interval],
221
- outputs = [vid_output_res, output_video]
222
- )
223
 
224
- demo.queue().launch(show_api=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
 
 
 
 
 
4
  import gradio as gr
5
+ import spaces
6
+ import torch
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+ from torchvision.transforms.functional import InterpolationMode
12
+ from transformers import AutoModel, AutoTokenizer
13
+
14
+ from threading import Thread
15
+ import re
16
+ import time
17
+ from PIL import Image
18
+ import torch
19
+ import spaces
20
  import subprocess
21
+ import os
22
+
23
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
24
+
25
+ torch.set_default_device('cuda')
26
+
27
+
28
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_STD = (0.229, 0.224, 0.225)
30
+
31
+ def build_transform(input_size):
32
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
33
+ transform = T.Compose([
34
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
35
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
36
+ T.ToTensor(),
37
+ T.Normalize(mean=MEAN, std=STD)
38
+ ])
39
+ return transform
40
+
41
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
42
+ best_ratio_diff = float('inf')
43
+ best_ratio = (1, 1)
44
+ area = width * height
45
+ for ratio in target_ratios:
46
+ target_aspect_ratio = ratio[0] / ratio[1]
47
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
48
+ if ratio_diff < best_ratio_diff:
49
+ best_ratio_diff = ratio_diff
50
+ best_ratio = ratio
51
+ elif ratio_diff == best_ratio_diff:
52
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
53
+ best_ratio = ratio
54
+ return best_ratio
55
+
56
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
57
+ orig_width, orig_height = image.size
58
+ aspect_ratio = orig_width / orig_height
59
+
60
+ # calculate the existing image aspect ratio
61
+ target_ratios = set(
62
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
63
+ i * j <= max_num and i * j >= min_num)
64
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
65
+
66
+ # find the closest aspect ratio to the target
67
+ target_aspect_ratio = find_closest_aspect_ratio(
68
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
69
+
70
+ # calculate the target width and height
71
+ target_width = image_size * target_aspect_ratio[0]
72
+ target_height = image_size * target_aspect_ratio[1]
73
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
74
+
75
+ # resize the image
76
+ resized_img = image.resize((target_width, target_height))
77
+ processed_images = []
78
+ for i in range(blocks):
79
+ box = (
80
+ (i % (target_width // image_size)) * image_size,
81
+ (i // (target_width // image_size)) * image_size,
82
+ ((i % (target_width // image_size)) + 1) * image_size,
83
+ ((i // (target_width // image_size)) + 1) * image_size
84
+ )
85
+ # split the image
86
+ split_img = resized_img.crop(box)
87
+ processed_images.append(split_img)
88
+ assert len(processed_images) == blocks
89
+ if use_thumbnail and len(processed_images) != 1:
90
+ thumbnail_img = image.resize((image_size, image_size))
91
+ processed_images.append(thumbnail_img)
92
+ return processed_images
93
+
94
+ def load_image(image_file, input_size=448, max_num=12):
95
+ image = Image.open(image_file).convert('RGB')
96
+ transform = build_transform(input_size=input_size)
97
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
98
+ pixel_values = [transform(image) for image in images]
99
+ pixel_values = torch.stack(pixel_values)
100
+ return pixel_values
101
+
102
+ model = AutoModel.from_pretrained(
103
+ "5CD-AI/Vintern-3B-beta",
104
+ torch_dtype=torch.bfloat16,
105
+ low_cpu_mem_usage=True,
106
  trust_remote_code=True,
107
  ).eval().cuda()
108
+ tokenizer = AutoTokenizer.from_pretrained("5CD-AI/Vintern-3B-beta", trust_remote_code=True, use_fast=False)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  @spaces.GPU
112
+ def chat(message, history):
113
+ print("history",history)
114
+ print("message",message)
115
+
116
+ if len(history) != 0 and len(message["files"]) != 0:
117
+ return """Chúng tôi hiện chỉ hổ trợ 1 ảnh ở đầu ngữ cảnh! Vui lòng tạo mới cuộc trò chuyện.
118
+ We currently only support one image at the start of the context! Please start a new conversation."""
 
 
 
 
 
 
 
 
 
119
 
120
+ if len(history) == 0 and len(message["files"]) != 0:
121
+ test_image = message["files"][0]["path"]
122
+ pixel_values = load_image(test_image, max_num=6).to(torch.bfloat16).cuda()
123
+ elif len(history) == 0 and len(message["files"]) == 0:
124
+ pixel_values = None
125
+ elif history[0][0][0] is not None and os.path.isfile(history[0][0][0]):
126
+ test_image = history[0][0][0]
127
+ pixel_values = load_image(test_image, max_num=6).to(torch.bfloat16).cuda()
128
  else:
129
+ pixel_values = None
130
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ generation_config = dict(max_new_tokens= 512, do_sample=False, num_beams = 3, repetition_penalty=2.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ if len(history) == 0:
135
+ if pixel_values is not None:
136
+ question = '<image>\n'+message["text"]
137
+ else:
138
+ question = message["text"]
139
+ response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
140
+ else:
141
+ conv_history = []
142
+ if history[0][0][0] is not None and os.path.isfile(history[0][0][0]):
143
+ start_index = 1
144
+ else:
145
+ start_index = 0
146
+
147
+ for i, chat_pair in enumerate(history[start_index:]):
148
+ if i == 0 and start_index == 1:
149
+ conv_history.append(tuple(['<image>\n'+chat_pair[0],chat_pair[1]]))
150
+ else:
151
+ conv_history.append(tuple(chat_pair))
 
 
 
 
152
 
153
+
154
+ print("conv_history",conv_history)
155
+ question = message["text"]
156
+ response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=conv_history, return_history=True)
157
+
158
+ print(f'User: {question}\nAssistant: {response}')
159
+
160
+ return response
161
+ # buffer = ""
162
+ # for new_text in response:
163
+ # buffer += new_text
164
+ # generated_text_without_prompt = buffer[:]
165
+ # time.sleep(0.005)
166
+ # yield generated_text_without_prompt
167
+
168
+ CSS ="""
169
+ # @media only screen and (max-width: 600px){
170
+ # #component-3 {
171
+ # height: 90dvh !important;
172
+ # transform-origin: top; /* Đảm bảo rằng phần tử mở rộng từ trên xuống */
173
+ # border-style: solid;
174
+ # overflow: hidden;
175
+ # flex-grow: 1;
176
+ # min-width: min(160px, 100%);
177
+ # border-width: var(--block-border-width);
178
+ # }
179
+ # }
180
+ #component-3 {
181
+ height: 50dvh !important;
182
+ transform-origin: top; /* Đảm bảo rằng phần tử mở rộng từ trên xuống */
183
+ border-style: solid;
184
+ overflow: hidden;
185
+ flex-grow: 1;
186
+ min-width: min(160px, 100%);
187
+ border-width: var(--block-border-width);
188
+ }
189
+ /* Đảm bảo ảnh bên trong nút hiển thị đúng cách cho các nút có aria-label chỉ định */
190
+ button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] img.svelte-1pijsyv {
191
+ width: 100%;
192
+ object-fit: contain;
193
+ height: 100%;
194
+ border-radius: 13px; /* Thêm bo góc cho ảnh */
195
+ max-width: 50vw; /* Giới hạn chiều rộng ảnh */
196
+ }
197
+ /* Đặt chiều cao cho nút và cho phép chọn văn bản chỉ cho các nút có aria-label chỉ định */
198
+ button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] {
199
+ user-select: text;
200
+ text-align: left;
201
+ height: 300px;
202
+ }
203
+ /* Thêm bo góc và giới hạn chiều rộng cho ảnh không thuộc avatar container */
204
+ .message-wrap.svelte-1lcyrx4 > div.svelte-1lcyrx4 .svelte-1lcyrx4:not(.avatar-container) img {
205
+ border-radius: 13px;
206
+ max-width: 50vw;
207
+ }
208
+ .message-wrap.svelte-1lcyrx4 .message.svelte-1lcyrx4 img {
209
+ margin: var(--size-2);
210
+ max-height: 500px;
211
+ }
212
+ """
213
+
214
+
215
+ demo = gr.ChatInterface(
216
+ fn=chat,
217
+ description="""Try [Vintern-3B-beta](https://huggingface.co/5CD-AI/Vintern-3B-beta) in this demo. Vintern-3B-beta consists of [InternViT-300M-448px](https://huggingface.co/OpenGVLab/InternViT-300M-448px), an MLP projector, and [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct).
218
+ Bias, Risks, and Limitations
219
+ The model might have biases because it learned from data that could be biased.
220
+ Users should be cautious of these possible biases when using the model.""",
221
+ examples=[{"text": "Mô tả hình ảnh.", "files":["./demo_3.jpg"]},
222
+ {"text": "Trích xuất các thông tin từ ảnh.", "files":["./demo_1.jpg"]},
223
+ {"text": "Mô tả hình ảnh một cách chi tiết.", "files":["./demo_2.jpg"]}],
224
+ title="❄️ Vintern-3B-beta Test ❄️",
225
+ multimodal=True,
226
+ css=CSS
227
+ )
228
+ demo.queue().launch()