omni-research commited on
Commit
dcd4560
·
1 Parent(s): aa4801f

update to tarsier2-7b-0115

Browse files
app.py CHANGED
@@ -13,19 +13,22 @@
13
  # limitations under the License.
14
 
15
  # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/demo/demo.py
16
- import spaces
 
17
  from copy import deepcopy
18
  import gradio as gr
19
  from gradio.themes.utils import colors, fonts, sizes
20
  from tools.conversation import Chat, conv_templates
21
  from tools.utils import load_model_and_processor, file_to_base64
22
- from dataset.processor import Processor
23
  import os
24
  import torch
 
25
 
26
  # huggingface-cli login
27
 
28
- model_path = os.getenv("MODEL_PATH", "omni-research/Tarsier2-7b")
 
29
  max_n_frames = int(os.getenv("MAX_N_FRAMES", 16))
30
  debug = False
31
  device = 'cuda' if not debug else 'cpu'
@@ -34,13 +37,14 @@ device = 'cuda' if not debug else 'cpu'
34
  # Model Initialization
35
  # ========================================
36
  def init_model():
 
37
  print("Start Initialization...")
38
  # if torch.cuda.is_available():
39
  if not debug:
40
- model, processor = load_model_and_processor(model_path, max_n_frames)
41
  else:
42
  print(f"No Valid GPU! Lauch in debug mode!")
43
- processor = Processor(model_path, max_n_frames)
44
  model = None
45
  chat = Chat(model, processor, device, debug)
46
  print('Initialization Finished')
@@ -50,13 +54,11 @@ def init_model():
50
  # ========================================
51
  # Gradio Setting
52
  # ========================================
53
- def gradio_reset(chat_state, img_file, img_list):
54
  if chat_state is not None:
55
  chat_state.messages = []
56
  img_file = None
57
- if img_list is not None:
58
- img_list = []
59
- return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_file, img_list
60
 
61
 
62
  def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
@@ -64,24 +66,24 @@ def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
64
  conv_type = ''
65
  if 'tarsier2-7b' in model_path.lower():
66
  conv_type = 'tarsier2-7b'
67
- elif '7b' in model_path.lower():
68
- conv_type = 'tarsier-7b'
69
- elif '13b' in model_path.lower():
70
- conv_type = 'tarsier-13b'
71
- elif '34b' in model_path.lower():
72
- conv_type = 'tarsier-34b'
73
  else:
74
  raise ValueError(f"Unknow model: {model_path}")
75
  chat_state = deepcopy(conv_templates[conv_type])
76
 
77
- img_list = []
78
  if gr_img is None and gr_video is None and gr_gif is None:
79
  return None, None, None, gr.update(interactive=True), gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None, None
80
  if gr_video or gr_img or gr_gif:
81
  for img_file in [gr_video, gr_img, gr_gif]:
82
  if img_file is not None:
83
  break
84
- return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_file, img_list
 
85
 
86
 
87
  def gradio_ask(user_message, chatbot, chat_state):
@@ -91,13 +93,13 @@ def gradio_ask(user_message, chatbot, chat_state):
91
  chatbot = chatbot + [[user_message, None]]
92
  return '', chatbot, chat_state
93
 
94
- @spaces.GPU(duration=120)
95
- def gradio_answer(chatbot, chat_state, img_file, img_list, top_p, temperature, n_frames=None):
96
- llm_message, chat_state, img_list = chat.answer(conv=chat_state, visual_data_file=img_file, images=img_list, n_frames=n_frames, max_new_tokens=256, num_beams=1, temperature=temperature, top_p=top_p)
97
  chatbot[-1][1] = llm_message
98
  print(chat_state)
99
  print(f"Answer: {llm_message}")
100
- return chatbot, chat_state, img_list
101
 
102
 
103
  class OpenGVLab(gr.themes.base.Base):
@@ -203,7 +205,6 @@ with gr.Blocks(title="Tarsier",theme=gvlabtheme,css="#chatbot {overflow:auto; he
203
 
204
  with gr.Column(visible=True) as input_raws:
205
  chat_state = gr.State()
206
- img_list = gr.State()
207
  img_file = gr.State()
208
  chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
209
  with gr.Row():
@@ -216,19 +217,19 @@ with gr.Blocks(title="Tarsier",theme=gvlabtheme,css="#chatbot {overflow:auto; he
216
  gr.Examples(examples=[
217
  [f"examples/test1.mp4", "Describe the video in detail."],
218
  [f"examples/test2.mp4", "Are they having a pleasant conversation?"],
219
- ], inputs=[up_video, text_input])
220
 
221
  chat = init_model()
222
- upload_button.click(upload_img, [up_image, up_video, up_gif, chat_state, num_frames], [up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file, img_list])
223
 
224
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
225
- gradio_answer, [chatbot, chat_state, img_file, img_list, top_p, temperature, num_frames], [chatbot, chat_state, img_list]
226
  )
227
  run.click(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
228
- gradio_answer, [chatbot, chat_state, img_file, img_list, top_p, temperature, num_frames], [chatbot, chat_state, img_list]
229
  )
230
  run.click(lambda: "", None, text_input)
231
- clear.click(gradio_reset, [chat_state, img_file, img_list], [chatbot, up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file, img_list], queue=False)
232
 
233
 
234
  demo.launch()
 
13
  # limitations under the License.
14
 
15
  # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/demo/demo.py
16
+
17
+ import spaces # for deploying on huggingface ZeroGPU
18
  from copy import deepcopy
19
  import gradio as gr
20
  from gradio.themes.utils import colors, fonts, sizes
21
  from tools.conversation import Chat, conv_templates
22
  from tools.utils import load_model_and_processor, file_to_base64
23
+ from dataset.tarsier_datamodule import init_processor
24
  import os
25
  import torch
26
+ import yaml
27
 
28
  # huggingface-cli login
29
 
30
+ model_path = os.getenv("MODEL_PATH", "omni-research/Tarsier2-7b-0115")
31
+ config_path = "configs/tarser2_default_config.yaml"
32
  max_n_frames = int(os.getenv("MAX_N_FRAMES", 16))
33
  debug = False
34
  device = 'cuda' if not debug else 'cpu'
 
37
  # Model Initialization
38
  # ========================================
39
  def init_model():
40
+ config = yaml.safe_load(open(config_path, 'r'))
41
  print("Start Initialization...")
42
  # if torch.cuda.is_available():
43
  if not debug:
44
+ model, processor = load_model_and_processor(model_path, config)
45
  else:
46
  print(f"No Valid GPU! Lauch in debug mode!")
47
+ processor = init_processor(model_path, config)
48
  model = None
49
  chat = Chat(model, processor, device, debug)
50
  print('Initialization Finished')
 
54
  # ========================================
55
  # Gradio Setting
56
  # ========================================
57
+ def gradio_reset(chat_state, img_file):
58
  if chat_state is not None:
59
  chat_state.messages = []
60
  img_file = None
61
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_file
 
 
62
 
63
 
64
  def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
 
66
  conv_type = ''
67
  if 'tarsier2-7b' in model_path.lower():
68
  conv_type = 'tarsier2-7b'
69
+ # elif '7b' in model_path.lower():
70
+ # conv_type = 'tarsier-7b'
71
+ # elif '13b' in model_path.lower():
72
+ # conv_type = 'tarsier-13b'
73
+ # elif '34b' in model_path.lower():
74
+ # conv_type = 'tarsier-34b'
75
  else:
76
  raise ValueError(f"Unknow model: {model_path}")
77
  chat_state = deepcopy(conv_templates[conv_type])
78
 
 
79
  if gr_img is None and gr_video is None and gr_gif is None:
80
  return None, None, None, gr.update(interactive=True), gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None, None
81
  if gr_video or gr_img or gr_gif:
82
  for img_file in [gr_video, gr_img, gr_gif]:
83
  if img_file is not None:
84
  break
85
+ chat_state.messages.append([chat_state.roles[0], {"type": "video", "text": img_file}])
86
+ return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_file
87
 
88
 
89
  def gradio_ask(user_message, chatbot, chat_state):
 
93
  chatbot = chatbot + [[user_message, None]]
94
  return '', chatbot, chat_state
95
 
96
+ @spaces.GPU(duration=120) # for deploying on huggingface ZeroGPU
97
+ def gradio_answer(chatbot, chat_state, img_file, top_p, temperature, n_frames=None):
98
+ llm_message, chat_state = chat.answer(conv=chat_state, n_frames=n_frames, max_new_tokens=256, num_beams=1, temperature=temperature, top_p=top_p)
99
  chatbot[-1][1] = llm_message
100
  print(chat_state)
101
  print(f"Answer: {llm_message}")
102
+ return chatbot, chat_state
103
 
104
 
105
  class OpenGVLab(gr.themes.base.Base):
 
205
 
206
  with gr.Column(visible=True) as input_raws:
207
  chat_state = gr.State()
 
208
  img_file = gr.State()
209
  chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
210
  with gr.Row():
 
217
  gr.Examples(examples=[
218
  [f"examples/test1.mp4", "Describe the video in detail."],
219
  [f"examples/test2.mp4", "Are they having a pleasant conversation?"],
220
+ ], inputs=[up_video, text_input])
221
 
222
  chat = init_model()
223
+ upload_button.click(upload_img, [up_image, up_video, up_gif, chat_state, num_frames], [up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file])
224
 
225
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
226
+ gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
227
  )
228
  run.click(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
229
+ gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
230
  )
231
  run.click(lambda: "", None, text_input)
232
+ clear.click(gradio_reset, [chat_state, img_file], [chatbot, up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file], queue=False)
233
 
234
 
235
  demo.launch()
configs/tarser2_default_config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_n_frames: 256
2
+ n_frames: 16
3
+ max_pixels: 460800 # 1280 * 720 // 2
4
+ min_pixels: 0
5
+ max_seq_len: 16384
6
+ is_training: false # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
7
+ print_data_error: true
8
+ is_training: false
9
+ do_image_padding: false
10
+ do_image_crop: false
11
+ do_image_resize: false
12
+ video_sampling_strategy: {'video_sampler_version': 'v1', 'force_frames_n_divisible': 1, 'use_multi_images_for_video': true}
13
+ prompt: ""
14
+ train_task: sft
dataset/custom_data_parsers/multi_images_parser.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import random
3
+ import re
4
+ from PIL import Image
5
+
6
+ from .utils import sample_video, read_image
7
+
8
+ class MultiImagesParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ is_training=True,
13
+ ):
14
+ self.n_frames = n_frames
15
+ self.is_training = is_training
16
+ # fmt: off
17
+ self.data_temp = {
18
+ "text": [
19
+ [{
20
+ "prompt": "Describe the image in short.",
21
+ "response": "A rollerblader rides high in a full pipe while others watch"
22
+ }],
23
+ [{
24
+ "prompt": "Describe the image in short.",
25
+ "response": "A woman in winter clothes is on the sidewalk with a phone."
26
+ }]
27
+ ],
28
+ "image": [
29
+ {
30
+ "image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg"
31
+ },
32
+ {
33
+ "image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg"
34
+ },
35
+ {
36
+ "video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4",
37
+ "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
38
+ }
39
+ ],
40
+ "dataset": "coco",
41
+ "task": "multi_images",
42
+ "image_processing_config": {},
43
+ }
44
+ # fmt: on
45
+
46
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
47
+ assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict
48
+
49
+ # 目前多图数据应该没有包含坐标的数据吧
50
+ if image_processing_config.get('has_coordinates', False):
51
+ raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!')
52
+
53
+ # 检查是否能匹配到坐标
54
+ texts = data_dict['text']
55
+ for text in texts:
56
+ match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response'])
57
+ if match:
58
+ print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}')
59
+
60
+
61
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
62
+ self.check_format(data_dict, image_processing_config)
63
+
64
+ # shuffle
65
+ texts = data_dict['text']
66
+ images = data_dict['image']
67
+ images = self.load_images(images)
68
+ idxs = list(range(len(texts)))
69
+ random.shuffle(idxs)
70
+ texts = [texts[i] for i in idxs]
71
+ images = [images[i] for i in idxs]
72
+
73
+ # sample n_frames
74
+ if isinstance(self.n_frames, int):
75
+ n_frames = random.choice(list(range(1, self.n_frames + 1)))
76
+ else:
77
+ n_frames = random.choice(self.n_frames)
78
+ texts = texts[: n_frames]
79
+ images = images[: n_frames]
80
+
81
+ dataset = data_dict['dataset']
82
+ if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']:
83
+ prompt, response = self.transform_for_caption_task(texts, dataset, images)
84
+ else:
85
+ prompt, response = self.transform_for_qa_task(texts, dataset, images)
86
+
87
+ messages = [
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ *[{"type": "image", "image": img} for img in images],
92
+ {"type": "text", "text": prompt},
93
+ ]
94
+ },
95
+ {
96
+ "role": "assistant",
97
+ "content": [
98
+ {"type": "text", "text": response}
99
+ ]
100
+ }
101
+ ]
102
+
103
+ return messages
104
+
105
+ def transform_for_caption_task(self, texts, dataset, images):
106
+ idx = random.choice(list(range(len(texts))))
107
+
108
+ if dataset == 'coco':
109
+ if len(texts) == 1:
110
+ prompt = 'Describe the image in short.'
111
+ else:
112
+ prompt = f'Describe the images starting from frame {idx + 1} in short in order.'
113
+ elif dataset == 'sharegpt4v_cap100k':
114
+ if len(texts) == 1:
115
+ prompt = 'Describe the image in detail.'
116
+ else:
117
+ prompt = f'Describe the images starting from frame {idx + 1} in detail in order.'
118
+ else:
119
+ if len(texts) == 1:
120
+ prompt = 'Describe the image.'
121
+ else:
122
+ prompt = f'Describe the images starting from frame {idx + 1} in order.'
123
+ response = ''
124
+ for i, text in enumerate(texts):
125
+ if i < idx:
126
+ continue
127
+ if not isinstance(text, dict):
128
+ text = random.choice(text)
129
+ resp = text['response']
130
+ response += f'{resp}\n'
131
+ return prompt, response
132
+
133
+ def transform_for_qa_task(self, texts, dataset, images):
134
+ prompt, response = '', ''
135
+ for i, text in enumerate(texts):
136
+ if not isinstance(text, dict):
137
+ text = random.choice(text)
138
+ if len(texts) > 1:
139
+ prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n'
140
+ response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n'
141
+ else:
142
+ prompt += text['prompt'] + '\n'
143
+ response += text['response'] + '\n'
144
+ return prompt, response
145
+
146
+
147
+ def load_images(self, image_items: List[Dict]) -> List[Image.Image]:
148
+ """
149
+ image_items: List[Dict]. each item like:
150
+ {"video_file": "path/to/video", "frame_indices": [1]}
151
+ or
152
+ {"image_file": "path/to/image"}
153
+ """
154
+ if image_items is None:
155
+ raise ValueError(f'image_items is None!')
156
+
157
+ if isinstance(image_items, dict):
158
+ image_items = [image_items]
159
+
160
+ images = []
161
+
162
+ for image_item in image_items:
163
+
164
+ if 'video_file' in image_item:
165
+ file_key = 'video_file'
166
+ elif 'image_file' in image_item:
167
+ file_key = 'image_file'
168
+ else:
169
+ raise KeyError(f'video_file or image_file not in {image_item}')
170
+
171
+ file_path = image_item[file_key]
172
+ if file_key == 'video_file':
173
+ frame_indices = image_item.get('frame_indices', None)
174
+ if frame_indices is None:
175
+ raise ValueError(f'read 0 frame: {image_item}')
176
+ if isinstance(frame_indices, int):
177
+ frame_indices = [frame_indices]
178
+ frames = sample_video(file_path, frame_indices = frame_indices)
179
+ images.extend(frames)
180
+ else:
181
+ if isinstance(file_path, str):
182
+ file_path = [file_path]
183
+ images.extend([read_image(f) for f in file_path])
184
+
185
+ return images
186
+
187
+ if __name__ == '__main__':
188
+ # python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser
189
+
190
+ from tqdm import tqdm
191
+ from tools.rw_utils import read_jsonlines
192
+
193
+ lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl')
194
+ lines = lines[:10]
195
+ parser = MultiImagesParser(n_frames=8)
196
+ for i, l in tqdm(enumerate(lines)):
197
+ l_image_processing_config = l.get('image_processing_config', {})
198
+ messages = parser.transform(l, l_image_processing_config)
199
+ print(messages)
dataset/custom_data_parsers/object_tracking_parser.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import random
3
+ import re
4
+
5
+ from torchvision import transforms
6
+
7
+ from .utils import sample_video
8
+
9
+ def return_same(x):
10
+ return x
11
+
12
+ def _bbox_transform_for_padding(bbox, frame):
13
+ w1, h1, w2, h2 = bbox
14
+ width, height = frame.size
15
+ if width == height:
16
+ pass
17
+ elif width > height:
18
+ h1 += (width - height) // 2
19
+ h2 += (width - height) // 2
20
+ height = width
21
+ else:
22
+ w1 += (height - width) // 2
23
+ w2 += (height - width) // 2
24
+ width = height
25
+ new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
26
+ new_bbox = [round(i, 2) for i in new_bbox]
27
+ return new_bbox
28
+
29
+ def _bbox_transform_for_resize(bbox, frame):
30
+ w1, h1, w2, h2 = bbox
31
+ width, height = frame.size
32
+ new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
33
+ new_bbox = [round(i, 2) for i in new_bbox]
34
+ return new_bbox
35
+
36
+ class InAndOutCropAndResize(object):
37
+ """Crop and resize for in_and_out boxes data according to yuchen
38
+ Args:
39
+ size: tuple of (width, height)
40
+ """
41
+
42
+ def __init__(self, size):
43
+ self.size = size
44
+
45
+ def __call__(self, img):
46
+ """
47
+ Args:
48
+ img (PIL Image): PIL Image
49
+ Returns:
50
+ PIL Image: PIL image.
51
+ """
52
+ w = img.width
53
+ h = img.height
54
+ x0 = int(w * 0.5 - h * 0.375)
55
+ y0 = int(h * 0.125)
56
+ x1 = int(w * 0.5 + h * 0.375)
57
+ y1 = int(h * 0.875)
58
+ img = img.crop((x0, y0, x1, y1)).resize(self.size)
59
+ return img
60
+
61
+
62
+ class ObjectTrackingParser:
63
+ def __init__(
64
+ self,
65
+ n_frames = 8,
66
+ max_objects = 3,
67
+ is_training=True,
68
+ ):
69
+ self.n_frames = n_frames
70
+ self.max_objects = max_objects
71
+ self.is_training = is_training
72
+ self.img_transform = self.get_img_transform()
73
+ # fmt: off
74
+ self.data_temp = {
75
+ "video_file": "/mnt/bn/llmdatalq/jiaxin/hdvila/20230926/saved/saved_video_clips/0076/lOjn__YCec4.624.1104.mp4",
76
+ "frame_indices": [154, 157, 160, 163, 166, 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202],
77
+ "objects": {
78
+ "0": {
79
+ "phrase": "person",
80
+ "all_frame_bounding_boxes": [[2, 0, 255, 250], [17, 0, 255, 251], [35, 0, 255, 253], [44, 0, 255, 255], [52, 0, 255, 255], [54, 0, 255, 255], [63, 0, 255, 255], [60, 0, 255, 255], [54, 0, 253, 255], [43, 0, 250, 255], [36, 1, 249, 255], [36, 0, 252, 254], [41, 0, 252, 254], [61, 0, 255, 253], [68, 4, 255, 255], [74, 8, 255, 255], [91, 3, 255, 255]]
81
+ }
82
+ },
83
+ "task": "object_tracking",
84
+ "dataset": "hdvila"
85
+ }
86
+ # fmt: on
87
+
88
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
89
+ # box tracking 数据不支持 do_crop!!!
90
+ if image_processing_config.get('do_crop', False):
91
+ raise ValueError(f'do_crop is not supported in ObjectTrackingParser!')
92
+
93
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
94
+ self.check_format(data_dict, image_processing_config)
95
+
96
+ bbox_transform = _bbox_transform_for_padding if image_processing_config['do_padding'] else _bbox_transform_for_resize
97
+
98
+ # sample n_frames
99
+ if isinstance(self.n_frames, int):
100
+ n_frames = self.n_frames
101
+ else:
102
+ n_frames = random.choice(self.n_frames)
103
+ total_frames = list(range(len(data_dict['frame_indices'])))
104
+ idxs = random.sample(total_frames, min(n_frames, len(total_frames)))
105
+ idxs.sort()
106
+
107
+ frame_indices = [data_dict['frame_indices'][i] for i in idxs]
108
+ frames = sample_video(data_dict['video_file'], frame_indices=frame_indices)
109
+ img_transform = self.img_transform[data_dict['dataset']]
110
+ frames = [img_transform(f) for f in frames]
111
+
112
+ objects = []
113
+ for _, o in data_dict['objects'].items():
114
+ if o is None:
115
+ continue
116
+ all_frame_bounding_boxes = [o['all_frame_bounding_boxes'][i] for i in idxs]
117
+ all_frame_bounding_boxes_t = []
118
+ for bbox, frame in zip(all_frame_bounding_boxes, frames):
119
+ all_frame_bounding_boxes_t.append(bbox_transform(bbox, frame))
120
+ objects.append(all_frame_bounding_boxes_t)
121
+ if len(objects) >= self.max_objects:
122
+ break
123
+
124
+ prompt = "Given the bounding box coordinates of these objects in the first frame, output the bounding box coordinates in the following frames.\n{}"
125
+ response = ''
126
+
127
+ object_info = ''
128
+ for i, o in enumerate(objects):
129
+ object_info += f'object {i+1}: {o[0]}\n'
130
+ response += f'object {i+1}: {o[1:]}\n'
131
+ response = response.strip()
132
+ prompt = prompt.format(object_info)
133
+
134
+ messages = [
135
+ {
136
+ "role": "user",
137
+ "content": [
138
+ {"type": "video", "video": frames},
139
+ {"type": "text", "text": prompt}
140
+ ]
141
+ },
142
+ {
143
+ "role": "assistant",
144
+ "content": [
145
+ {"type": "text", "text": response}
146
+ ]
147
+ }
148
+ ]
149
+
150
+ return messages
151
+
152
+ def get_img_transform(self):
153
+ return {
154
+ 'webvid': return_same,
155
+ 'hdvila': transforms.Compose([
156
+ transforms.Resize(size=256),
157
+ transforms.CenterCrop(size=(256, 256))
158
+ ]),
159
+ 'hdvila_in_and_out_boxes': InAndOutCropAndResize(size=(256, 256))
160
+ }
dataset/custom_data_parsers/standard_vision_parser.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from PIL import Image
3
+ import random
4
+
5
+ from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon
6
+
7
+
8
+ class VisionParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ max_n_frames=256,
13
+ is_training=True,
14
+ video_sampling_strategy={},
15
+ ):
16
+ self.n_frames = n_frames
17
+ self.max_n_frames = max_n_frames
18
+ self.is_training = is_training
19
+ self.video_sampling_strategy = video_sampling_strategy
20
+
21
+ # fmt: off
22
+ self.data_temp = {
23
+ "messages": [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {"type": "text", "text": "Describe the image and the video."},
28
+ # 支持的 image 格式:
29
+ {"type": "image", "image": {"image_file": "/path/to/image"}},
30
+ {"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}},
31
+ # 支持的 video 格式:
32
+ {"type": "video", "video": {"video_file": "/path/to/video"}},
33
+ {"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}},
34
+ {"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}},
35
+ {"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}},
36
+ {"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}},
37
+ {"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]},
38
+ ]
39
+ },
40
+ {
41
+ "role": "assistant",
42
+ "content": [
43
+ {"type": "text","text": "xxx"}
44
+ ]
45
+ }
46
+ ],
47
+ "dataset": "LSMDC",
48
+ "task": "video/caption"
49
+ }
50
+ # fmt: on
51
+
52
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
53
+ if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False):
54
+ raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!')
55
+
56
+ """
57
+ 1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image]
58
+ 2. text 的特殊处理:调整 box;过滤面积太小的OCR
59
+ """
60
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
61
+ self.check_format(data_dict, image_processing_config)
62
+
63
+ self.set_n_frames(data_dict)
64
+
65
+ first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务
66
+
67
+ for msg in data_dict['messages']:
68
+ if isinstance(msg['content'], dict):
69
+ msg['content'] = [msg['content']]
70
+ for content in msg['content']:
71
+
72
+ if content['type'] == 'image':
73
+ content['image'] = self.load_image_item(content['image'])
74
+ if first_image is None:
75
+ first_image = content['image']
76
+ elif content['type'] == 'video':
77
+ video = self.load_video_item(content['video'])
78
+ content['video'] = video.pop('frames')
79
+ if video:
80
+ data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {})
81
+ elif content['type'] == 'text':
82
+ pass
83
+ else:
84
+ raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']")
85
+ for msg in data_dict['messages']:
86
+ for content in msg['content']:
87
+ if content['type'] == 'text':
88
+ self.postprocess_text(content, data_dict, image_processing_config, first_image)
89
+
90
+ return data_dict['messages']
91
+
92
+ # set n_frames for each vision item.
93
+ def set_n_frames(self, data_dict):
94
+
95
+ if isinstance(self.n_frames, int):
96
+ n_frames = self.n_frames
97
+ else:
98
+ n_frames = random.choice(self.n_frames)
99
+
100
+ assert n_frames <= self.max_n_frames
101
+
102
+ curr_n_frames = 0
103
+ has_dynamic = False
104
+ for msg in data_dict['messages']:
105
+ if isinstance(msg['content'], dict):
106
+ msg['content'] = [msg['content']]
107
+
108
+ for content in msg['content']:
109
+
110
+ if content['type'] == 'image':
111
+ curr_n_frames += 1
112
+ elif content['type'] == 'video':
113
+ if 'frame_indices' in content['video']:
114
+ curr_n_frames += len(content['video']['frame_indices'])
115
+ content['video']['n_frames'] = len(content['video']['frame_indices'])
116
+ elif 'time_indices' in content['video']:
117
+ curr_n_frames += len(content['video']['time_indices'])
118
+ content['video']['n_frames'] = len(content['video']['time_indices'])
119
+ elif 'min_n_frames' in content['video']:
120
+ content['video']['min_n_frames'] = int(content['video']['min_n_frames'])
121
+ curr_n_frames += content['video']['min_n_frames']
122
+ content['video']['n_frames'] = content['video']['min_n_frames']
123
+ has_dynamic = True
124
+ elif 'fps' in content['video']:
125
+ content['video']['n_frames'] = self.max_n_frames
126
+ curr_n_frames += self.max_n_frames
127
+ has_dynamic = True
128
+ else:
129
+ content['video']['n_frames'] = 0
130
+ has_dynamic = True
131
+
132
+ while curr_n_frames < n_frames and has_dynamic:
133
+ for msg in data_dict['messages']:
134
+ for content in msg['content']:
135
+ if content['type'] == 'video':
136
+ if 'frame_indices' in content['video']:
137
+ pass
138
+ elif 'time_indices' in content['video']:
139
+ pass
140
+ else:
141
+ if curr_n_frames < n_frames:
142
+ content['video']['n_frames'] += 1
143
+ curr_n_frames += 1
144
+
145
+ while curr_n_frames > self.max_n_frames and has_dynamic:
146
+ for msg in data_dict['messages']:
147
+ for content in msg['content']:
148
+ if content['type'] == 'video':
149
+ if 'frame_indices' in content['video']:
150
+ pass
151
+ elif 'time_indices' in content['video']:
152
+ pass
153
+ else:
154
+ if curr_n_frames > self.max_n_frames:
155
+ content['video']['n_frames'] -= 1
156
+ curr_n_frames -= 1
157
+
158
+
159
+ for msg in data_dict['messages']:
160
+ for content in msg['content']:
161
+ if content['type'] == 'video':
162
+ if 'frame_indices' in content['video']:
163
+ pass
164
+ elif 'time_indices' in content['video']:
165
+ pass
166
+ else:
167
+ n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
168
+ if n > 1 and content['video']['n_frames'] % n != 0:
169
+ content['video']['n_frames'] += n - content['video']['n_frames'] % n
170
+
171
+ def load_image_item(self, image_item) -> Image.Image:
172
+ """
173
+ image_item:
174
+ {"image_file": {"lq": "/path/to/image"}}
175
+ {"video_file": {"lq": "/path/to/video"}, "frame_indices": 0}
176
+ """
177
+
178
+ # check format
179
+ if ("image_file" not in image_item) and ("video_file" not in image_item):
180
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item")
181
+ if 'image_file' in image_item:
182
+ if not isinstance(image_item['image_file'], str):
183
+ raise ValueError(f"{image_item['image_file']} is not a str!")
184
+ if 'video_file' in image_item:
185
+ if not isinstance(image_item['frame_indices'], int):
186
+ raise ValueError(f"{image_item['frame_indices']} is not a int!")
187
+
188
+ if 'image_file' in image_item:
189
+ image = read_image(image_item['image_file'])
190
+ else:
191
+ frame_indices = [image_item['frame_indices']]
192
+ image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0]
193
+
194
+ return image
195
+
196
+ def load_video_item(self, video_item) -> List[Image.Image]:
197
+ """
198
+ video_item:
199
+ {"video_file": {"lq": "/path/to/video"}, "n_frames": 8}
200
+ {"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3}
201
+ {"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8}
202
+ {"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3}
203
+ {"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8}
204
+ {"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3}
205
+ """
206
+
207
+ # check format
208
+ if ("image_file" not in video_item) and ("video_file" not in video_item):
209
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
210
+
211
+ video_path = video_item.get('video_file', video_item.get('image_file'))
212
+ n_frames = video_item.get('n_frames', None)
213
+ frame_indices = video_item.get('frame_indices', None)
214
+ start_frame = video_item.get('start_frame', None)
215
+ end_frame = video_item.get('end_frame', None)
216
+ time_indices = video_item.get('time_indices', None)
217
+ start_time = video_item.get('start_time', None)
218
+ end_time = video_item.get('end_time', None)
219
+ mask_boxes = video_item.get('mask_boxes', None)
220
+ fps = video_item.get('fps', None)
221
+
222
+ frames, frame_indices = sample_video(
223
+ video_path=video_path,
224
+ frame_indices=frame_indices,
225
+ start_frame=start_frame,
226
+ end_frame=end_frame,
227
+ n_frames=n_frames,
228
+ time_indices=time_indices,
229
+ start_time=start_time,
230
+ end_time=end_time,
231
+ sampling_fps=fps,
232
+ mask_boxes=mask_boxes,
233
+ is_training=self.is_training,
234
+ video_sampling_strategy=self.video_sampling_strategy,
235
+ return_frame_ids=True,
236
+ )
237
+
238
+ if self.video_sampling_strategy.get('use_multi_images_for_video', False):
239
+ new_frames = []
240
+ for f in frames:
241
+ new_frames.extend([f, f])
242
+ frames = new_frames
243
+
244
+ if isinstance(frame_indices, dict):
245
+ return {
246
+ 'frames': frames,
247
+ 'video_info': frame_indices
248
+ }
249
+ return {'frames': frames}
250
+
251
+ def postprocess_text(self, content, data_dict, image_processing_config, first_image):
252
+ if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'):
253
+ content['text'] = adjust_bbox(content['text'], frame=first_image)
254
+ if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'):
255
+ content['text'] = filter_ocr_polygon(content['text'])
dataset/custom_data_parsers/utils.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Union
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from PIL import Image, ImageSequence
6
+ import base64
7
+ import io
8
+ import re
9
+ import uuid
10
+ import json
11
+ import numpy as np
12
+ import pyarrow.fs as pf
13
+ import func_timeout
14
+ from func_timeout import func_set_timeout
15
+ import math
16
+
17
+ # fmt: on
18
+ import decord
19
+ # fmt: off
20
+
21
+
22
+ def denorm_box(points, height, width):
23
+ new_points = []
24
+ for p in points:
25
+ new_points.append((round(p[0] * width), round(p[1] * height)))
26
+ return new_points
27
+
28
+ def process_image_for_tiktok(frames: List[Image.Image], mask_boxes):
29
+ mask_boxes = mask_boxes[:len(frames)]
30
+ frames = [np.array(f) for f in frames]
31
+ # assert len(mask_boxes) == len(frames)
32
+ height, width = frames[0].shape[:2]
33
+
34
+ new_frames = []
35
+ for boxes, frame in zip(mask_boxes, frames):
36
+ left, top, right, bottom = 0, 0, width, height
37
+ for box in boxes:
38
+ pts = np.array(denorm_box(box, height, width), np.int32)
39
+ upper_bound = max([p[1] for p in pts]) + 30
40
+ if bottom > upper_bound:
41
+ bottom = upper_bound
42
+ frame[pts[0][1]: pts[2][1], pts[0][0]: pts[1][0]] = 0
43
+
44
+ new_frames.append(Image.fromarray(frame[top: bottom, left: right]))
45
+ return new_frames
46
+
47
+ # 先将视频分成 n_frames 份。训练时,每份随机抽一帧;测试时,每份抽中间的那一帧。
48
+ def _sample_frame_indices_v2(
49
+ total_frames: int,
50
+ n_frames: int,
51
+ is_training=False,
52
+ video_sampling_strategy = {},
53
+ ):
54
+ total_frames_idxs = list(range(total_frames))
55
+ if total_frames <= n_frames:
56
+ return total_frames_idxs
57
+ k, m = divmod(total_frames, n_frames)
58
+ frame_splits = [total_frames_idxs[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(n_frames))]
59
+ if is_training:
60
+ sample_ids = [random.choice(i) for i in frame_splits]
61
+ else:
62
+ sample_ids = [i[(len(i)+1)//2-1] for i in frame_splits]
63
+ return sample_ids
64
+
65
+ # 均匀抽帧,必采样首尾帧。
66
+ def _sample_frame_indices_v1(total_frames: int, n_frames: int, is_training=False, video_sampling_strategy = {}):
67
+ if n_frames == 1:
68
+ return [0] # sample first frame in default
69
+ if total_frames <= n_frames:
70
+ return list(range(total_frames))
71
+ sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
72
+ return sample_ids
73
+
74
+ def conduct_disturb_frame(frame_indices):
75
+ disturb_type = random.choice(['exchange', 'crop', 'reverse', 'discard'])
76
+ n_frames = len(frame_indices)
77
+ frame_indices_new = []
78
+ if disturb_type == 'exchange':
79
+ # 均等分成4个segments, 随机交换两个segment
80
+ seg_len = math.ceil(n_frames / 4)
81
+ seg_idxs = list(range(0, n_frames, seg_len))
82
+ target_idxs = random.sample(range(0, 4), 2)
83
+ seg_idxs[target_idxs[0]], seg_idxs[target_idxs[1]] = seg_idxs[target_idxs[1]], seg_idxs[target_idxs[0]]
84
+ for idx in seg_idxs:
85
+ frame_indices_new += frame_indices[idx: idx+seg_len]
86
+ elif disturb_type == 'crop':
87
+ # 随机截取出3/4时长,再采均匀n_frames帧
88
+ crop_len = math.ceil(n_frames / 4)
89
+ idx_s = random.choice(range(0, crop_len+1))
90
+ idx_e = n_frames - 1 - (crop_len - idx_s)
91
+ frame_indices_new = np.linspace(frame_indices[idx_s], frame_indices[idx_e], n_frames, dtype=int).tolist()
92
+ elif disturb_type == 'reverse':
93
+ # 随机选择长度为[1/2, 1]时长的片段进行顺序颠倒
94
+ reverse_len = math.ceil(random.uniform(0.5,1) * n_frames)
95
+ idx_s = random.choice(range(0, n_frames-reverse_len+1))
96
+ idx_e = idx_s + reverse_len - 1
97
+ frame_indices_new = frame_indices[:idx_s] + list(reversed(frame_indices[idx_s: idx_e+1])) + frame_indices[idx_e+1:]
98
+ elif disturb_type == 'discard':
99
+ # 随机丢弃一半帧
100
+ frame_indices_new = random.sample(frame_indices, n_frames//2)
101
+ frame_indices_new.sort()
102
+ return disturb_type, frame_indices_new
103
+
104
+ @func_set_timeout(60)
105
+ def _download_file(path):
106
+ if path.startswith("hdfs"):
107
+ local_path = os.path.join(tempfile.gettempdir(), f'{uuid.uuid4()}_' + os.path.basename(path))
108
+
109
+ fs = pf.HadoopFileSystem.from_uri(uri="hdfs://harunava")
110
+ hdfs_file = fs.open_input_file(path)
111
+ file_size = hdfs_file.size()
112
+ if file_size > 1024 * 1024 * 1024: # 1G
113
+ os.system(f"hadoop fs -get --ct 8 -c 512 '{path}' '{local_path}' > /dev/null 2>&1")
114
+ elif file_size > 1024 * 1024 * 100: # 100M
115
+ os.system(f"hadoop fs -get '{path}' '{local_path}' > /dev/null 2>&1")
116
+ else:
117
+ local_fs = pf.LocalFileSystem()
118
+ with local_fs.open_output_stream(local_path) as local_file:
119
+ while True:
120
+ chunk = hdfs_file.read(1024 * 1024 * 100) # Reading 1MB chunks, you can adjust this as needed
121
+ if not chunk:
122
+ break
123
+ local_file.write(chunk)
124
+ else:
125
+ local_path = path
126
+
127
+ if not os.path.exists(local_path):
128
+ raise FileNotFoundError(f'{local_path}')
129
+
130
+ return local_path
131
+
132
+ def download_file(path):
133
+ try:
134
+ # with timer(f'Download {path}'):
135
+ return _download_file(path)
136
+ except func_timeout.exceptions.FunctionTimedOut as e:
137
+ raise ValueError(e)
138
+
139
+ class VideoReader:
140
+ def __init__(self, path: str) -> None:
141
+ self.path = path
142
+ self.local_path = self.preprocess()
143
+ self.vr = decord.VideoReader(self.local_path, num_threads=1, ctx=decord.cpu(0), fault_tol=1)
144
+ self.vr.seek(0)
145
+ self._length = len(self.vr)
146
+ self._fps = self.vr.get_avg_fps()
147
+
148
+ @property
149
+ def length(self):
150
+ return self._length
151
+
152
+ @property
153
+ def fps(self):
154
+ return self._fps
155
+
156
+ def sample(self, frame_indices) -> List[Image.Image]:
157
+ frames = self.vr.get_batch(frame_indices).asnumpy()
158
+ frames = [Image.fromarray(f).convert('RGB') for f in frames]
159
+ return frames
160
+
161
+ def preprocess(self):
162
+ return download_file(self.path)
163
+
164
+ def postprocess(self):
165
+ if self.path.startswith("hdfs"):
166
+ os.remove(self.local_path)
167
+
168
+ class ImageSeqReader:
169
+ def __init__(self, path: List[str]) -> None:
170
+ self.path = path
171
+ self.local_path = self.preprocess()
172
+ self._length = len(self.local_path)
173
+ self._fps = None
174
+
175
+ @property
176
+ def length(self):
177
+ return self._length
178
+
179
+ @property
180
+ def fps(self):
181
+ return self._fps
182
+
183
+ def sample(self, frame_indices):
184
+ return [read_image(self.local_path[i]) for i in frame_indices]
185
+
186
+ def preprocess(self):
187
+ local_paths = []
188
+ for p in self.path:
189
+ local_paths.append(p)
190
+ return local_paths
191
+
192
+ def postprocess(self):
193
+ pass
194
+
195
+ class GIFReader:
196
+ def __init__(self, path: str) -> None:
197
+ self.path = path
198
+ self.local_path = self.preprocess()
199
+ self.gif = Image.open(self.local_path)
200
+ self._length = self.gif.n_frames
201
+ duration = self.gif.info.get('duration', 0) / 1000 # 转换为秒
202
+ if duration > 0:
203
+ self._fps = 1 / duration
204
+ else:
205
+ self._fps = None
206
+
207
+ @property
208
+ def length(self):
209
+ return self._length
210
+
211
+ @property
212
+ def fps(self):
213
+ return self._fps
214
+
215
+ def sample(self, frame_indices):
216
+ frames = []
217
+ i = 0
218
+ for frame in ImageSequence.Iterator(self.gif):
219
+ if i in frame_indices:
220
+ frames.append(frame.convert('RGB'))
221
+ i += 1
222
+ return frames
223
+
224
+ def preprocess(self):
225
+ return download_file(self.path)
226
+
227
+ def postprocess(self):
228
+ if self.path.startswith("hdfs"):
229
+ os.remove(self.local_path)
230
+
231
+ def check_frame_indices(frame_indices, total_frames, video_path):
232
+ if frame_indices[-1] == total_frames:
233
+ frame_indices[-1] = total_frames - 1
234
+
235
+ valid_frame_indices = [i for i in frame_indices if i >= 0 and i < total_frames]
236
+
237
+ if len(valid_frame_indices) != len(frame_indices):
238
+ print(f'[Error] frame out of index. video_path={video_path}, frame_indices={frame_indices}, total_frames={total_frames}', flush=True)
239
+
240
+ return valid_frame_indices
241
+
242
+
243
+ def sample_video(
244
+ video_path: Union[str, List[str]],
245
+ frame_indices: List[int] = None,
246
+ start_frame:int=None,
247
+ end_frame:int=None,
248
+ n_frames:int = None,
249
+ time_indices: List[float] = None,
250
+ start_time:int=None,
251
+ end_time:int=None,
252
+ sampling_fps:float=None,
253
+ mask_boxes=None,
254
+ is_training:bool=False,
255
+ video_sampling_strategy={'video_sampler_version': 'v1'},
256
+ return_frame_ids: bool=False,
257
+ ) -> List[Image.Image]:
258
+
259
+ do_frame_disturb = video_sampling_strategy.get('do_frame_disturb', False)
260
+
261
+ if isinstance(video_path, str):
262
+ if video_path.endswith('.gif'):
263
+ reader = GIFReader(video_path)
264
+ else:
265
+ reader = VideoReader(video_path)
266
+ else:
267
+ reader = ImageSeqReader(video_path)
268
+
269
+ total_frames = reader.length
270
+ fps = reader.fps
271
+
272
+ if sampling_fps is not None:
273
+ frame_indices = list(range(0, total_frames, round(fps / sampling_fps)))
274
+ if len(frame_indices) > n_frames:
275
+ frame_indices = None
276
+
277
+ if time_indices is not None:
278
+ frame_indices = [round(float(i) * fps) for i in time_indices]
279
+
280
+ if start_time is not None and end_time is not None:
281
+ start_frame = round(start_time * fps)
282
+ end_frame = round(end_time * fps)
283
+
284
+ if frame_indices is None:
285
+ start_frame = 0 if start_frame is None else round(start_frame)
286
+ end_frame = total_frames - 1 if end_frame is None else round(end_frame)
287
+
288
+ if end_frame == total_frames:
289
+ end_frame -= 1
290
+
291
+ if video_sampling_strategy['video_sampler_version'] == 'v1':
292
+ # 均匀抽帧,必采样首尾帧。
293
+ frame_indices = _sample_frame_indices_v1(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
294
+ elif video_sampling_strategy['video_sampler_version'] == 'v2':
295
+ frame_indices = _sample_frame_indices_v2(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
296
+ else:
297
+ raise ValueError(f"video_sampler_version={video_sampling_strategy['video_sampler_version']} must be 'v1' or 'v2'")
298
+ frame_indices = [i + start_frame for i in frame_indices]
299
+
300
+ frame_indices = check_frame_indices(frame_indices, total_frames, video_path)
301
+
302
+ if do_frame_disturb:
303
+ frame_disturb_type, frame_indices_new = conduct_disturb_frame(frame_indices)
304
+ frame_indices_raw = frame_indices[:]
305
+ frame_indices = frame_indices_new
306
+
307
+ frames = reader.sample(frame_indices)
308
+ if mask_boxes is not None:
309
+ frames = process_image_for_tiktok(frames, mask_boxes)
310
+
311
+ n = video_sampling_strategy.get('force_frames_n_divisible', 1)
312
+ if n > 1 and len(frames) % n != 0:
313
+ new_n = n - len(frames) % n
314
+ frames.extend([Image.new(mode='RGB', size=frames[-1].size) for _ in range(new_n)])
315
+
316
+ reader.postprocess()
317
+
318
+ if do_frame_disturb:
319
+ return frames, {"frame_indices": frame_indices, "disturb_type": frame_disturb_type, "frame_indices_raw": frame_indices_raw}
320
+ if return_frame_ids:
321
+ return frames, frame_indices
322
+ return frames
323
+
324
+
325
+
326
+ def load_image_from_base64String(img_path):
327
+ img = base64.b64decode(open(img_path, "rb").read())
328
+ buf = io.BytesIO(img)
329
+ img = Image.open(buf)
330
+ return img
331
+
332
+ def read_image(image_path):
333
+ local_file = download_file(image_path)
334
+
335
+ if local_file.endswith('.dat'):
336
+ image = load_image_from_base64String(local_file)
337
+ else:
338
+ image = Image.open(local_file).convert('RGB')
339
+ if image_path.startswith("hdfs"):
340
+ os.remove(local_file)
341
+ return image
342
+
343
+
344
+ def adjust_bbox(text, frame):
345
+
346
+ width, height = frame.size
347
+ new_text = []
348
+ start_idx = 0
349
+ for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text):
350
+ coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
351
+ xys = [float(coord) for coord in coordinate_matches]
352
+
353
+ new_xys = []
354
+ for i in range(len(xys)):
355
+ p = xys[i]
356
+
357
+ if width == height:
358
+ pass
359
+
360
+ if width > height and i % 2 != 0:
361
+ p = xys[i] * height
362
+ p += (width - height) // 2
363
+ p = round(p / width, 2)
364
+
365
+ if height > width and i % 2 == 0:
366
+ p = xys[i] * width
367
+ p += (height - width) // 2
368
+ p = round(p / height, 2)
369
+
370
+ new_xys.append(p)
371
+
372
+ new_text.append(text[start_idx: match.span()[0]])
373
+ new_text.append(str(new_xys))
374
+ start_idx = match.span()[1]
375
+ new_text.append(text[start_idx: ])
376
+ text = ''.join(new_text)
377
+
378
+
379
+ return text
380
+
381
+ def bbox_area(vertices, convert_format = True):
382
+ if convert_format:
383
+ vertices = list(zip(vertices[::2], vertices[1::2]))
384
+ x0, y0 = vertices[0]
385
+ x1, y1 = vertices[1]
386
+ return abs((x1 - x0) * (y1 - y0))
387
+
388
+ def polygon_area(vertices, convert_format = True):
389
+ if convert_format:
390
+ vertices = list(zip(vertices[::2], vertices[1::2]))
391
+ n = len(vertices) # 多边形顶点的数量
392
+ if n == 2:
393
+ return bbox_area(vertices, convert_format=False)
394
+ area = 0
395
+ for i in range(n):
396
+ x1, y1 = vertices[i]
397
+ x2, y2 = vertices[(i + 1) % n]
398
+ area += x1 * y2 - x2 * y1
399
+ return abs(area) / 2
400
+
401
+ def get_text_len(text_line):
402
+ l = 0
403
+ for c in text_line:
404
+ if '\u4e00' <= c <= '\u9fff':
405
+ l += 1
406
+ else:
407
+ l += 0.5
408
+ return l
409
+
410
+ def filter_ocr_polygon(response, area_threshold=0.0005):
411
+ try:
412
+ resp = json.loads(response)
413
+ except:
414
+ return response
415
+ new_resp = []
416
+ for coords, text_line in resp:
417
+ area = polygon_area(coords, convert_format=True)
418
+ text_len = get_text_len(text_line)
419
+ if text_len == 0:
420
+ continue
421
+ if area / text_len < area_threshold:
422
+ continue
423
+ new_resp.append([coords, text_line])
424
+ new_resp = json.dumps(new_resp, ensure_ascii=False)
425
+
426
+ return new_resp
427
+
428
+ def put_pred_to_data_dict(prediction, data_dict):
429
+ msg = data_dict['messages'][-1]
430
+ if msg['role'] == 'assistant':
431
+ msg['content'][-1]['text'] = prediction
432
+ else:
433
+ data_dict['messages'].append({
434
+ "role": "assistant",
435
+ "content": [{"type": "text", "text": prediction}]
436
+ })
437
+
438
+ def get_prompt_from_data_dict(data_dict):
439
+ prompt = ""
440
+ for msg in data_dict['messages']:
441
+ role = msg['role']
442
+ assert role in {'system', 'user', 'assistant'}
443
+ for content in msg['content']:
444
+ if content['type'] == 'text':
445
+ if content['text']:
446
+ prompt += f"[{role}]: {content['text']}"
447
+ elif content['type'] == 'image':
448
+ prompt += f"[{role}]: <image>"
449
+ elif content['type'] == 'video':
450
+ prompt += f"[{role}]: <video>"
451
+ prompt += '\n'
452
+ return prompt
dataset/custom_data_parsers/utils_visualize.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, List, Optional
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+
6
+ def scale_polygon(polygon, w, h):
7
+ new_polygon = []
8
+ for (x, y) in polygon:
9
+ new_polygon.append((x * w, y * h))
10
+ return new_polygon
11
+
12
+ def draw_polygon(image: Image.Image, points: List[List[int]], label: Optional[str] = None):
13
+ draw = ImageDraw.Draw(image)
14
+ if len(points) > 2:
15
+ draw.polygon(points, outline="red", width=3)
16
+ elif len(points) == 2:
17
+ draw.rectangle(points, outline="red", width=3)
18
+ else:
19
+ raise ValueError(f'points={points} only has one point!')
20
+
21
+ if label is not None:
22
+ font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 20)
23
+ draw.text(points[0], label, font=font, fill=(0, 0, 255))
24
+ return image
25
+
26
+ def visualize_image_bbox(data_dict, image_processing_config, processor):
27
+ if image_processing_config.get('has_coordinates') != True:
28
+ return
29
+
30
+ messages = data_dict['messages']
31
+
32
+ polygons = []
33
+ first_image_content = None
34
+
35
+ for msg in messages:
36
+ for content in msg['content']:
37
+ if content['type'] == 'text':
38
+ for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', content["text"]):
39
+ coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
40
+ coords = [float(coord) for coord in coordinate_matches]
41
+ polygons.append(list(zip(coords[::2], coords[1::2])))
42
+ elif first_image_content is None and content['type'] == 'image':
43
+ first_image_content = content
44
+
45
+ first_image = first_image_content['image']
46
+ first_image = processor.preprocess_image(first_image, image_processing_config)
47
+ w, h = first_image.size
48
+
49
+ if len(polygons) > 0:
50
+ for i, polygon in enumerate(polygons):
51
+ polygon = scale_polygon(polygon, w, h)
52
+ first_image = draw_polygon(first_image, polygon, label=str(i))
53
+
54
+ first_image_content['image'] = first_image
dataset/custom_data_parsers/video_permutation_parser.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import random
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+ from .utils import sample_video
6
+
7
+
8
+ class VideoPermutationParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ is_training=True,
13
+ frame_nums = list(range(8, 25)),
14
+ video_sampling_strategy={},
15
+ ):
16
+ self.n_frames = n_frames
17
+ self.is_training = is_training
18
+ self.frame_nums = frame_nums
19
+ self.video_sampling_strategy = video_sampling_strategy
20
+ # fmt: off
21
+ self.data_temp = {
22
+ "text": [{
23
+ "prompt": "<video>",
24
+ "response": ""
25
+ }],
26
+ "video": [{
27
+ "video_file": {
28
+ "yg": "/mnt/bn/videonasyg/videos/webvid_10M_download/011851_011900/1047443473.mp4",
29
+ "lq": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4"
30
+ },
31
+ "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
32
+ }],
33
+ }
34
+ # fmt: on
35
+
36
+ def check_format(self, data_dict: Dict):
37
+ pass
38
+ # for k in self.data_temp.keys():
39
+ # assert k in data_dict
40
+
41
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
42
+ self.check_format(data_dict)
43
+
44
+ frames = self.load_video_item(data_dict['video'][0])
45
+
46
+ # frames = self.add_text_to_frames(frames) # for debug
47
+
48
+ idxs = list(range(1, len(frames) + 1))
49
+ random.shuffle(idxs)
50
+
51
+ prefix_len = int(3/8*len(idxs))
52
+
53
+ shuffled_frames = [frames[i-1] for i in idxs]
54
+
55
+ prompt = f'Output the correct chronological order of scrambled video frames. The order of the first {prefix_len} ones are:\n'
56
+ prompt += '\n'.join([str(i) for i in idxs[: prefix_len]]) + '\nOutput the order of the following frames:'
57
+ response = '\n'.join([str(i) for i in idxs[prefix_len: ]])
58
+
59
+ messages = [
60
+ {
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "video", "video": shuffled_frames},
64
+ {"type": "text", "text": prompt},
65
+ ]
66
+ },
67
+ {
68
+ "role": "assistant",
69
+ "content": [
70
+ {"type": "text", "text": response}
71
+ ]
72
+ }
73
+ ]
74
+
75
+ return messages
76
+
77
+
78
+ def load_video_item(self, video_item) -> List[Image.Image]:
79
+ """
80
+ video_item:
81
+ {"video_file": "/path/to/video", "n_frames": 8}
82
+ {"video_file": "/path/to/video", "frame_indices": [0, 1, 2], "n_frames": 3}
83
+ {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100, "n_frames": 8}
84
+ {"video_file": "/path/to/video", "time_indices": [0, 1, 2], "n_frames": 3}
85
+ {"video_file": "/path/to/video", "start_time": 0, "end_time": 100, "n_frames": 8}
86
+ {"image_file": ["/path/to/image"], "frame_indices": [0, 1, 2], "n_frames": 3}
87
+ """
88
+
89
+ # check format
90
+ if ("image_file" not in video_item) and ("video_file" not in video_item):
91
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
92
+
93
+ video_path = video_item.get('video_file', video_item.get('image_file'))
94
+ n_frames = video_item.get('n_frames', None)
95
+ frame_indices = video_item.get('frame_indices', None)
96
+ start_frame = video_item.get('start_frame', None)
97
+ end_frame = video_item.get('end_frame', None)
98
+ time_indices = video_item.get('time_indices', None)
99
+ start_time = video_item.get('start_time', None)
100
+ end_time = video_item.get('end_time', None)
101
+ mask_boxes = video_item.get('mask_boxes', None)
102
+
103
+ n_frames = random.choice(self.frame_nums)
104
+ n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
105
+ if n > 1 and n_frames % n != 0:
106
+ n_frames += n - n_frames % n
107
+
108
+ frames, frame_indices = sample_video(
109
+ video_path=video_path,
110
+ frame_indices=frame_indices,
111
+ start_frame=start_frame,
112
+ end_frame=end_frame,
113
+ n_frames=n_frames,
114
+ time_indices=time_indices,
115
+ start_time=start_time,
116
+ end_time=end_time,
117
+ mask_boxes=mask_boxes,
118
+ is_training=self.is_training,
119
+ video_sampling_strategy=self.video_sampling_strategy,
120
+ return_frame_ids=True,
121
+ )
122
+ return frames
123
+
124
+
125
+ def add_text_to_frames(self, frames: List[Image.Image]):
126
+ new_frames = []
127
+ for i, image in enumerate(frames):
128
+ draw = ImageDraw.Draw(image)
129
+
130
+ font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 100)
131
+ text_position = (50, 50)
132
+ text_content = f'{i+1}'
133
+ text_color = (255, 0, 0)
134
+ draw.text(text_position, text_content, font=font, fill=text_color)
135
+ new_frames.append(image)
136
+ return new_frames
137
+
dataset/mm_dataset.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataset.utils import get_visual_type, sample_frame_indices
15
- from .processor import Processor
16
- from tools.rw_utils import read_jsonlines
17
-
18
- class MMDataset(object):
19
- def __init__(self, ann_path="", anns=None, processor:Processor=None):
20
- self.processor = processor
21
- if anns is None:
22
- self.anns = []
23
- if isinstance(ann_path, str):
24
- ann_path = [ann_path]
25
- for path in ann_path:
26
- self.anns.extend(read_jsonlines(path))
27
- else:
28
- self.anns = anns
29
-
30
- def __len__(self):
31
- return len(self.anns)
32
-
33
- def __getitem__(self, index):
34
- try:
35
- ann = self.anns[index]
36
-
37
- prompt = ann['text']['prompt']
38
-
39
- video_file = ann['video_file']
40
- visual_files = []
41
- start_time = ann.get("start_time", 0)
42
- end_time = ann.get("end_time", -1)
43
- if isinstance(video_file, list):
44
- # This is for MVBench/Episodic Reasoning
45
- # The video_file are a list of sorted frames extract from the target video
46
- for img_file in video_file:
47
- if get_visual_type(img_file) == 'image':
48
- visual_files.append(img_file)
49
- frame_indices = sample_frame_indices(start_frame=0, total_frames=len(visual_files), n_frames=min(len(visual_files), self.processor.max_n_frames))
50
- visual_files = [v for i,v in enumerate(visual_files) if i in frame_indices]
51
- else:
52
- if get_visual_type(video_file) in ['image', 'video', 'gif']:
53
- visual_files.append(video_file)
54
- assert len(visual_files) >= 0, f"Failed to load valid visual file from anns[{index}]!"
55
- images = []
56
- for v_f in visual_files:
57
- images.extend(self.processor.load_images(v_f, start_time=start_time, end_time=end_time))
58
- model_inputs = self.processor(prompt, images=images, edit_prompt=True, return_prompt=True)
59
- except Exception as e:
60
- print(f"Load data error: {e}")
61
- return ann, None
62
- return ann, model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/processor.py DELETED
@@ -1,164 +0,0 @@
1
- # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from PIL import Image
15
- from typing import List
16
- import torch
17
- from transformers import DataCollatorForSeq2Seq
18
- from transformers.models.llava import LlavaProcessor
19
- import re
20
- import os
21
-
22
- from .utils import sample_image, sample_video, sample_gif, get_visual_type
23
-
24
- HF_TOKEN = os.environ.get('HF_TOKEN', '')
25
-
26
- ext2sampler = {
27
- 'image': sample_image,
28
- 'gif': sample_gif,
29
- 'video': sample_video
30
- }
31
-
32
- class CustomImageProcessor:
33
- def __init__(self, processor) -> None:
34
- self.processor = processor
35
-
36
- def __call__(self, images: List[Image.Image], do_padding=False) -> torch.Tensor:
37
- if do_padding:
38
- images = [self.expand2square(
39
- img,
40
- tuple(int(x * 255) for x in self.processor.image_processor.image_mean)
41
- ) for img in images]
42
- else:
43
- images = [self.resize2square(img) for img in images]
44
- images_pixel = self.processor(text="", images=images, return_tensors="pt")['pixel_values']
45
- return images_pixel # [num_images, 3, 336, 336]
46
-
47
- def expand2square(self, pil_img, background_color):
48
- width, height = pil_img.size
49
- if width == height:
50
- return pil_img
51
- elif width > height:
52
- result = Image.new(pil_img.mode, (width, width), background_color)
53
- result.paste(pil_img, (0, (width - height) // 2))
54
- return result
55
- else:
56
- result = Image.new(pil_img.mode, (height, height), background_color)
57
- result.paste(pil_img, ((height - width) // 2, 0))
58
- return result
59
-
60
- def resize2square(self, pil_img: Image.Image):
61
- width, height = pil_img.size
62
- pil_img = pil_img.resize((max(width, height), max(width, height)))
63
- return pil_img
64
-
65
- class Processor(object):
66
- def __init__(
67
- self,
68
- model_name_or_path,
69
- max_n_frames=8,
70
- max_seq_len=None,
71
- add_sep=False,
72
- do_image_padding=False,
73
- ):
74
- self.max_n_frames = max_n_frames
75
- self.max_seq_len = max_seq_len,
76
- self.add_sep = add_sep
77
- self.do_image_padding = do_image_padding
78
- if not self.do_image_padding:
79
- print(f"### do_image_padding is set as False, images will be resized directly!")
80
-
81
- self.setup(model_name_or_path)
82
-
83
-
84
- def setup(self, model_name_or_path):
85
- sub_processor = LlavaProcessor.from_pretrained(
86
- model_name_or_path,
87
- padding_side='left',
88
- trust_remote_code=True,
89
- token=HF_TOKEN,
90
- )
91
- self.processor = CustomImageProcessor(sub_processor)
92
- self.tokenizer = sub_processor.tokenizer
93
- # self.pad_collator = DataCollatorForSeq2Seq(self.tokenizer, padding='longest')
94
- self.sep_id = self.tokenizer.sep_token_id
95
- self.pad_id = self.tokenizer.pad_token_id
96
- self.eos_id = self.tokenizer.eos_token_id
97
-
98
- if self.sep_id is None:
99
- self.add_sep = False
100
- if not self.max_seq_len:
101
- self.max_seq_len = self.tokenizer.model_max_length
102
-
103
- def process_prompt(self, prompt, images: List[Image.Image]=None):
104
- if not images:
105
- prompt = prompt.replace("<image>", "").replace("<video>", "")
106
- elif images is not None:
107
- prompt = prompt.replace("<video>", "<image>"*len(images))
108
- image_token_num = len(re.findall('<image>', prompt, re.S))
109
- if image_token_num == 0:
110
- prompt_parts = re.findall(r'USER:(.*)ASSISTANT:(.*)', prompt, re.S)
111
- if prompt_parts and len(prompt_parts) == 2:
112
- p1, p2 = prompt_parts
113
- else:
114
- p1 = prompt
115
- p2 = ''
116
- prompt = f"USER: {'<image>'*len(images) + ' ' + p1.strip()} ASSISTANT: {p2.strip()}"
117
- assert image_token_num == len(images)
118
-
119
- if not re.findall(r'USER:(.*)ASSISTANT:(.*)', prompt, re.S):
120
- prompt = f'USER: {prompt} ASSISTANT: '
121
- return prompt
122
-
123
- def select_frames_sampler(self, visual_data_path):
124
- visual_type = get_visual_type(visual_data_path)
125
- if visual_type in ext2sampler:
126
- return ext2sampler[visual_type]
127
- else:
128
- raise ValueError(f"Unsupported data format: {visual_data_path}")
129
-
130
- def load_images(self, visual_data_path, n_frames=None, start_time=0, end_time=-1):
131
- sampler = self.select_frames_sampler(visual_data_path)
132
- return sampler(visual_data_path, n_frames=min(n_frames, self.max_n_frames) if n_frames else self.max_n_frames, start_time=start_time, end_time=end_time)
133
-
134
- def get_pixel_values(self, images):
135
- if images is not None and len(images) > 0:
136
- pixel_values = self.processor(images=images, do_padding=self.do_image_padding)
137
- else:
138
- pixel_values = None
139
- return pixel_values
140
-
141
- def get_text_inputs(self, text):
142
- prompt_ids = self.tokenizer.encode(text, add_special_tokens=True) # will add <s>
143
- if self.add_sep:
144
- prompt_ids = prompt_ids + [self.sep_id]
145
- prompt_ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(dim=0)
146
- return prompt_ids
147
-
148
- def get_inputs(self, prompt, visual_data_file=None, images=None, n_frames=None, edit_prompt=False, return_prompt=False):
149
- if images is None:
150
- images = self.load_images(visual_data_file, n_frames) if visual_data_file else None
151
- if edit_prompt:
152
- prompt = self.process_prompt(prompt, images)
153
- text_inputs = self.get_text_inputs(prompt)
154
- pixel_values = self.get_pixel_values(images)
155
- inputs = {
156
- "input_ids": text_inputs,
157
- "pixel_values": pixel_values
158
- }
159
- if return_prompt:
160
- inputs['prompt'] = prompt
161
- return inputs
162
-
163
- def __call__(self, prompt, visual_data_file=None, images=None, n_frames=None, edit_prompt=False, return_prompt=False):
164
- return self.get_inputs(prompt, visual_data_file, images, n_frames, edit_prompt, return_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/tarsier_datamodule.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Datamodule for Llava Pretraining and Finetuning"""
2
+ import os
3
+ import re
4
+ from PIL import Image
5
+ import numpy as np
6
+ import re
7
+ import tempfile
8
+ from typing import Dict, List, Union, Tuple
9
+ import traceback
10
+ import json
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from transformers import DataCollatorForSeq2Seq
15
+
16
+ from tools.rw_utils import read_jsonlines
17
+ from torch.utils.data import Dataset, DataLoader
18
+
19
+ np_str_obj_array_pattern = re.compile(r"[SaUO]")
20
+
21
+ default_collate_err_msg_format = (
22
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
23
+ "dicts or lists; found {}"
24
+ )
25
+
26
+ from .custom_data_parsers.standard_vision_parser import VisionParser
27
+ from .custom_data_parsers.object_tracking_parser import ObjectTrackingParser
28
+ from .custom_data_parsers.multi_images_parser import MultiImagesParser
29
+ from .custom_data_parsers.video_permutation_parser import VideoPermutationParser
30
+ from .custom_data_parsers.utils_visualize import visualize_image_bbox
31
+
32
+ from .tarsier_processor import TarsierProcessor
33
+
34
+ from tools.rw_utils import NumpyArrayEncoder
35
+ from .utils import DictToObject
36
+ import os
37
+
38
+ HF_TOKEN = os.environ.get('HF_TOKEN', '')
39
+
40
+ class TarsierDataProcessor:
41
+ def __init__(
42
+ self,
43
+ processor: TarsierProcessor,
44
+ n_frames: Union[int, list],
45
+ max_n_frames=256,
46
+ max_pixels=int(1280 * 720 // 2),
47
+ min_pixels=0,
48
+ max_seq_len=None,
49
+ is_training=True, # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
50
+ print_data_error=True,
51
+ do_image_padding=False,
52
+ do_image_crop=False,
53
+ do_image_resize=True,
54
+ video_sampling_strategy={},
55
+ prompt='',
56
+ train_task='sft',
57
+ **kwargs
58
+ ):
59
+ self.kwargs = kwargs
60
+
61
+ self.processor = processor
62
+ self.pad_collator = DataCollatorForSeq2Seq(processor.tokenizer, padding='longest')
63
+
64
+ self.processor.max_seq_len = self.tokenizer.model_max_length if max_seq_len is None else max_seq_len
65
+
66
+ self.n_frames = n_frames
67
+ self.max_n_frames = max_n_frames
68
+ self.max_pixels = max_pixels
69
+ self.min_pixels = min_pixels
70
+
71
+ self.is_training = is_training
72
+ self.print_data_error = print_data_error
73
+ self.do_image_padding = do_image_padding
74
+ self.do_image_crop = do_image_crop
75
+ self.do_image_resize = do_image_resize
76
+ self.video_sampling_strategy = video_sampling_strategy
77
+ self.prompt = prompt
78
+ self.train_task = train_task
79
+
80
+ self.object_tracking_parser = ObjectTrackingParser(
81
+ n_frames=self.n_frames,
82
+ max_objects=4,
83
+ is_training=self.is_training,
84
+ )
85
+ self.multi_images_parser = MultiImagesParser(
86
+ n_frames=self.n_frames,
87
+ is_training=self.is_training,
88
+ )
89
+ self.video_permutation_parser = VideoPermutationParser(
90
+ n_frames=self.n_frames,
91
+ is_training=self.is_training,
92
+ video_sampling_strategy=self.video_sampling_strategy,
93
+ )
94
+ self.vision_parser = VisionParser(
95
+ n_frames=self.n_frames,
96
+ max_n_frames=self.max_n_frames,
97
+ is_training=self.is_training,
98
+ video_sampling_strategy=self.video_sampling_strategy
99
+ )
100
+
101
+ def select_parser(self, data_dict):
102
+ if data_dict.get('task', None) == 'video/object_tracking':
103
+ return self.object_tracking_parser
104
+ elif data_dict.get('task', None) == 'multi_images':
105
+ return self.multi_images_parser
106
+ elif data_dict.get('dataset', None) == 'video_permutation':
107
+ return self.video_permutation_parser
108
+ else:
109
+ return self.vision_parser
110
+
111
+ def parse_image_processing_config(self, data_dict):
112
+ image_processing_config=data_dict.get('image_processing_config', {})
113
+
114
+ do_padding = image_processing_config.get('do_padding', self.do_image_padding)
115
+ do_crop = image_processing_config.get('do_crop', self.do_image_crop)
116
+ do_resize = image_processing_config.get('do_resize', self.do_image_resize)
117
+ max_pixels = image_processing_config.get('max_pixels', self.max_pixels)
118
+ min_pixels = image_processing_config.get('min_pixels', self.min_pixels)
119
+
120
+ assert min_pixels <= max_pixels
121
+
122
+ image_processing_config['do_padding'] = do_padding
123
+ image_processing_config['do_crop'] = do_crop
124
+ image_processing_config['do_resize'] = do_resize
125
+ image_processing_config['max_pixels'] = max_pixels
126
+ image_processing_config['min_pixels'] = min_pixels
127
+
128
+ return image_processing_config
129
+
130
+
131
+ def _transform(self, raw_data_dict: Dict) -> Dict:
132
+ data_dict = json.loads(json.dumps(raw_data_dict, cls=NumpyArrayEncoder))
133
+ del raw_data_dict
134
+
135
+ if self.prompt:
136
+ for msg in data_dict['messages']:
137
+ if msg['role'] == 'user':
138
+ for content in msg['content']:
139
+ if content['type'] == 'text':
140
+ content['text'] = self.prompt
141
+
142
+ data_dict_copy = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
143
+
144
+ image_processing_config = self.parse_image_processing_config(data_dict)
145
+ parser = self.select_parser(data_dict)
146
+ messages = parser.transform(data_dict, image_processing_config)
147
+ data_dict_copy['extra_info'] = data_dict.pop('extra_info', {})
148
+
149
+ # visualize_image_bbox(data_dict, image_processing_config, self.processor)
150
+ outputs = self.processor(messages, image_processing_config, is_training=self.is_training)
151
+
152
+ # if not self.is_training:
153
+ outputs['raw_data_dict'] = data_dict_copy
154
+
155
+ return [outputs]
156
+
157
+ def _split_chosen_rejected(self, data_dict: Dict):
158
+ chosen_data_dict = data_dict
159
+ rejected_data_dict = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
160
+ for msg in chosen_data_dict['messages']:
161
+ if msg['role'] == 'assistant':
162
+ for content in msg['content']:
163
+ if content['type'] == 'text':
164
+ content['text'] = content['chosen']
165
+
166
+ for msg in rejected_data_dict['messages']:
167
+ if msg['role'] == 'assistant':
168
+ for content in msg['content']:
169
+ if content['type'] == 'text':
170
+ content['text'] = content['rejected']
171
+
172
+ return chosen_data_dict, rejected_data_dict
173
+
174
+ def transform(self, data_dict: Dict) -> Dict:
175
+ try:
176
+ if self.train_task == 'dpo':
177
+ chosen_data_dict, rejected_data_dict = self._split_chosen_rejected(data_dict)
178
+ return self._transform(chosen_data_dict) + self._transform(rejected_data_dict)
179
+ return self._transform(data_dict)
180
+ except Exception as e:
181
+ if self.print_data_error:
182
+ print(traceback.format_exc())
183
+ print(f'Error occurs when processing: \n{data_dict}')
184
+ return []
185
+
186
+ def batch_transform(self, batch_data: List[Dict]) -> Dict:
187
+ model_inputs = {}
188
+ # if not self.is_training:
189
+ raw_data_dict = [d.pop('raw_data_dict') for d in batch_data]
190
+ model_inputs['raw_data_dict'] = raw_data_dict
191
+
192
+ batch_pixel_values = [d.pop('pixel_values') for d in batch_data if 'pixel_values' in d]
193
+ batch_image_grid_thw = [d.pop('image_grid_thw') for d in batch_data if 'image_grid_thw' in d]
194
+ if len(batch_pixel_values) == 0:
195
+ vision_placeholder = self.get_vision_placeholder()
196
+ batch_pixel_values = [vision_placeholder.get('pixel_values')]
197
+ batch_image_grid_thw = [vision_placeholder.get('image_grid_thw')] if 'image_grid_thw' in vision_placeholder else []
198
+
199
+ model_inputs['pixel_values'] = torch.cat(batch_pixel_values, dim=0)
200
+ if len(batch_image_grid_thw) > 0:
201
+ model_inputs['image_grid_thw'] = torch.cat(batch_image_grid_thw, dim=0)
202
+
203
+ batch_num_images = [d.pop('num_images') for d in batch_data]
204
+ model_inputs['num_images'] = torch.tensor(batch_num_images)
205
+ model_inputs.update(self.pad_collator(batch_data))
206
+ return model_inputs
207
+
208
+ def __call__(self, batch_data: Union[Dict, List[Dict]]) -> Dict:
209
+ if isinstance(batch_data, dict):
210
+ batch_data = [batch_data]
211
+ batch = [self.transform(d)[0] for d in batch_data]
212
+ return self.batch_transform(batch)
213
+
214
+ def get_vision_placeholder(self):
215
+ messages = [{"role": "user", "content": [{"type": "image", "image": Image.new(mode='RGB', size=(336, 336))}]}]
216
+ image_processing_config = self.parse_image_processing_config({})
217
+ return self.processor(messages, image_processing_config)
218
+
219
+ def get_text_placeholder(self):
220
+ messages = [
221
+ {"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
222
+ {"role": "assistant", "content": [{"type": "text", "text": "Thank you very much"}]},
223
+ ]
224
+ image_processing_config = self.parse_image_processing_config({})
225
+ return self.processor(messages, image_processing_config)
226
+
227
+ def init_processor(processor: Union[TarsierProcessor, str]=None, config: Dict=None):
228
+ config = DictToObject(config) if isinstance(config, dict) else config
229
+ if isinstance(processor, str):
230
+ sub_processor = TarsierProcessor.from_pretrained(
231
+ processor,
232
+ padding_side='left',
233
+ trust_remote_code=True,
234
+ token=HF_TOKEN,
235
+ )
236
+ else:
237
+ sub_processor = processor
238
+ processor = TarsierDataProcessor(
239
+ processor=sub_processor,
240
+ n_frames=config.n_frames,
241
+ max_n_frames=config.max_n_frames,
242
+ max_pixels=config.max_pixels,
243
+ min_pixels=config.min_pixels,
244
+ max_seq_len=config.max_seq_len,
245
+ is_training=config.is_training,
246
+ print_data_error=config.print_data_error,
247
+ do_image_padding=config.do_image_padding,
248
+ do_image_crop=config.do_image_crop,
249
+ do_image_resize=config.do_image_resize,
250
+ video_sampling_strategy=config.video_sampling_strategy,
251
+ prompt=config.prompt,
252
+ train_task=config.train_task
253
+ )
254
+ return processor
255
+
256
+ class TarsierDataset(Dataset):
257
+ def __init__(self, ann_path="", anns=None, config: Dict=None, processor: Union[TarsierDataProcessor, TarsierProcessor, str]=None):
258
+ self.config = DictToObject(config) if isinstance(config, dict) else config
259
+ if not isinstance(processor, TarsierDataProcessor):
260
+ self.processor = init_processor(processor, config)
261
+ else:
262
+ self.processor = processor
263
+ if anns is None:
264
+ self.anns = []
265
+ if isinstance(ann_path, str):
266
+ ann_path = [ann_path]
267
+ for path in ann_path:
268
+ self.anns.extend(read_jsonlines(path))
269
+ else:
270
+ self.anns = anns
271
+
272
+ def __len__(self):
273
+ return len(self.anns)
274
+
275
+ def __getitem__(self, index):
276
+ if index < 0 or index >= len(self.anns):
277
+ raise IndexError("Index out of range")
278
+ try:
279
+ ann = self.anns[index]
280
+ model_inputs = self.processor(ann)
281
+ except Exception as e:
282
+ print(f"Load data error: {e}")
283
+ return ann, None
284
+ return ann, model_inputs
dataset/tarsier_processor.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+
4
+ import torch
5
+
6
+ from transformers.feature_extraction_utils import BatchFeature
7
+ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
8
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
9
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
10
+ from transformers.utils import logging
11
+ from transformers import Qwen2VLImageProcessor
12
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class TarsierProcessorKwargs(ProcessingKwargs, total=False):
18
+ _defaults = {
19
+ "text_kwargs": {},
20
+ "images_kwargs": {},
21
+ }
22
+
23
+
24
+ class TarsierProcessor(ProcessorMixin):
25
+
26
+ attributes = ["image_processor", "tokenizer"]
27
+ valid_kwargs = ["chat_template", "image_token", "patch_size", "merge_size", "temporal_patch_size", "max_seq_len"]
28
+ image_processor_class = "AutoImageProcessor"
29
+ tokenizer_class = "AutoTokenizer"
30
+
31
+ def __init__(
32
+ self,
33
+ image_processor=None,
34
+ tokenizer=None,
35
+ chat_template=None,
36
+ image_token="<image>",
37
+ patch_size=None,
38
+ merge_size=1,
39
+ temporal_patch_size=1,
40
+ max_seq_len=8192,
41
+ **kwargs,
42
+ ) -> None:
43
+
44
+ self.image_token = image_token
45
+ self.patch_size = patch_size
46
+ self.merge_size = merge_size
47
+ self.temporal_patch_size = temporal_patch_size
48
+ self.max_seq_len = max_seq_len
49
+ self.max_pixels_per_sample = 128 * 384 * 384
50
+
51
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
52
+
53
+ def __call__(
54
+ self,
55
+ messages,
56
+ image_processing_config=None,
57
+ is_training=True,
58
+ ) -> torch.Tensor:
59
+
60
+ output_kwargs = self._merge_kwargs(
61
+ TarsierProcessorKwargs,
62
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
63
+ )
64
+
65
+ # 【图片处理】
66
+ pixel_values, image_grid_thw = [], []
67
+ num_images = 0
68
+ for msg in messages:
69
+ for content in msg['content']:
70
+ if content['type'] == 'image':
71
+ num_images += self.temporal_patch_size
72
+ elif content['type'] == 'video':
73
+ num_images += len(content['video'])
74
+ if num_images > 0 and self.max_pixels_per_sample // num_images < image_processing_config['max_pixels']:
75
+ image_processing_config['max_pixels'] = self.max_pixels_per_sample // num_images
76
+ image_processing_config['min_pixels'] = min(image_processing_config['min_pixels'], image_processing_config['max_pixels'])
77
+
78
+ for msg in messages:
79
+ for content in msg['content']:
80
+ if content['type'] == 'image':
81
+ content['image'] = self.preprocess_image(content['image'], image_processing_config)
82
+ content['image'] = self.image_processor(images = content['image'], **output_kwargs["images_kwargs"], return_tensors="pt")
83
+ content['num_vision_tokens'] = self.get_num_vision_tokens(content)
84
+ pixel_values.append(content['image']['pixel_values'])
85
+ if 'image_grid_thw' in content['image']:
86
+ image_grid_thw.extend(content['image']['image_grid_thw'])
87
+ elif content['type'] == 'video':
88
+ content['video'] = self.preprocess_image(content['video'], image_processing_config)
89
+ if isinstance(self.image_processor, Qwen2VLImageProcessor):
90
+ content['video'] = self.image_processor(images = None, videos = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
91
+ pixel_values.append(content['video']['pixel_values_videos'])
92
+ else:
93
+ content['video'] = self.image_processor(images = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
94
+ pixel_values.append(content['video']['pixel_values'])
95
+
96
+ if 'video_grid_thw' in content['video']:
97
+ image_grid_thw.extend(content['video']['video_grid_thw'])
98
+ content['num_vision_tokens'] = self.get_num_vision_tokens(content)
99
+
100
+ #【文本处理】
101
+ add_generation_prompt = (not is_training and messages[-1]['role'] != 'assistant')
102
+ strip_final_eos = (not is_training and messages[-1]['role'] == 'assistant')
103
+ text_inputs = self.tokenizer.apply_chat_template(
104
+ messages,
105
+ chat_template = self.chat_template,
106
+ tokenize=True,
107
+ tokenizer_kwargs = output_kwargs["text_kwargs"],
108
+ return_assistant_tokens_mask=True,
109
+ return_dict=True,
110
+ add_generation_prompt=add_generation_prompt,
111
+ strip_final_eos=strip_final_eos,
112
+ )
113
+ labels = [-100 if j == 0 else i for i, j in zip(text_inputs['input_ids'], text_inputs['assistant_masks'])]
114
+ labels = labels[:self.max_seq_len]
115
+ input_ids = text_inputs['input_ids'][:self.max_seq_len]
116
+
117
+ image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
118
+ if image_token_id in text_inputs['input_ids'][self.max_seq_len:]:
119
+ raise ValueError(f'Too long sequence! {len(text_inputs["input_ids"])}')
120
+
121
+ outputs = {
122
+ 'input_ids': input_ids,
123
+ 'labels': labels,
124
+ 'num_images': num_images,
125
+ }
126
+ if len(pixel_values) > 0:
127
+ outputs['pixel_values'] = torch.cat(pixel_values, dim=0)
128
+ if len(image_grid_thw) > 0:
129
+ outputs['image_grid_thw'] = torch.stack(image_grid_thw)
130
+ return outputs
131
+
132
+
133
+ def preprocess_image(self, pil_img: Union[Image.Image, List[Image.Image]], image_processing_config):
134
+ if image_processing_config is None:
135
+ return pil_img
136
+ images = pil_img
137
+ if isinstance(pil_img, Image.Image):
138
+ images = [images]
139
+ if image_processing_config['do_crop']:
140
+ images = [self.centralcrop(img, rate=[4, 3]) for img in images]
141
+ if image_processing_config['do_padding']:
142
+ images = [self.expand2square(
143
+ img,
144
+ # tuple(int(x * 255) for x in self.processor.image_processor.image_mean)
145
+ tuple(int(x * 255) for x in [0, 0, 0])
146
+ ) for img in images]
147
+ if image_processing_config['do_resize']:
148
+ images = [self.resize2square(img) for img in images]
149
+ if image_processing_config.get('max_pixels'):
150
+ images = [self.resize2pixels(
151
+ img,
152
+ int(image_processing_config['max_pixels']),
153
+ int(image_processing_config['min_pixels'])
154
+ ) for img in images]
155
+ if isinstance(pil_img, Image.Image):
156
+ images = images[0]
157
+ return images
158
+
159
+ def expand2square(self, pil_img, background_color):
160
+ width, height = pil_img.size
161
+ if width == height:
162
+ return pil_img
163
+ elif width > height:
164
+ result = Image.new(pil_img.mode, (width, width), background_color)
165
+ result.paste(pil_img, (0, (width - height) // 2))
166
+ return result
167
+ else:
168
+ result = Image.new(pil_img.mode, (height, height), background_color)
169
+ result.paste(pil_img, ((height - width) // 2, 0))
170
+ return result
171
+
172
+ def resize2square(self, pil_img: Image.Image):
173
+ width, height = pil_img.size
174
+ pil_img = pil_img.resize((max(width, height), max(width, height)))
175
+ return pil_img
176
+
177
+ def centralcrop(self, pil_img: Image.Image, rate=[4, 3]):
178
+ width, height = pil_img.size
179
+ size = (width, height)
180
+ min_len = min(size)
181
+ longer_side = 0 if width >= height else 1
182
+ center = (width/2, height/2)
183
+ box = [0, 0, size[0], size[1]]
184
+
185
+ # if longer_side == 0:
186
+ # box[0] = max(0, center[0] - 1/2*min_len/rate[1]*rate[0])
187
+ # box[2] = min(width, center[0] + 1/2*min_len/rate[1]*rate[0])
188
+ # else:
189
+ # box[1] = max(0, center[1] - 1/2*min_len/rate[1]*rate[0])
190
+ # box[3] = min(height, center[1] + 1/2*min_len/rate[1]*rate[0])
191
+ box[longer_side] = max(0, center[longer_side] - 1/2*min_len/rate[1]*rate[0])
192
+ box[2 + longer_side] = min(size[longer_side], center[longer_side] + 1/2*min_len/rate[1]*rate[0])
193
+
194
+ # box = (width/2-min_len/2, height/2-min_len/2, width/2+min_len/2, height/2+min_len/2)
195
+ pil_img = pil_img.crop(box)
196
+ return pil_img
197
+
198
+ def resize2pixels(self, pil_img: Image.Image, max_pixels=None, min_pixels=None):
199
+ width, height = pil_img.size
200
+ new_height, new_width = smart_resize(height, width, factor=1, max_pixels=max_pixels, min_pixels=min_pixels)
201
+ pil_img = pil_img.resize((new_width, new_height))
202
+ return pil_img
203
+
204
+ def get_num_vision_tokens(self, content):
205
+ if isinstance(self.image_processor, Qwen2VLImageProcessor):
206
+ merge_length = self.image_processor.merge_size**2
207
+ if content['type'] == 'image':
208
+ num_image_tokens = content['image']['image_grid_thw'].prod() // merge_length
209
+ else:
210
+ num_image_tokens = content['video']['video_grid_thw'].prod() // merge_length
211
+ return num_image_tokens
212
+ else:
213
+ # 其他模型:image tokens (-> 2x2 compressed) -> add image_newline and image_new
214
+ k = 'image'if content['type'] == 'image' else 'video'
215
+ pixel_values = content[k]['pixel_values'][0]
216
+ n_frames = len(content[k]['pixel_values'])
217
+
218
+ height, width = get_image_size(to_numpy_array(pixel_values))
219
+ num_image_tokens = (height // (self.patch_size * self.merge_size)) * (width // (self.patch_size * self.merge_size) + 1) + 1
220
+ return num_image_tokens * n_frames
221
+
222
+ def batch_decode(self, *args, **kwargs):
223
+ """
224
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
225
+ refer to the docstring of this method for more information.
226
+ """
227
+ return self.tokenizer.batch_decode(*args, **kwargs)
228
+
229
+ def decode(self, *args, **kwargs):
230
+ """
231
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
232
+ the docstring of this method for more information.
233
+ """
234
+ return self.tokenizer.decode(*args, **kwargs)
235
+
236
+ @property
237
+ def model_input_names(self):
238
+ tokenizer_input_names = self.tokenizer.model_input_names
239
+ image_processor_input_names = self.image_processor.model_input_names
240
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
dataset/utils.py CHANGED
@@ -126,3 +126,61 @@ def get_benchmarks(benchmarks):
126
  else:
127
  final_benchmarks.append(bm)
128
  return final_benchmarks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
  final_benchmarks.append(bm)
128
  return final_benchmarks
129
+
130
+ def check_data_format(data):
131
+ for msg in data['messages']:
132
+ if isinstance(msg['content'], dict):
133
+ msg['content'] = [msg['content']]
134
+ for content in msg['content']:
135
+ assert content['type'] in {'image', 'video', 'text'}, f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']"
136
+ if content['type'] != "text":
137
+ media_path_key = f"{content['type']}_file"
138
+ meida_paths = content[content['type']][media_path_key]
139
+ if isinstance(meida_paths, str):
140
+ meida_paths = [meida_paths]
141
+ for path in meida_paths:
142
+ assert os.path.exists(path), f"File not found: {path}"
143
+
144
+ def format_one_sample(media_file=None, prompt="Describe the video in detail."):
145
+ sample = {
146
+ "messages": []
147
+ }
148
+ user_content = {
149
+ "role": "user",
150
+ "content": []
151
+ }
152
+ if media_file is not None:
153
+ media_type = get_visual_type(media_file)
154
+ if media_type in ("video", "gif"):
155
+ media_type = "video"
156
+ media_path_key = f"{media_type}_file"
157
+ user_content["content"].append({
158
+ "type": media_type,
159
+ media_type: {
160
+ media_path_key: media_file,
161
+ }
162
+ })
163
+ user_content["content"].append({
164
+ "type": "text",
165
+ "text": prompt
166
+ })
167
+
168
+ assistant_content = {
169
+ "role": "assistant",
170
+ "content": []
171
+ }
172
+
173
+ sample["messages"].append(user_content)
174
+ sample["messages"].append(assistant_content)
175
+ if media_file is not None:
176
+ sample["task"] = f"{media_type}/QA"
177
+ else:
178
+ sample["task"] = 'text-only'
179
+ check_data_format(sample)
180
+ return sample
181
+
182
+
183
+ class DictToObject(object):
184
+ def __init__(self, dictionary):
185
+ for key, value in dictionary.items():
186
+ setattr(self, key, value)
models/modeling_qwen2_vl_fast.py ADDED
@@ -0,0 +1,1320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import LayerNorm
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.configuration_utils import PretrainedConfig
13
+ from transformers.modeling_rope_utils import rope_config_validation, ROPE_INIT_FUNCTIONS
14
+ from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.utils import (
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ is_flash_attn_2_available,
20
+ is_flash_attn_greater_or_equal_2_10,
21
+ logging,
22
+ replace_return_docstrings,
23
+ )
24
+ from transformers.modeling_outputs import (
25
+ BaseModelOutputWithPast,
26
+ ModelOutput,
27
+ )
28
+ from transformers.activations import ACT2FN
29
+ from transformers.generation import GenerationMixin
30
+
31
+ if is_flash_attn_2_available():
32
+ from flash_attn import flash_attn_varlen_func
33
+
34
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
+ else:
36
+ flash_attn_varlen_func = None
37
+
38
+ # from apex.normalization.fused_layer_norm import fused_rms_norm_affine
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ @dataclass
43
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
44
+ """
45
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
46
+
47
+ Args:
48
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
49
+ Language modeling loss (for next-token prediction).
50
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
51
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
52
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
53
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
54
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
55
+
56
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
57
+ `past_key_values` input) to speed up sequential decoding.
58
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
59
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
60
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
61
+
62
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
63
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`.
66
+
67
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
68
+ heads.
69
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
70
+ The rope index difference between sequence length and multimodal rope.
71
+ """
72
+
73
+ loss: Optional[torch.FloatTensor] = None
74
+ logits: torch.FloatTensor = None
75
+ past_key_values: Optional[List[torch.FloatTensor]] = None
76
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
77
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
78
+
79
+ class Qwen2VLVisionConfig(PretrainedConfig):
80
+ model_type = "qwen2_vl"
81
+
82
+ def __init__(
83
+ self,
84
+ depth=32,
85
+ embed_dim=1280,
86
+ hidden_size=3584,
87
+ hidden_act="quick_gelu",
88
+ mlp_ratio=4,
89
+ num_heads=16,
90
+ in_channels=3,
91
+ patch_size=14,
92
+ spatial_merge_size=2,
93
+ temporal_patch_size=2,
94
+ attn_implementation='flash_attention_2',
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+
99
+ self.depth = depth
100
+ self.embed_dim = embed_dim
101
+ self.hidden_size = hidden_size
102
+ self.hidden_act = hidden_act
103
+ self.mlp_ratio = mlp_ratio
104
+ self.num_heads = num_heads
105
+ self.in_channels = in_channels
106
+ self.patch_size = patch_size
107
+ self.spatial_merge_size = spatial_merge_size
108
+ self.temporal_patch_size = temporal_patch_size
109
+ self.attn_implementation = attn_implementation
110
+
111
+ @classmethod
112
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
113
+ cls._set_token_in_kwargs(kwargs)
114
+
115
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
116
+
117
+ if config_dict.get("model_type") == "qwen2_vl":
118
+ config_dict = config_dict["vision_config"]
119
+
120
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
121
+ logger.warning(
122
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
123
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
124
+ )
125
+
126
+ return cls.from_dict(config_dict, **kwargs)
127
+
128
+
129
+ class Qwen2VLConfig(PretrainedConfig):
130
+ r"""
131
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
132
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
133
+ with the defaults will yield a similar configuration to that of
134
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
135
+
136
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
137
+ documentation from [`PretrainedConfig`] for more information.
138
+
139
+
140
+ Args:
141
+ vocab_size (`int`, *optional*, defaults to 152064):
142
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
143
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
144
+ hidden_size (`int`, *optional*, defaults to 8192):
145
+ Dimension of the hidden representations.
146
+ intermediate_size (`int`, *optional*, defaults to 29568):
147
+ Dimension of the MLP representations.
148
+ num_hidden_layers (`int`, *optional*, defaults to 80):
149
+ Number of hidden layers in the Transformer encoder.
150
+ num_attention_heads (`int`, *optional*, defaults to 64):
151
+ Number of attention heads for each attention layer in the Transformer encoder.
152
+ num_key_value_heads (`int`, *optional*, defaults to 8):
153
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
154
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
155
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
156
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
157
+ by meanpooling all the original heads within that group. For more details checkout [this
158
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
159
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
160
+ The non-linear activation function (function or string) in the decoder.
161
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
162
+ The maximum sequence length that this model might ever be used with.
163
+ initializer_range (`float`, *optional*, defaults to 0.02):
164
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
165
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
166
+ The epsilon used by the rms normalization layers.
167
+ use_cache (`bool`, *optional*, defaults to `True`):
168
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
169
+ relevant if `config.is_decoder=True`.
170
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
171
+ Whether the model's input and output word embeddings should be tied.
172
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
173
+ The base period of the RoPE embeddings.
174
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
175
+ Whether to use sliding window attention.
176
+ sliding_window (`int`, *optional*, defaults to 4096):
177
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
178
+ max_window_layers (`int`, *optional*, defaults to 80):
179
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
180
+ attention_dropout (`float`, *optional*, defaults to 0.0):
181
+ The dropout ratio for the attention probabilities.
182
+ vision_config (`Dict`, *optional*):
183
+ The config for the visual encoder initialization.
184
+ rope_scaling (`Dict`, *optional*):
185
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
186
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
187
+ accordingly.
188
+ Expected contents:
189
+ `rope_type` (`str`):
190
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
191
+ 'llama3'], with 'default' being the original RoPE implementation.
192
+ `factor` (`float`, *optional*):
193
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
194
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
195
+ original maximum pre-trained length.
196
+ `original_max_position_embeddings` (`int`, *optional*):
197
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
198
+ pretraining.
199
+ `attention_factor` (`float`, *optional*):
200
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
201
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
202
+ `factor` field to infer the suggested value.
203
+ `beta_fast` (`float`, *optional*):
204
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
205
+ ramp function. If unspecified, it defaults to 32.
206
+ `beta_slow` (`float`, *optional*):
207
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
208
+ ramp function. If unspecified, it defaults to 1.
209
+ `short_factor` (`List[float]`, *optional*):
210
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
211
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
212
+ size divided by the number of attention heads divided by 2
213
+ `long_factor` (`List[float]`, *optional*):
214
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
215
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
216
+ size divided by the number of attention heads divided by 2
217
+ `low_freq_factor` (`float`, *optional*):
218
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
219
+ `high_freq_factor` (`float`, *optional*):
220
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
221
+
222
+ ```python
223
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
224
+
225
+ >>> # Initializing a Qwen2VL style configuration
226
+ >>> configuration = Qwen2VLConfig()
227
+
228
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
229
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
230
+
231
+ >>> # Accessing the model configuration
232
+ >>> configuration = model.config
233
+ ```"""
234
+
235
+ model_type = "qwen2_vl"
236
+ keys_to_ignore_at_inference = ["past_key_values"]
237
+
238
+ def __init__(
239
+ self,
240
+ vocab_size=152064,
241
+ hidden_size=8192,
242
+ intermediate_size=29568,
243
+ num_hidden_layers=80,
244
+ num_attention_heads=64,
245
+ num_key_value_heads=8,
246
+ hidden_act="silu",
247
+ max_position_embeddings=32768,
248
+ initializer_range=0.02,
249
+ rms_norm_eps=1e-05,
250
+ use_cache=True,
251
+ tie_word_embeddings=False,
252
+ rope_theta=1000000.0,
253
+ use_sliding_window=False,
254
+ sliding_window=4096,
255
+ max_window_layers=80,
256
+ attention_dropout=0.0,
257
+ rope_scaling=None,
258
+ spatial_merge_size=2,
259
+ attn_implementation='flash_attention_2',
260
+ **kwargs,
261
+ ):
262
+
263
+ self.vocab_size = vocab_size
264
+ self.max_position_embeddings = max_position_embeddings
265
+ self.hidden_size = hidden_size
266
+ self.intermediate_size = intermediate_size
267
+ self.num_hidden_layers = num_hidden_layers
268
+ self.num_attention_heads = num_attention_heads
269
+ self.use_sliding_window = use_sliding_window
270
+ self.sliding_window = sliding_window
271
+ self.max_window_layers = max_window_layers
272
+
273
+ # for backward compatibility
274
+ if num_key_value_heads is None:
275
+ num_key_value_heads = num_attention_heads
276
+
277
+ self.num_key_value_heads = num_key_value_heads
278
+ self.hidden_act = hidden_act
279
+ self.initializer_range = initializer_range
280
+ self.rms_norm_eps = rms_norm_eps
281
+ self.use_cache = use_cache
282
+ self.rope_theta = rope_theta
283
+ self.attention_dropout = attention_dropout
284
+ self.rope_scaling = rope_scaling
285
+ self.spatial_merge_size = spatial_merge_size
286
+ self.attn_implementation = attn_implementation
287
+
288
+ # Validate the correctness of rotary position embeddings parameters
289
+ # BC: if there is a 'type' field, move it to 'rope_type'.
290
+ # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
291
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
292
+ # TODO: @raushan update config in the hub
293
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
294
+ if self.rope_scaling["type"] == "mrope":
295
+ self.rope_scaling["type"] = "default"
296
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
297
+ rope_config_validation(self, ignore_keys={"mrope_section"})
298
+
299
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
300
+
301
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
302
+ def rotate_half(x):
303
+ """Rotates half the hidden dims of the input."""
304
+ x1 = x[..., : x.shape[-1] // 2]
305
+ x2 = x[..., x.shape[-1] // 2 :]
306
+ return torch.cat((-x2, x1), dim=-1)
307
+
308
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
309
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
310
+
311
+ Explanation:
312
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
313
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
314
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
315
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
316
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
317
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
318
+ difference with modern LLMs.
319
+
320
+ Args:
321
+ q (`torch.Tensor`): The query tensor.
322
+ k (`torch.Tensor`): The key tensor.
323
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
324
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
325
+ position_ids (`torch.Tensor`):
326
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
327
+ used to pass offsetted position ids when working with a KV-cache.
328
+ mrope_section(`List(int)`):
329
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
330
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
331
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
332
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
333
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
334
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
335
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
336
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
337
+ Returns:
338
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
339
+ """
340
+ mrope_section = mrope_section * 2
341
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
342
+ unsqueeze_dim
343
+ )
344
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
345
+ unsqueeze_dim
346
+ )
347
+
348
+ q_embed = (q * cos) + (rotate_half(q) * sin)
349
+ k_embed = (k * cos) + (rotate_half(k) * sin)
350
+ return q_embed, k_embed
351
+
352
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
353
+ orig_dtype = tensor.dtype
354
+ tensor = tensor.float()
355
+ cos = freqs.cos()
356
+ sin = freqs.sin()
357
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
358
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
359
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
360
+ output = output.to(orig_dtype)
361
+ return output
362
+
363
+
364
+ class VisionRotaryEmbedding(nn.Module):
365
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
366
+ super().__init__()
367
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
368
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
369
+
370
+ def forward(self, seqlen: int) -> torch.Tensor:
371
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
372
+ freqs = torch.outer(seq, self.inv_freq)
373
+ return freqs
374
+
375
+ class PatchEmbed(nn.Module):
376
+ def __init__(
377
+ self,
378
+ patch_size: int = 14,
379
+ temporal_patch_size: int = 2,
380
+ in_channels: int = 3,
381
+ embed_dim: int = 1152,
382
+ ) -> None:
383
+ super().__init__()
384
+ self.patch_size = patch_size
385
+ self.temporal_patch_size = temporal_patch_size
386
+ self.in_channels = in_channels
387
+ self.embed_dim = embed_dim
388
+
389
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
390
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
391
+
392
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
393
+ target_dtype = self.proj.weight.dtype
394
+ hidden_states = hidden_states.view(
395
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
396
+ )
397
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
398
+ return hidden_states
399
+
400
+
401
+ class PatchMerger(nn.Module):
402
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
403
+ super().__init__()
404
+ self.hidden_size = context_dim * (spatial_merge_size**2)
405
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
406
+ self.mlp = nn.Sequential(
407
+ nn.Linear(self.hidden_size, self.hidden_size),
408
+ nn.GELU(),
409
+ nn.Linear(self.hidden_size, dim),
410
+ )
411
+
412
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
413
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
414
+ return x
415
+
416
+ class VisionMlp(nn.Module):
417
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
418
+ super().__init__()
419
+ self.fc1 = nn.Linear(dim, hidden_dim)
420
+ self.act = ACT2FN[hidden_act]
421
+ self.fc2 = nn.Linear(hidden_dim, dim)
422
+
423
+ def forward(self, x) -> torch.Tensor:
424
+ return self.fc2(self.act(self.fc1(x)))
425
+
426
+
427
+ class VisionAttention(nn.Module):
428
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
429
+ super().__init__()
430
+ self.num_heads = num_heads
431
+ self.head_dim = dim // num_heads
432
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
433
+ self.proj = nn.Linear(dim, dim)
434
+
435
+ def forward(
436
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
437
+ ) -> torch.Tensor:
438
+ seq_length = hidden_states.shape[0]
439
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
440
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
441
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
442
+
443
+ attention_mask = torch.full(
444
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
445
+ )
446
+ for i in range(1, len(cu_seqlens)):
447
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
448
+
449
+ q = q.transpose(0, 1)
450
+ k = k.transpose(0, 1)
451
+ v = v.transpose(0, 1)
452
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
453
+ attn_weights = attn_weights + attention_mask
454
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
455
+ attn_output = torch.matmul(attn_weights, v)
456
+ attn_output = attn_output.transpose(0, 1)
457
+ attn_output = attn_output.reshape(seq_length, -1)
458
+ attn_output = self.proj(attn_output)
459
+ return attn_output
460
+
461
+
462
+ class VisionFlashAttention2(nn.Module):
463
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
464
+ super().__init__()
465
+ self.num_heads = num_heads
466
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
467
+ self.proj = nn.Linear(dim, dim)
468
+
469
+ def forward(
470
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
471
+ ) -> torch.Tensor:
472
+ seq_length = hidden_states.shape[0]
473
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
474
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
475
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
476
+
477
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
478
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
479
+ seq_length, -1
480
+ )
481
+ attn_output = self.proj(attn_output)
482
+ return attn_output
483
+
484
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
485
+ "eager": VisionAttention,
486
+ "flash_attention_2": VisionFlashAttention2,
487
+ }
488
+
489
+
490
+ class Qwen2VLVisionBlock(nn.Module):
491
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
492
+ super().__init__()
493
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
494
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
495
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
496
+
497
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
498
+ config.embed_dim, num_heads=config.num_heads
499
+ )
500
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
501
+
502
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
503
+ hidden_states = hidden_states + self.attn(
504
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
505
+ )
506
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
507
+ return hidden_states
508
+
509
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
510
+ config_class = Qwen2VLConfig
511
+ base_model_prefix = "model"
512
+ supports_gradient_checkpointing = True
513
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
514
+ _skip_keys_device_placement = "past_key_values"
515
+ _supports_flash_attn_2 = True
516
+ _supports_sdpa = False
517
+ _supports_cache_class = True
518
+ _supports_static_cache = True
519
+
520
+ def _init_weights(self, module):
521
+ std = self.config.initializer_range
522
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
523
+ module.weight.data.normal_(mean=0.0, std=std)
524
+ if module.bias is not None:
525
+ module.bias.data.zero_()
526
+ elif isinstance(module, nn.Embedding):
527
+ module.weight.data.normal_(mean=0.0, std=std)
528
+ if module.padding_idx is not None:
529
+ module.weight.data[module.padding_idx].zero_()
530
+
531
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
532
+ config_class = Qwen2VLVisionConfig
533
+ _no_split_modules = ["Qwen2VLVisionBlock"]
534
+
535
+ def __init__(self, config) -> None:
536
+ super().__init__(config)
537
+ self.spatial_merge_size = config.spatial_merge_size
538
+
539
+ self.patch_embed = PatchEmbed(
540
+ patch_size=config.patch_size,
541
+ temporal_patch_size=config.temporal_patch_size,
542
+ in_channels=config.in_channels,
543
+ embed_dim=config.embed_dim,
544
+ )
545
+
546
+ head_dim = config.embed_dim // config.num_heads
547
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
548
+
549
+ self.blocks = nn.ModuleList(
550
+ [Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
551
+ )
552
+ self.merger = PatchMerger(
553
+ dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
554
+ )
555
+ # Initialize weights and apply final processing
556
+ self.gradient_checkpointing = False
557
+ self.post_init()
558
+
559
+ def get_dtype(self) -> torch.dtype:
560
+ return self.blocks[0].mlp.fc2.weight.dtype
561
+
562
+ def get_device(self) -> torch.device:
563
+ return self.blocks[0].mlp.fc2.weight.device
564
+
565
+ def rot_pos_emb(self, grid_thw):
566
+ pos_ids = []
567
+ for t, h, w in grid_thw:
568
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
569
+ hpos_ids = hpos_ids.reshape(
570
+ h // self.spatial_merge_size,
571
+ self.spatial_merge_size,
572
+ w // self.spatial_merge_size,
573
+ self.spatial_merge_size,
574
+ )
575
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
576
+ hpos_ids = hpos_ids.flatten()
577
+
578
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
579
+ wpos_ids = wpos_ids.reshape(
580
+ h // self.spatial_merge_size,
581
+ self.spatial_merge_size,
582
+ w // self.spatial_merge_size,
583
+ self.spatial_merge_size,
584
+ )
585
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
586
+ wpos_ids = wpos_ids.flatten()
587
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
588
+ pos_ids = torch.cat(pos_ids, dim=0)
589
+ max_grid_size = grid_thw[:, 1:].max()
590
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
591
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
592
+ return rotary_pos_emb
593
+
594
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
595
+ hidden_states = self.patch_embed(hidden_states)
596
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
597
+
598
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
599
+ dim=0, dtype=torch.int32
600
+ )
601
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
602
+
603
+ for blk in self.blocks:
604
+ if self.gradient_checkpointing and self.training:
605
+ hidden_states = self._gradient_checkpointing_func(
606
+ blk.__call__,
607
+ hidden_states,
608
+ cu_seqlens,
609
+ rotary_pos_emb,
610
+ )
611
+ else:
612
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
613
+
614
+ return self.merger(hidden_states)
615
+
616
+ # class Qwen2RMSNorm(nn.Module):
617
+ # def __init__(self, hidden_size, eps=1e-6):
618
+ # """
619
+ # Qwen2RMSNorm is equivalent to T5LayerNorm
620
+ # """
621
+ # super().__init__()
622
+ # self.weight = nn.Parameter(torch.ones(hidden_size))
623
+ # self.variance_epsilon = eps
624
+ # self.normalized_shape = torch.Size((hidden_size, ))
625
+
626
+ # def forward(self, hidden_states):
627
+ # return fused_rms_norm_affine(input=hidden_states,
628
+ # weight=self.weight,
629
+ # normalized_shape=self.normalized_shape,
630
+ # eps=self.variance_epsilon,
631
+ # memory_efficient=True)
632
+
633
+ class Qwen2RMSNorm(nn.Module):
634
+ def __init__(self, hidden_size, eps=1e-6):
635
+ """
636
+ Qwen2RMSNorm is equivalent to T5LayerNorm
637
+ """
638
+ super().__init__()
639
+ self.weight = nn.Parameter(torch.ones(hidden_size))
640
+ self.variance_epsilon = eps
641
+
642
+ def forward(self, hidden_states):
643
+ input_dtype = hidden_states.dtype
644
+ hidden_states = hidden_states.to(torch.float32)
645
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
646
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
647
+ return self.weight * hidden_states.to(input_dtype)
648
+
649
+ def extra_repr(self):
650
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
651
+
652
+ class Qwen2VLRotaryEmbedding(nn.Module):
653
+ def __init__(
654
+ self,
655
+ dim=None,
656
+ max_position_embeddings=2048,
657
+ base=10000,
658
+ device=None,
659
+ scaling_factor=1.0,
660
+ rope_type="default",
661
+ config: Optional[Qwen2VLConfig] = None,
662
+ ):
663
+ super().__init__()
664
+ # TODO (joao): remove the `if` below, only used for BC
665
+ self.rope_kwargs = {}
666
+ if config is None:
667
+ logger.warning_once(
668
+ "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
669
+ "`config` argument. All other arguments will be removed in v4.46"
670
+ )
671
+ self.rope_kwargs = {
672
+ "rope_type": rope_type,
673
+ "factor": scaling_factor,
674
+ "dim": dim,
675
+ "base": base,
676
+ "max_position_embeddings": max_position_embeddings,
677
+ }
678
+ self.rope_type = rope_type
679
+ self.max_seq_len_cached = max_position_embeddings
680
+ self.original_max_seq_len = max_position_embeddings
681
+ else:
682
+ # BC: "rope_type" was originally "type"
683
+ if config.rope_scaling is not None:
684
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
685
+ else:
686
+ self.rope_type = "default"
687
+ self.max_seq_len_cached = config.max_position_embeddings
688
+ self.original_max_seq_len = config.max_position_embeddings
689
+
690
+ self.config = config
691
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
692
+
693
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
694
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
695
+ self.original_inv_freq = self.inv_freq
696
+
697
+ def _dynamic_frequency_update(self, position_ids, device):
698
+ """
699
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
700
+ 1 - growing beyond the cached sequence length (allow scaling)
701
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
702
+ """
703
+ seq_len = torch.max(position_ids) + 1
704
+ if seq_len > self.max_seq_len_cached: # growth
705
+ inv_freq, self.attention_scaling = self.rope_init_fn(
706
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
707
+ )
708
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
709
+ self.max_seq_len_cached = seq_len
710
+
711
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
712
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
713
+ self.max_seq_len_cached = self.original_max_seq_len
714
+
715
+ @torch.no_grad()
716
+ def forward(self, x, position_ids):
717
+ position_ids = position_ids.permute(2, 0, 1)
718
+ if "dynamic" in self.rope_type:
719
+ self._dynamic_frequency_update(position_ids, device=x.device)
720
+
721
+ # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
722
+ # So we expand the inv_freq to shape (3, ...)
723
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
724
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
725
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
726
+ device_type = x.device.type
727
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
728
+ with torch.autocast(device_type=device_type, enabled=False):
729
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
730
+ emb = torch.cat((freqs, freqs), dim=-1)
731
+ cos = emb.cos()
732
+ sin = emb.sin()
733
+
734
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
735
+ cos = cos * self.attention_scaling
736
+ sin = sin * self.attention_scaling
737
+
738
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
739
+
740
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
741
+ class Qwen2MLP(nn.Module):
742
+ def __init__(self, config):
743
+ super().__init__()
744
+ self.hidden_size = config.hidden_size
745
+ self.intermediate_size = config.intermediate_size
746
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
747
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
748
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
749
+ self.act_fn = ACT2FN[config.hidden_act]
750
+
751
+ def forward(self, hidden_state):
752
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
753
+
754
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
755
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
756
+ """
757
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
758
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
759
+ """
760
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
761
+ if n_rep == 1:
762
+ return hidden_states
763
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
764
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
765
+
766
+ class Qwen2VLAttention(nn.Module):
767
+ """
768
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
769
+ and "Generating Long Sequences with Sparse Transformers".
770
+ """
771
+
772
+ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
773
+ super().__init__()
774
+ self.config = config
775
+ self.layer_idx = layer_idx
776
+ if layer_idx is None:
777
+ logger.warning_once(
778
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
779
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
780
+ "when creating this class."
781
+ )
782
+
783
+ self.hidden_size = config.hidden_size
784
+ self.num_heads = config.num_attention_heads
785
+ self.head_dim = self.hidden_size // self.num_heads
786
+ self.num_key_value_heads = config.num_key_value_heads
787
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
788
+ self.max_position_embeddings = config.max_position_embeddings
789
+ self.rope_theta = config.rope_theta
790
+ self.is_causal = True
791
+ self.attention_dropout = config.attention_dropout
792
+ self.rope_scaling = config.rope_scaling
793
+
794
+ if (self.head_dim * self.num_heads) != self.hidden_size:
795
+ raise ValueError(
796
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
797
+ f" and `num_heads`: {self.num_heads})."
798
+ )
799
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
800
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
801
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
802
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
803
+
804
+
805
+ class Qwen2VLFlashAttention2(Qwen2VLAttention):
806
+ """
807
+ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
808
+ as the weights of the module stays untouched. The only required change would be on the forward pass
809
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
810
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
811
+ config.max_window_layers layers.
812
+ """
813
+
814
+ def __init__(self, *args, **kwargs):
815
+ super().__init__(*args, **kwargs)
816
+
817
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
818
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
819
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
820
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
821
+
822
+ def forward(
823
+ self,
824
+ hidden_states: torch.Tensor,
825
+ attention_mask: Optional[torch.Tensor] = None,
826
+ position_ids: Optional[torch.LongTensor] = None,
827
+ past_key_value: Optional[Cache] = None,
828
+ output_attentions: bool = False,
829
+ use_cache: bool = False,
830
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
831
+ use_rmpad: Optional[bool] = False,
832
+ cu_seqlens: Optional[torch.Tensor] = False,
833
+ ):
834
+ """
835
+ Train:
836
+ unpad: (bsz, q_len) = (1, acc_seqlen)
837
+ pad: (bsz, q_len) = (bsz, q_len)
838
+ Test:
839
+ first_iter: (bsz, q_len) = (bsz, q_len)
840
+ other: (bsz, q_len) = (bsz, 1)
841
+ """
842
+ bsz, q_len, _ = hidden_states.size()
843
+
844
+ query_states = self.q_proj(hidden_states)
845
+ key_states = self.k_proj(hidden_states)
846
+ value_states = self.v_proj(hidden_states)
847
+
848
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
849
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
850
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
851
+
852
+ cos, sin = position_embeddings
853
+
854
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
855
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
856
+ )
857
+
858
+ if past_key_value is not None:
859
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
860
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
861
+
862
+ # repeat k/v heads if n_kv_heads < n_heads
863
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
864
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
865
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
866
+
867
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
868
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
869
+ # cast them back in float16 just to be sure everything works as expected.
870
+ input_dtype = query_states.dtype
871
+ if input_dtype == torch.float32:
872
+ if torch.is_autocast_enabled():
873
+ target_dtype = torch.get_autocast_gpu_dtype()
874
+ # Handle the case where the model is quantized
875
+ elif hasattr(self.config, "_pre_quantization_dtype"):
876
+ target_dtype = self.config._pre_quantization_dtype
877
+ else:
878
+ target_dtype = self.q_proj.weight.dtype
879
+
880
+ logger.warning_once(
881
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
882
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
883
+ f" {target_dtype}."
884
+ )
885
+
886
+ query_states = query_states.to(target_dtype)
887
+ key_states = key_states.to(target_dtype)
888
+ value_states = value_states.to(target_dtype)
889
+
890
+ # Reashape to the expected shape for Flash Attention
891
+ query_states = query_states.transpose(1, 2)
892
+ key_states = key_states.transpose(1, 2)
893
+ value_states = value_states.transpose(1, 2)
894
+
895
+ if use_rmpad:
896
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
897
+ attn_output = flash_attn_varlen_func(
898
+ query_states.squeeze(0), key_states.squeeze(0), value_states.squeeze(0),
899
+ cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
900
+ max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
901
+ dropout_p=dropout_rate,
902
+ causal=self.is_causal, window_size=(-1, -1),
903
+ )
904
+ else:
905
+ attn_output = _flash_attention_forward(
906
+ query_states, key_states, value_states,
907
+ attention_mask,
908
+ q_len,
909
+ dropout=dropout_rate,
910
+ sliding_window=None,
911
+ is_causal=self.is_causal,
912
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
913
+ )
914
+
915
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
916
+ attn_output = self.o_proj(attn_output)
917
+
918
+ if not output_attentions:
919
+ attn_weights = None
920
+
921
+ return attn_output, attn_weights, past_key_value
922
+
923
+ QWEN2_VL_ATTENTION_CLASSES = {
924
+ "flash_attention_2": Qwen2VLFlashAttention2,
925
+ }
926
+
927
+ class Qwen2VLDecoderLayer(nn.Module):
928
+ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
929
+ super().__init__()
930
+ self.hidden_size = config.hidden_size
931
+
932
+ if config.attn_implementation != "flash_attention_2":
933
+ logger.error(
934
+ f"只支持 flash_attention_2!config.attn_implementation={config.attn_implementation}"
935
+ )
936
+ self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
937
+
938
+ self.mlp = Qwen2MLP(config)
939
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
940
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
941
+
942
+ def forward(
943
+ self,
944
+ hidden_states: torch.Tensor,
945
+ attention_mask: Optional[torch.Tensor] = None,
946
+ position_ids: Optional[torch.LongTensor] = None,
947
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
948
+ output_attentions: Optional[bool] = False,
949
+ use_cache: Optional[bool] = False,
950
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
951
+ use_rmpad: Optional[bool] = False,
952
+ cu_seqlens: Optional[torch.Tensor] = False,
953
+ **kwargs,
954
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
955
+ """
956
+ Args:
957
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
958
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
959
+ `(batch, sequence_length)` where padding elements are indicated by 0.
960
+ output_attentions (`bool`, *optional*):
961
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
962
+ returned tensors for more detail.
963
+ use_cache (`bool`, *optional*):
964
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
965
+ (see `past_key_values`).
966
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
967
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
968
+ Indices depicting the position of the input sequence tokens in the sequence.
969
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
970
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
971
+ with `head_dim` being the embedding dimension of each attention head.
972
+ kwargs (`dict`, *optional*):
973
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
974
+ into the model
975
+ """
976
+
977
+ residual = hidden_states
978
+
979
+ hidden_states = self.input_layernorm(hidden_states)
980
+
981
+ # Self Attention
982
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
983
+ hidden_states=hidden_states,
984
+ attention_mask=attention_mask,
985
+ position_ids=position_ids,
986
+ past_key_value=past_key_value,
987
+ output_attentions=output_attentions,
988
+ use_cache=use_cache,
989
+ position_embeddings=position_embeddings,
990
+ use_rmpad=use_rmpad,
991
+ cu_seqlens=cu_seqlens,
992
+ )
993
+ hidden_states = residual + hidden_states
994
+
995
+ # Fully Connected
996
+ residual = hidden_states
997
+ hidden_states = self.post_attention_layernorm(hidden_states)
998
+ hidden_states = self.mlp(hidden_states)
999
+ hidden_states = residual + hidden_states
1000
+
1001
+ outputs = (hidden_states,)
1002
+
1003
+ if output_attentions:
1004
+ outputs += (self_attn_weights,)
1005
+
1006
+ if use_cache:
1007
+ outputs += (present_key_value,)
1008
+
1009
+ return outputs
1010
+
1011
+ class Qwen2VLModel(Qwen2VLPreTrainedModel):
1012
+ def __init__(self, config: Qwen2VLConfig):
1013
+ super().__init__(config)
1014
+ self.padding_idx = config.pad_token_id
1015
+ self.vocab_size = config.vocab_size
1016
+
1017
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1018
+ self.layers = nn.ModuleList([Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
1019
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1020
+ self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
1021
+
1022
+ self.gradient_checkpointing = False
1023
+ # Initialize weights and apply final processing
1024
+ self.post_init()
1025
+
1026
+ def get_input_embeddings(self):
1027
+ return self.embed_tokens
1028
+
1029
+ def set_input_embeddings(self, value):
1030
+ self.embed_tokens = value
1031
+
1032
+ def forward(
1033
+ self,
1034
+ input_ids: torch.LongTensor = None,
1035
+ attention_mask: Optional[torch.Tensor] = None,
1036
+ position_ids: Optional[torch.LongTensor] = None,
1037
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1038
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1039
+ use_cache: Optional[bool] = None,
1040
+ output_attentions: Optional[bool] = None,
1041
+ output_hidden_states: Optional[bool] = None,
1042
+ return_dict: Optional[bool] = None,
1043
+ use_rmpad: Optional[bool] = False,
1044
+ cu_seqlens: Optional[torch.Tensor] = False,
1045
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1046
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1047
+ output_hidden_states = (
1048
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1049
+ )
1050
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1051
+
1052
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1053
+
1054
+ if (input_ids is None) ^ (inputs_embeds is not None):
1055
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1056
+
1057
+ if self.gradient_checkpointing and self.training:
1058
+ if use_cache:
1059
+ logger.warning_once(
1060
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1061
+ )
1062
+ use_cache = False
1063
+
1064
+
1065
+ hidden_states = inputs_embeds
1066
+
1067
+ # create position embeddings to be shared across the decoder layers
1068
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1069
+
1070
+ # decoder layers
1071
+ all_hidden_states = () if output_hidden_states else None
1072
+ all_self_attns = () if output_attentions else None
1073
+ next_decoder_cache = None
1074
+
1075
+ for decoder_layer in self.layers:
1076
+ if output_hidden_states:
1077
+ all_hidden_states += (hidden_states,)
1078
+
1079
+ if self.gradient_checkpointing and self.training:
1080
+ layer_outputs = self._gradient_checkpointing_func(
1081
+ decoder_layer.__call__,
1082
+ hidden_states,
1083
+ attention_mask,
1084
+ position_ids,
1085
+ past_key_values,
1086
+ output_attentions,
1087
+ use_cache,
1088
+ position_embeddings,
1089
+ use_rmpad,
1090
+ cu_seqlens,
1091
+ )
1092
+ else:
1093
+ layer_outputs = decoder_layer(
1094
+ hidden_states,
1095
+ attention_mask=attention_mask,
1096
+ position_ids=position_ids,
1097
+ past_key_value=past_key_values,
1098
+ output_attentions=output_attentions,
1099
+ use_cache=use_cache,
1100
+ position_embeddings=position_embeddings,
1101
+ use_rmpad=use_rmpad,
1102
+ cu_seqlens=cu_seqlens,
1103
+ )
1104
+
1105
+ hidden_states = layer_outputs[0]
1106
+
1107
+ if use_cache:
1108
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1109
+
1110
+ if output_attentions:
1111
+ all_self_attns += (layer_outputs[1],)
1112
+
1113
+ hidden_states = self.norm(hidden_states)
1114
+
1115
+ # add hidden states from the last decoder layer
1116
+ if output_hidden_states:
1117
+ all_hidden_states += (hidden_states,)
1118
+
1119
+ next_cache = next_decoder_cache if use_cache else None
1120
+
1121
+ if not return_dict:
1122
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1123
+ return BaseModelOutputWithPast(
1124
+ last_hidden_state=hidden_states,
1125
+ past_key_values=next_cache,
1126
+ hidden_states=all_hidden_states,
1127
+ attentions=all_self_attns,
1128
+ )
1129
+
1130
+ class Qwen2VLForCausalLM(Qwen2VLPreTrainedModel, GenerationMixin):
1131
+ _tied_weights_keys = ["lm_head.weight"]
1132
+
1133
+ def __init__(self, config):
1134
+ super().__init__(config)
1135
+ self.model = Qwen2VLModel(config)
1136
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1137
+ self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
1138
+
1139
+ # Initialize weights and apply final processing
1140
+ self.post_init()
1141
+
1142
+ def get_input_embeddings(self):
1143
+ return self.model.embed_tokens
1144
+
1145
+ def set_input_embeddings(self, value):
1146
+ self.model.embed_tokens = value
1147
+
1148
+ def get_output_embeddings(self):
1149
+ return self.lm_head
1150
+
1151
+ def set_output_embeddings(self, new_embeddings):
1152
+ self.lm_head = new_embeddings
1153
+
1154
+ def set_decoder(self, decoder):
1155
+ self.model = decoder
1156
+
1157
+ def get_decoder(self):
1158
+ return self.model
1159
+
1160
+ def get_rope_index(
1161
+ self,
1162
+ input_ids: torch.LongTensor,
1163
+ image_grid_thw: Optional[torch.LongTensor] = None,
1164
+ attention_mask: Optional[torch.Tensor] = None,
1165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1166
+ """
1167
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1168
+
1169
+ Explanation:
1170
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1171
+
1172
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
1173
+ Examples:
1174
+ input_ids: [T T T T T], here T is for text.
1175
+ temporal position_ids: [0, 1, 2, 3, 4]
1176
+ height position_ids: [0, 1, 2, 3, 4]
1177
+ width position_ids: [0, 1, 2, 3, 4]
1178
+
1179
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1180
+ and 1D rotary position embeddin for text part.
1181
+ Examples:
1182
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1183
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1184
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1185
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1186
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1187
+ text temporal position_ids: [3, 4, 5, 6, 7]
1188
+ text height position_ids: [3, 4, 5, 6, 7]
1189
+ text width position_ids: [3, 4, 5, 6, 7]
1190
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1191
+
1192
+ Args:
1193
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1194
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1195
+ it.
1196
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1197
+ The temporal, height and width of feature shape of each image in LLM.
1198
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1199
+ The temporal, height and width of feature shape of each video in LLM.
1200
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1201
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1202
+
1203
+ - 1 for tokens that are **not masked**,
1204
+ - 0 for tokens that are **masked**.
1205
+
1206
+ Returns:
1207
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1208
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1209
+ """
1210
+ spatial_merge_size = self.config.spatial_merge_size
1211
+ vision_token_id = self.config.image_token_id
1212
+ vision_start_token_id = self.config.vision_start_token_id
1213
+ assert image_grid_thw is not None # TODO:测试纯文本会不会卡住
1214
+ total_input_ids = input_ids
1215
+ position_ids = torch.ones(
1216
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1217
+ )
1218
+ vision_index = 0
1219
+ for i, input_ids in enumerate(total_input_ids):
1220
+ if attention_mask is not None:
1221
+ input_ids = input_ids[attention_mask[i] == 1]
1222
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1223
+ vision_num = (input_ids[vision_start_indices + 1] == vision_token_id).sum()
1224
+ input_tokens = input_ids.tolist()
1225
+ llm_pos_ids_list: list = []
1226
+ st = 0
1227
+ remain_vision_num = vision_num
1228
+ for _ in range(vision_num):
1229
+ if vision_token_id in input_tokens and remain_vision_num > 0:
1230
+ ed_vision = input_tokens.index(vision_token_id, st)
1231
+ else:
1232
+ ed_vision = len(input_tokens) + 1
1233
+
1234
+ t, h, w = (
1235
+ image_grid_thw[vision_index][0],
1236
+ image_grid_thw[vision_index][1],
1237
+ image_grid_thw[vision_index][2],
1238
+ )
1239
+ vision_index += 1
1240
+ remain_vision_num -= 1
1241
+ ed = ed_vision
1242
+
1243
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1244
+ t.item(),
1245
+ h.item() // spatial_merge_size,
1246
+ w.item() // spatial_merge_size,
1247
+ )
1248
+ text_len = ed - st
1249
+
1250
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1251
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1252
+
1253
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1254
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1255
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1256
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1257
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1258
+
1259
+ if st < len(input_tokens):
1260
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1261
+ text_len = len(input_tokens) - st
1262
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1263
+
1264
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1265
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1266
+ position_ids = position_ids.permute(1, 2, 0)
1267
+ return position_ids
1268
+
1269
+ def forward(
1270
+ self,
1271
+ input_ids: torch.LongTensor = None,
1272
+ attention_mask: Optional[torch.Tensor] = None,
1273
+ position_ids: Optional[torch.LongTensor] = None,
1274
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1276
+ labels: Optional[torch.LongTensor] = None,
1277
+ use_cache: Optional[bool] = None,
1278
+ output_attentions: Optional[bool] = None,
1279
+ output_hidden_states: Optional[bool] = None,
1280
+ return_dict: Optional[bool] = None,
1281
+ use_rmpad: Optional[bool] = False,
1282
+ cu_seqlens: Optional[torch.Tensor] = False,
1283
+ ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
1284
+
1285
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1286
+ output_hidden_states = (
1287
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1288
+ )
1289
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1290
+
1291
+
1292
+ outputs = self.model(
1293
+ input_ids=input_ids,
1294
+ attention_mask=attention_mask,
1295
+ position_ids=position_ids,
1296
+ past_key_values=past_key_values,
1297
+ inputs_embeds=inputs_embeds,
1298
+ use_cache=use_cache,
1299
+ output_attentions=output_attentions,
1300
+ output_hidden_states=output_hidden_states,
1301
+ return_dict=return_dict,
1302
+ use_rmpad=use_rmpad,
1303
+ cu_seqlens=cu_seqlens,
1304
+ )
1305
+
1306
+ hidden_states = outputs[0]
1307
+ logits = self.lm_head(hidden_states)
1308
+
1309
+ if not return_dict:
1310
+ output = (logits,) + outputs[1:]
1311
+ return output
1312
+
1313
+ return Qwen2VLCausalLMOutputWithPast(
1314
+ logits=logits,
1315
+ past_key_values=outputs.past_key_values,
1316
+ hidden_states=outputs.hidden_states,
1317
+ attentions=outputs.attentions,
1318
+ )
1319
+
1320
+
models/modeling_tarsier.py CHANGED
@@ -1,100 +1,30 @@
1
- # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # copy and modify from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
16
- """ PyTorch Llava model."""
17
  from dataclasses import dataclass
18
- from typing import List, Optional, Tuple, Union
19
  import math
20
- import numpy as np
21
 
22
- import torch
23
  import torch.utils.checkpoint
24
  from torch import nn
25
  import torch.nn.functional as F
26
 
27
- from transformers import PreTrainedModel
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache
30
  from transformers.modeling_outputs import ModelOutput
31
- from transformers.utils import (
32
- add_start_docstrings,
33
- add_start_docstrings_to_model_forward,
34
- logging,
35
- replace_return_docstrings,
36
- )
37
- from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
38
- from transformers import LlamaForCausalLM
39
  from transformers.configuration_utils import PretrainedConfig
 
 
 
40
 
 
 
 
 
41
 
42
  logger = logging.get_logger(__name__)
43
 
44
- LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
45
- "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
46
- }
47
 
48
  class LlavaConfig(PretrainedConfig):
49
- r"""
50
- This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
51
- Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
52
- with the defaults will yield a similar configuration to that of the Llava-9B.
53
-
54
- e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
55
-
56
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
57
- documentation from [`PretrainedConfig`] for more information.
58
-
59
- Args:
60
- vision_config (`LlavaVisionConfig`, *optional*):
61
- Custom vision config or dict
62
- text_config (`Union[AutoConfig, dict]`, *optional*):
63
- The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
64
- ignore_index (`int`, *optional*, defaults to -100):
65
- The ignore index for the loss function.
66
- image_token_index (`int`, *optional*, defaults to 32000):
67
- The image token index to encode the image prompt.
68
- projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
69
- The activation function used by the multimodal projector.
70
- vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
71
- The feature selection strategy used to select the vision feature from the CLIP backbone.
72
- vision_feature_layer (`int`, *optional*, defaults to -2):
73
- The index of the layer to select the vision feature.
74
- vocab_size (`int`, *optional*, defaults to 32000):
75
- Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
76
- `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
77
-
78
- Example:
79
-
80
- ```python
81
- >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
82
-
83
- >>> # Initializing a CLIP-vision config
84
- >>> vision_config = CLIPVisionConfig()
85
-
86
- >>> # Initializing a Llama config
87
- >>> text_config = LlamaConfig()
88
-
89
- >>> # Initializing a Llava llava-1.5-7b style configuration
90
- >>> configuration = LlavaConfig(vision_config, text_config)
91
-
92
- >>> # Initializing a model from the llava-1.5-7b style configuration
93
- >>> model = LlavaForConditionalGeneration(configuration)
94
-
95
- >>> # Accessing the model configuration
96
- >>> configuration = model.config
97
- ```"""
98
 
99
  model_type = "llava"
100
  is_composition = False
@@ -108,9 +38,9 @@ class LlavaConfig(PretrainedConfig):
108
  projector_hidden_act="gelu",
109
  vision_feature_select_strategy="default",
110
  vision_feature_layer=-2,
111
- vocab_size=32000,
112
  image_newline_idx=32002,
113
  image_new_idx=32003,
 
114
  **kwargs,
115
  ):
116
  self.ignore_index = ignore_index
@@ -118,9 +48,9 @@ class LlavaConfig(PretrainedConfig):
118
  self.projector_hidden_act = projector_hidden_act
119
  self.vision_feature_select_strategy = vision_feature_select_strategy
120
  self.vision_feature_layer = vision_feature_layer
121
- self.vocab_size = vocab_size
122
  self.image_newline_idx = image_newline_idx
123
  self.image_new_idx = image_new_idx
 
124
 
125
  self.vision_config = vision_config
126
 
@@ -128,142 +58,166 @@ class LlavaConfig(PretrainedConfig):
128
  vision_config["model_type"] = (
129
  vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
130
  )
131
- self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
132
- elif vision_config is None:
133
- self.vision_config = CONFIG_MAPPING["clip_vision_model"](
134
- intermediate_size=4096,
135
- hidden_size=1024,
136
- patch_size=14,
137
- image_size=336,
138
- num_hidden_layers=24,
139
- num_attention_heads=16,
140
- vocab_size=32000,
141
- projection_dim=768,
142
- )
143
- self.vocab_size = self.vocab_size
144
-
145
  self.text_config = text_config
146
 
147
  if isinstance(self.text_config, dict):
148
  text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
149
- self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
150
- self.vocab_size = self.text_config.vocab_size
151
- elif text_config is None:
152
- self.text_config = CONFIG_MAPPING["llama"]()
 
 
 
 
 
153
 
154
  super().__init__(**kwargs)
155
 
156
 
157
- logger = logging.get_logger(__name__)
158
-
159
- _CONFIG_FOR_DOC = "LlavaConfig"
160
-
161
- LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
162
- "llava-hf/llava-1.5-7b-hf",
163
- "llava-hf/llava-1.5-13b-hf",
164
- "llava-hf/bakLlava-v1-hf",
165
- # See all Llava models at https://huggingface.co/models?filter=llava
166
- ]
167
-
168
-
169
- class Llava3DPositionalEncoding(nn.Module):
170
- def __init__(self, num_pos, dim) -> None:
171
- super().__init__()
172
- dim1, dim2, dim3 = self.split_dim(dim)
173
- frame_position_encodings = self.create_sinusoidal_positions(num_pos, dim1)
174
- height_position_encodings = self.create_sinusoidal_positions(num_pos, dim2)
175
- width_position_encodings = self.create_sinusoidal_positions(num_pos, dim3)
176
-
177
- self.register_buffer('frame_position_encodings', frame_position_encodings, persistent=False)
178
- self.register_buffer('height_position_encodings', height_position_encodings, persistent=False)
179
- self.register_buffer('width_position_encodings', width_position_encodings, persistent=False)
180
-
181
- def split_dim(self, dim):
182
- dim1 = dim // 3
183
- if dim1 % 2 != 0:
184
- dim1 -= 1
185
-
186
- dim2 = dim // 3
187
- if dim2 % 2 != 0:
188
- dim2 -= 1
189
-
190
- dim3 = dim - dim1 - dim2
191
- return dim1, dim2, dim3
192
-
193
- def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
194
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
195
- sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
196
- return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
197
-
198
- def forward(self, frame_position_ids, height_position_ids, width_position_ids):
199
- frame_position_embeds = F.embedding(frame_position_ids, self.frame_position_encodings)
200
- height_position_embeds = F.embedding(height_position_ids, self.height_position_encodings)
201
- width_position_embeds = F.embedding(width_position_ids, self.width_position_encodings)
202
-
203
- return torch.cat([frame_position_embeds, height_position_embeds, width_position_embeds], dim = -1)
204
-
205
 
206
  @dataclass
207
  # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
208
  class LlavaCausalLMOutputWithPast(ModelOutput):
209
- """
210
- Base class for Llava causal language model (or autoregressive) outputs.
211
-
212
- Args:
213
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
214
- Language modeling loss (for next-token prediction).
215
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
216
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
217
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
218
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
219
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
220
-
221
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
222
- `past_key_values` input) to speed up sequential decoding.
223
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
224
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
225
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
226
-
227
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
228
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
229
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
230
- sequence_length)`.
231
-
232
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
233
- heads.
234
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
235
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
236
- sequence_length, hidden_size)`.
237
-
238
- image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
239
- """
240
 
241
  loss: Optional[torch.FloatTensor] = None
242
  logits: torch.FloatTensor = None
243
  past_key_values: Optional[List[torch.FloatTensor]] = None
244
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
245
  attentions: Optional[Tuple[torch.FloatTensor]] = None
246
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
247
- vision_outputs: Optional[torch.FloatTensor] = None
248
- llm_attn_mask: Optional[Tuple[torch.FloatTensor]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  class LlavaMultiModalProjector(nn.Module):
252
  def __init__(self, config: LlavaConfig):
253
  super().__init__()
 
254
 
255
  self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
256
  self.act = ACT2FN[config.projector_hidden_act]
257
  self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
258
 
259
- def forward(self, image_features):
260
- hidden_states = self.linear_1(image_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  hidden_states = self.act(hidden_states)
262
  hidden_states = self.linear_2(hidden_states)
 
 
 
 
263
  return hidden_states
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- TARSIER_START_DOCSTRING = r"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
268
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
269
  etc.)
@@ -279,23 +233,17 @@ TARSIER_START_DOCSTRING = r"""
279
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
280
  """
281
 
282
-
283
- @add_start_docstrings(
284
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
285
- TARSIER_START_DOCSTRING,
286
- )
287
  class TarsierPreTrainedModel(PreTrainedModel):
288
  config_class = LlavaConfig
289
- base_model_prefix = "model"
290
- supports_gradient_checkpointing = True
291
- _no_split_modules = ["LlavaVisionAttention"]
292
  _skip_keys_device_placement = "past_key_values"
293
  _supports_flash_attn_2 = True
 
 
 
294
 
295
  def _init_weights(self, module):
296
- # important: this ported version of Llava isn't meant for training from scratch - only
297
- # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
298
- # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
299
  std = (
300
  self.config.initializer_range
301
  if hasattr(self.config, "initializer_range")
@@ -305,7 +253,7 @@ class TarsierPreTrainedModel(PreTrainedModel):
305
  if hasattr(module, "class_embedding"):
306
  module.class_embedding.data.normal_(mean=0.0, std=std)
307
 
308
- if isinstance(module, (nn.Linear, nn.Conv2d)):
309
  module.weight.data.normal_(mean=0.0, std=std)
310
  if module.bias is not None:
311
  module.bias.data.zero_()
@@ -313,98 +261,39 @@ class TarsierPreTrainedModel(PreTrainedModel):
313
  module.weight.data.normal_(mean=0.0, std=std)
314
  if module.padding_idx is not None:
315
  module.weight.data[module.padding_idx].zero_()
316
-
 
 
 
317
  @property
318
- def _supports_sdpa(self):
319
- """
320
- Retrieve language_model's attribute to check whether the model supports
321
- SDPA or not.
322
- """
323
- return self.language_model._supports_sdpa
324
-
325
-
326
- TARSIER_INPUTS_DOCSTRING = r"""
327
- Args:
328
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
329
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
330
- it.
331
-
332
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
333
- [`PreTrainedTokenizer.__call__`] for details.
334
-
335
- [What are input IDs?](../glossary#input-ids)
336
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
337
- The tensors corresponding to the input images. Pixel values can be obtained using
338
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
339
- [`CLIPImageProcessor`] for processing images).
340
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
341
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
342
-
343
- - 1 for tokens that are **not masked**,
344
- - 0 for tokens that are **masked**.
345
-
346
- [What are attention masks?](../glossary#attention-mask)
347
-
348
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
349
- [`PreTrainedTokenizer.__call__`] for details.
350
-
351
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
352
- `past_key_values`).
353
-
354
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
355
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
356
- information on the default strategy.
357
-
358
- - 1 indicates the head is **not masked**,
359
- - 0 indicates the head is **masked**.
360
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
361
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
362
- config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
363
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
364
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
365
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
366
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
367
-
368
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
369
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
370
-
371
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
372
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
373
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
374
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
375
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
376
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
377
- model's internal embedding lookup matrix.
378
- use_cache (`bool`, *optional*):
379
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
380
- `past_key_values`).
381
- output_attentions (`bool`, *optional*):
382
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
383
- tensors for more detail.
384
- output_hidden_states (`bool`, *optional*):
385
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
386
- more detail.
387
- return_dict (`bool`, *optional*):
388
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
389
- """
390
 
391
 
392
- @add_start_docstrings(
393
- """The LLAVA model which consists of a vision backbone and a language model.""",
394
- TARSIER_INPUTS_DOCSTRING,
395
- )
396
- class TarsierForConditionalGeneration(TarsierPreTrainedModel):
397
  def __init__(self, config: LlavaConfig):
398
  super().__init__(config)
399
  self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
400
- self.multi_modal_projector = LlavaMultiModalProjector(config)
401
- self.vocab_size = config.vocab_size
402
- self.language_model = AutoModelForCausalLM.from_config(config.text_config, attn_implementation="flash_attention_2")
403
- image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
404
- image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
405
- self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
406
- self.register_buffer('image_new_idx', image_new_idx, persistent=False)
407
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
 
 
 
 
 
 
 
 
 
 
 
408
  self.post_init()
409
 
410
  def get_input_embeddings(self):
@@ -432,231 +321,81 @@ class TarsierForConditionalGeneration(TarsierPreTrainedModel):
432
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
433
  # update vocab size
434
  self.config.text_config.vocab_size = model_embeds.num_embeddings
435
- self.config.vocab_size = model_embeds.num_embeddings
436
- self.vocab_size = model_embeds.num_embeddings
437
  return model_embeds
438
 
439
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
440
- num_images, num_image_patches, embed_dim = image_features.shape
441
-
442
- batch_size, sequence_length = input_ids.shape
443
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
444
- # 1. Create a mask to know where special image tokens are
445
- special_image_token_mask = input_ids == self.config.image_token_index
446
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
447
- # Compute the maximum embed dimension
448
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
449
- batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
450
-
451
- # 2. Compute the positions where text should be written
452
- # Calculate new positions for text tokens in merged image-text sequence.
453
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
454
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
455
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
456
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
457
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
458
- if left_padding:
459
- new_token_positions += nb_image_pad[:, None] # offset for left padding
460
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
461
-
462
- # 3. Create the full embedding, already padded to the maximum position
463
- final_embedding = torch.zeros(
464
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
465
- )
466
- final_attention_mask = torch.zeros(
467
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
468
- )
469
- if labels is not None:
470
- final_labels = torch.full(
471
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
472
- )
473
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
474
- # set the corresponding tensors into their correct target device.
475
- target_device = inputs_embeds.device
476
- batch_indices, non_image_indices, text_to_overwrite = (
477
- batch_indices.to(target_device),
478
- non_image_indices.to(target_device),
479
- text_to_overwrite.to(target_device),
480
- )
481
- attention_mask = attention_mask.to(target_device)
482
-
483
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
484
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
485
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
486
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
487
- if labels is not None:
488
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
489
-
490
- # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
491
- image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
492
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
493
-
494
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
495
- raise ValueError(
496
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
497
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
498
- )
499
-
500
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
501
- final_attention_mask |= image_to_overwrite
502
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
503
-
504
- if labels is None:
505
- final_labels = None
506
-
507
- return final_embedding, final_attention_mask, final_labels, position_ids
508
-
509
- def add_split_tokens(self, image_features):
510
- num_images, num_image_patches, embed_dim = image_features.shape
511
- num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
512
-
513
- # add image_newline
514
- image_newline = self.get_input_embeddings()(self.image_newline_idx).squeeze()
515
- image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
516
- image_features = torch.cat([
517
- image_features,
518
- image_newline.expand((num_images, num_height_patches, 1, embed_dim)).to(device=image_features.device)
519
- ], dim=2)
520
- num_image_patches += num_height_patches
521
- image_features = image_features.view(num_images, num_image_patches, embed_dim)
522
-
523
- # add image_new
524
- image_new = self.get_input_embeddings()(self.image_new_idx).squeeze()
525
- image_features = torch.cat([
526
- image_features,
527
- image_new.expand((num_images, 1, embed_dim)).to(device=image_features.device)
528
- ], dim = 1)
529
-
530
- return image_features
531
-
532
- @add_start_docstrings_to_model_forward(TARSIER_INPUTS_DOCSTRING)
533
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
534
  def forward(
535
  self,
536
  input_ids: torch.LongTensor = None,
537
- pixel_values: torch.FloatTensor = None,
538
  attention_mask: Optional[torch.Tensor] = None,
539
  position_ids: Optional[torch.LongTensor] = None,
 
 
540
  past_key_values: Optional[List[torch.FloatTensor]] = None,
541
- inputs_embeds: Optional[torch.FloatTensor] = None,
542
- vision_feature_layer: Optional[int] = None,
543
- vision_feature_select_strategy: Optional[str] = None,
544
  labels: Optional[torch.LongTensor] = None,
 
545
  use_cache: Optional[bool] = None,
546
  output_attentions: Optional[bool] = None,
547
  output_hidden_states: Optional[bool] = None,
548
  return_dict: Optional[bool] = None,
 
549
  **kwargs,
550
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
551
- r"""
552
- Args:
553
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
554
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
555
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
556
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
557
-
558
- Returns:
559
-
560
- Example:
561
-
562
- ```python
563
- >>> from PIL import Image
564
- >>> import requests
565
- >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
566
-
567
- >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
568
- >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
569
-
570
- >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
571
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
572
- >>> image = Image.open(requests.get(url, stream=True).raw)
573
-
574
- >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
575
-
576
- >>> # Generate
577
- >>> generate_ids = model.generate(**inputs, max_length=30)
578
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
579
- "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
580
- ```"""
581
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
582
  output_hidden_states = (
583
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
584
  )
585
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586
- vision_feature_layer = (
587
- vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
588
- )
589
- vision_feature_select_strategy = (
590
- vision_feature_select_strategy
591
- if vision_feature_select_strategy is not None
592
- else self.config.vision_feature_select_strategy
593
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
 
 
595
  image_features = None
596
- if inputs_embeds is None:
597
- # 1. Extra the input embeddings
598
- inputs_embeds = self.get_input_embeddings()(input_ids)
599
-
600
- # 2. Merge text and images
601
- if pixel_values is not None and input_ids.shape[1] != 1:
602
- pixel_values = pixel_values.to(dtype=self.vision_tower.dtype)
603
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
604
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
605
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
606
-
607
- if vision_feature_select_strategy == "default":
608
- selected_image_feature = selected_image_feature[:, 1:]
609
- elif vision_feature_select_strategy == "full":
610
- selected_image_feature = selected_image_feature
611
- else:
612
- raise ValueError(
613
- f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
614
- )
615
-
616
- image_features = self.multi_modal_projector(selected_image_feature)
617
-
618
- special_image_token_mask = input_ids == self.config.image_token_index
619
- num_special_image_tokens = torch.sum(special_image_token_mask, dim = -1)
620
-
621
- image_features = self.add_split_tokens(image_features)
622
-
623
- if sum(num_special_image_tokens) > 0:
624
- # print(f'num_special_image_tokens: {num_special_image_tokens}')
625
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
626
- image_features, inputs_embeds, input_ids, attention_mask, labels
627
- )
628
- else:
629
- inputs_embeds = image_features.sum(dim=(0,1))[None, None, :] * 0. + inputs_embeds
630
-
631
- if labels is None:
632
- labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
633
  else:
634
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
635
- # generation with cache
636
- if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
637
- # Retrieve the first layer to inspect the logits and mask out the hidden states
638
- # that are set to 0
639
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
640
-
641
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
642
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
643
-
644
- # Get the target length
645
- target_seqlen = first_layer_past_key_value.shape[-1] + 1
646
- extended_attention_mask = torch.ones(
647
- (attention_mask.shape[0], target_seqlen),
648
- dtype=attention_mask.dtype,
649
- device=attention_mask.device,
650
- )
651
-
652
- extended_attention_mask[batch_index, non_attended_tokens] = 0
653
-
654
- valid_indices = torch.ones_like(attention_mask)
655
- valid_indices[:, 0] = target_seqlen - extended_attention_mask.sum(dim=-1)
656
- valid_indices = torch.cumsum(valid_indices, dim=-1)
657
- extended_attention_mask = extended_attention_mask.scatter(1, valid_indices, attention_mask)
658
- attention_mask = extended_attention_mask
659
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
660
  outputs = self.language_model(
661
  attention_mask=attention_mask,
662
  position_ids=position_ids,
@@ -665,27 +404,35 @@ class TarsierForConditionalGeneration(TarsierPreTrainedModel):
665
  use_cache=use_cache,
666
  output_attentions=output_attentions,
667
  output_hidden_states=output_hidden_states,
668
- # use_rmpad=kwargs.get("use_rmpad", False),
669
  return_dict=return_dict,
 
 
670
  )
671
 
672
  logits = outputs[0]
673
 
674
  loss = None
675
  if labels is not None:
676
- # Shift so that tokens < n predict n
677
- if attention_mask is not None:
678
- shift_attention_mask = attention_mask[..., 1:]
679
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
680
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
 
 
 
681
  else:
 
682
  shift_logits = logits[..., :-1, :].contiguous()
683
  shift_labels = labels[..., 1:].contiguous()
684
- # Flatten the tokens
685
- loss_fct = nn.CrossEntropyLoss()
686
- loss = loss_fct(
687
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
688
- )
 
 
 
689
 
690
  if not return_dict:
691
  output = (logits,) + outputs[1:]
@@ -697,61 +444,59 @@ class TarsierForConditionalGeneration(TarsierPreTrainedModel):
697
  past_key_values=outputs.past_key_values,
698
  hidden_states=outputs.hidden_states,
699
  attentions=outputs.attentions,
700
- llm_attn_mask=attention_mask
701
  )
702
 
703
  def prepare_inputs_for_generation(
704
- self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
 
 
 
 
 
 
 
 
 
705
  ):
706
  if past_key_values is not None:
707
- if isinstance(past_key_values, Cache):
708
- cache_length = past_key_values.get_seq_length()
709
- past_length = past_key_values.seen_tokens
710
- else:
711
- cache_length = past_length = past_key_values[0][0].shape[2]
712
-
713
- # Keep only the unprocessed tokens:
714
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
715
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
716
- # input)
717
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
718
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
719
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
720
- # input_ids based on the past_length.
721
- elif past_length < input_ids.shape[1]:
722
- input_ids = input_ids[:, past_length:]
723
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
724
- elif self.config.image_token_index in input_ids:
725
- input_ids = input_ids[:, input_ids.shape[1] - 1 :]
726
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
727
- # older attention values, as their corresponding values are not part of the input.
728
- if cache_length < past_length and attention_mask is not None:
729
- attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
730
-
731
- position_ids = kwargs.get("position_ids", None)
732
- if attention_mask is not None and position_ids is None:
733
- # create position_ids on the fly for batch generation
734
- position_ids = attention_mask.long().cumsum(-1) - 1
735
- position_ids.masked_fill_(attention_mask == 0, 1)
736
- if past_key_values:
737
- position_ids = position_ids[:, -input_ids.shape[1] :]
738
-
739
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
740
- if inputs_embeds is not None and past_key_values is None:
741
- model_inputs = {"inputs_embeds": inputs_embeds}
742
  else:
743
- model_inputs = {"input_ids": input_ids}
744
-
745
- model_inputs.update(
746
- {
747
- "position_ids": position_ids,
748
- "past_key_values": past_key_values,
749
- "use_cache": kwargs.get("use_cache"),
750
- "attention_mask": attention_mask,
751
- "pixel_values": pixel_values,
752
- }
753
- )
754
  return model_inputs
755
 
756
- def _reorder_cache(self, *args, **kwargs):
757
- return self.language_model._reorder_cache(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union, Dict, Any
3
  import math
 
4
 
 
5
  import torch.utils.checkpoint
6
  from torch import nn
7
  import torch.nn.functional as F
8
 
9
+ from transformers import PreTrainedModel, AutoConfig, AutoModel
10
  from transformers.activations import ACT2FN
11
  from transformers.cache_utils import Cache
12
  from transformers.modeling_outputs import ModelOutput
13
+ from transformers.utils import logging
 
 
 
 
 
 
 
14
  from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
16
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
17
+ from transformers.generation import GenerationMixin
18
 
19
+ from transformers import LlamaForCausalLM, Qwen2ForCausalLM
20
+ # from models.modeling_qwen2 import Qwen2ForCausalLM
21
+ from models.modeling_qwen2_vl_fast import Qwen2VLForCausalLM
22
+ from models.utils import _pad_input, _unpad_input
23
 
24
  logger = logging.get_logger(__name__)
25
 
 
 
 
26
 
27
  class LlavaConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  model_type = "llava"
30
  is_composition = False
 
38
  projector_hidden_act="gelu",
39
  vision_feature_select_strategy="default",
40
  vision_feature_layer=-2,
 
41
  image_newline_idx=32002,
42
  image_new_idx=32003,
43
+ projection_head="MLP",
44
  **kwargs,
45
  ):
46
  self.ignore_index = ignore_index
 
48
  self.projector_hidden_act = projector_hidden_act
49
  self.vision_feature_select_strategy = vision_feature_select_strategy
50
  self.vision_feature_layer = vision_feature_layer
 
51
  self.image_newline_idx = image_newline_idx
52
  self.image_new_idx = image_new_idx
53
+ self.projection_head = projection_head
54
 
55
  self.vision_config = vision_config
56
 
 
58
  vision_config["model_type"] = (
59
  vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
60
  )
61
+ if 'auto_map' in vision_config:
62
+ repo_id, class_ref = vision_config['auto_map']['AutoConfig'].split("--")
63
+ config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
64
+ self.vision_config = config_class(**vision_config)
65
+ elif vision_config["model_type"] in CONFIG_MAPPING:
66
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
67
+ else:
68
+ raise ValueError(f'vision_config["model_type"] = {vision_config["model_type"]} not supported!')
69
+
 
 
 
 
 
70
  self.text_config = text_config
71
 
72
  if isinstance(self.text_config, dict):
73
  text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
74
+ if 'auto_map' in text_config:
75
+ repo_id, class_ref = text_config['auto_map']['AutoConfig'].split("--")
76
+ config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
77
+ self.text_config = config_class(**text_config)
78
+ elif text_config["model_type"] in CONFIG_MAPPING:
79
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
80
+ else:
81
+ raise ValueError(f'text_config["model_type"] = {text_config["model_type"]} not supported!')
82
+
83
 
84
  super().__init__(**kwargs)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  @dataclass
89
  # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
90
  class LlavaCausalLMOutputWithPast(ModelOutput):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  loss: Optional[torch.FloatTensor] = None
93
  logits: torch.FloatTensor = None
94
  past_key_values: Optional[List[torch.FloatTensor]] = None
95
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
96
  attentions: Optional[Tuple[torch.FloatTensor]] = None
97
+ position_ids: Optional[torch.LongTensor] = None
98
+
99
+ def add_split_tokens(image_features, image_newline_embed, image_new_embed):
100
+ num_images, num_image_patches, embed_dim = image_features.shape
101
+ num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
102
+
103
+ # add image_newline
104
+ image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
105
+ image_features = torch.cat([
106
+ image_features,
107
+ image_newline_embed.expand((num_images, num_height_patches, 1, embed_dim))
108
+ ], dim=2)
109
+ num_image_patches += num_height_patches
110
+ image_features = image_features.view(num_images, num_image_patches, embed_dim)
111
+
112
+ # add image_new
113
+ image_features = torch.cat([
114
+ image_features,
115
+ image_new_embed.expand((num_images, 1, embed_dim))
116
+ ], dim = 1)
117
+
118
+ return image_features
119
 
120
 
121
  class LlavaMultiModalProjector(nn.Module):
122
  def __init__(self, config: LlavaConfig):
123
  super().__init__()
124
+ self.config = config
125
 
126
  self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
127
  self.act = ACT2FN[config.projector_hidden_act]
128
  self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
129
 
130
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
131
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
132
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
133
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
134
+
135
+
136
+ def forward(self, image_features, input_embeddings):
137
+
138
+ selected_image_feature = image_features[self.config.vision_feature_layer]
139
+
140
+ if self.config.vision_feature_select_strategy == "default":
141
+ selected_image_feature = selected_image_feature[:, 1:]
142
+ elif self.config.vision_feature_select_strategy == "full":
143
+ selected_image_feature = selected_image_feature
144
+ else:
145
+ raise ValueError(
146
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
147
+ )
148
+
149
+ hidden_states = self.linear_1(selected_image_feature)
150
  hidden_states = self.act(hidden_states)
151
  hidden_states = self.linear_2(hidden_states)
152
+
153
+ image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
154
+ image_new_embed = input_embeddings(self.image_new_idx).squeeze()
155
+ hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
156
  return hidden_states
157
 
158
+ class PixelShuffleMultiModalProjector(nn.Module):
159
+ def __init__(self, config: LlavaConfig):
160
+ super().__init__()
161
+ self.config = config
162
+
163
+ self.downsample_ratio = 0.5
164
+ vit_hidden_size = config.vision_config.hidden_size
165
+ llm_hidden_size = config.text_config.hidden_size
166
+
167
+ self.mlp = nn.Sequential(
168
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
169
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(llm_hidden_size, llm_hidden_size)
172
+ )
173
 
174
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
175
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
176
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
177
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
178
+
179
+ def forward(self, image_features, input_embeddings):
180
+ selected_image_feature = image_features[self.config.vision_feature_layer]
181
+
182
+ if self.config.vision_feature_select_strategy == "default":
183
+ selected_image_feature = selected_image_feature[:, 1:]
184
+ elif self.config.vision_feature_select_strategy == "full":
185
+ selected_image_feature = selected_image_feature
186
+ else:
187
+ raise ValueError(
188
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
189
+ )
190
+
191
+ image_features = self.pixel_shuffle(selected_image_feature)
192
+ hidden_states = self.mlp(image_features)
193
+
194
+ image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
195
+ image_new_embed = input_embeddings(self.image_new_idx).squeeze()
196
+ hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
197
+
198
+ return hidden_states
199
+
200
+ def pixel_shuffle(self, x, scale_factor=0.5):
201
+ if scale_factor == 1:
202
+ return x
203
+ n, wh, c = x.shape
204
+ h, w = int(math.sqrt(wh)), int(math.sqrt(wh))
205
+ x = x.view(n, h, w, c)
206
+
207
+ n, w, h, c = x.size()
208
+ # N, W, H, C --> N, W, H * scale, C // scale
209
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
210
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
211
+ x = x.permute(0, 2, 1, 3).contiguous()
212
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
213
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
214
+ int(c / (scale_factor * scale_factor)))
215
+ x = x.permute(0, 2, 1, 3).contiguous()
216
+ x = x.view(x.shape[0], -1, x.shape[-1])
217
+ return x
218
+
219
+
220
+ LLAVA_START_DOCSTRING = r"""
221
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
222
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
223
  etc.)
 
233
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
234
  """
235
 
 
 
 
 
 
236
  class TarsierPreTrainedModel(PreTrainedModel):
237
  config_class = LlavaConfig
238
+ base_model_prefix = "llm"
239
+ supports_gradient_checkpointing = True # TODO: support latest gc
 
240
  _skip_keys_device_placement = "past_key_values"
241
  _supports_flash_attn_2 = True
242
+ _supports_sdpa = False
243
+ _supports_cache_class = True # TODO: support different cache
244
+ _supports_static_cache = True
245
 
246
  def _init_weights(self, module):
 
 
 
247
  std = (
248
  self.config.initializer_range
249
  if hasattr(self.config, "initializer_range")
 
253
  if hasattr(module, "class_embedding"):
254
  module.class_embedding.data.normal_(mean=0.0, std=std)
255
 
256
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
257
  module.weight.data.normal_(mean=0.0, std=std)
258
  if module.bias is not None:
259
  module.bias.data.zero_()
 
261
  module.weight.data.normal_(mean=0.0, std=std)
262
  if module.padding_idx is not None:
263
  module.weight.data[module.padding_idx].zero_()
264
+ elif isinstance(module, nn.LayerNorm):
265
+ module.weight.data.fill_(1.0)
266
+ if module.bias is not None:
267
+ module.bias.data.zero_()
268
  @property
269
+ def _no_split_modules(self):
270
+ return self.language_model._no_split_modules + self.vision_tower._no_split_modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
+ class TarsierForConditionalGeneration(TarsierPreTrainedModel, GenerationMixin):
 
 
 
 
274
  def __init__(self, config: LlavaConfig):
275
  super().__init__(config)
276
  self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
277
+ if config.text_config.model_type == 'qwen2':
278
+ self.language_model = Qwen2ForCausalLM(config.text_config)
279
+ elif config.text_config.model_type == 'qwen2_vl':
280
+ self.language_model = Qwen2VLForCausalLM(config.text_config)
281
+ elif config.text_config.model_type == 'llama':
282
+ self.language_model = LlamaForCausalLM(config.text_config)
283
+ else:
284
+ raise ValueError(f'{config.text_config.model_type} not supported!')
285
+
286
+ if config.projection_head == 'Pixel_Shuffle':
287
+ self.multi_modal_projector = PixelShuffleMultiModalProjector(config)
288
+ elif config.projection_head == 'MLP':
289
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
290
+ elif config.projection_head == 'auto_map':
291
+ repo_id, class_ref = config.auto_map['ProjectionLayer'].split("--")
292
+ model_class = get_class_from_dynamic_module(class_ref, repo_id)
293
+ self.multi_modal_projector = model_class(config)
294
+ elif config.projection_head is None:
295
+ self.multi_modal_projector = lambda x, *args, **kwargs: x
296
+
297
  self.post_init()
298
 
299
  def get_input_embeddings(self):
 
321
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
322
  # update vocab size
323
  self.config.text_config.vocab_size = model_embeds.num_embeddings
 
 
324
  return model_embeds
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def forward(
327
  self,
328
  input_ids: torch.LongTensor = None,
 
329
  attention_mask: Optional[torch.Tensor] = None,
330
  position_ids: Optional[torch.LongTensor] = None,
331
+ pixel_values: torch.FloatTensor = None,
332
+ image_grid_thw: Optional[torch.Tensor] = None,
333
  past_key_values: Optional[List[torch.FloatTensor]] = None,
 
 
 
334
  labels: Optional[torch.LongTensor] = None,
335
+ num_images: Optional[torch.Tensor] = None,
336
  use_cache: Optional[bool] = None,
337
  output_attentions: Optional[bool] = None,
338
  output_hidden_states: Optional[bool] = None,
339
  return_dict: Optional[bool] = None,
340
+ use_rmpad: Optional[bool] = False,
341
  **kwargs,
342
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
343
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
345
  output_hidden_states = (
346
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
347
  )
348
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
+
350
+
351
+ if input_ids is None:
352
+ raise ValueError("You must specify input_ids")
353
+
354
+ bsz, max_seq_len = input_ids.shape[0], input_ids.shape[1]
355
+
356
+ if max_seq_len > 1:
357
+ special_image_mask = input_ids == self.config.image_token_index
358
+ print(f'[{input_ids.device}] num_images: {num_images.tolist()} num_image_tokens: {special_image_mask.sum(-1).tolist()}', flush=True)
359
+
360
+ if position_ids is None:
361
+ if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
362
+ position_ids = self.language_model.get_rope_index(input_ids, image_grid_thw, attention_mask) # [bsz, seqlen, 3]
363
+ else:
364
+ position_ids = attention_mask.long().cumsum(-1) - 1 # # [bsz, seqlen]
365
+ position_ids.masked_fill_(attention_mask == 0, 1)
366
+
367
+
368
+ if use_rmpad:
369
+ input_ids, input_ids_indices, cu_seqlens, _ = _unpad_input(input_ids, attention_mask) # [bsz, seqlen] -> [1, seqlen]
370
+ position_ids, _, _, _ = _unpad_input(position_ids, attention_mask)
371
+ input_ids, position_ids = input_ids.unsqueeze(0), position_ids.unsqueeze(0)
372
+ else:
373
+ input_ids_indices, cu_seqlens = None, None
374
 
375
+ inputs_embeds = self.get_input_embeddings()(input_ids) # [1, seqlen, dim]
376
+
377
  image_features = None
378
+ if pixel_values is not None: # training / first step in generation
379
+ if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
380
+ pixel_values = pixel_values.type(self.vision_tower.get_dtype())
381
+ image_features = self.vision_tower(pixel_values, image_grid_thw)
382
+ else:
 
 
383
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
384
+ image_features = self.multi_modal_projector(
385
+ image_outputs.hidden_states,
386
+ self.get_input_embeddings(),
387
+ )
388
+
389
+ special_image_mask = input_ids == self.config.image_token_index
390
+ if special_image_mask.sum() > 0:
391
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
392
+ inputs_embeds = inputs_embeds.masked_scatter(
393
+ special_image_mask.unsqueeze(-1).expand_as(inputs_embeds),
394
+ image_features
395
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  else:
397
+ inputs_embeds = image_features.sum(dim=(0,1)) * 0. + inputs_embeds
398
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  outputs = self.language_model(
400
  attention_mask=attention_mask,
401
  position_ids=position_ids,
 
404
  use_cache=use_cache,
405
  output_attentions=output_attentions,
406
  output_hidden_states=output_hidden_states,
 
407
  return_dict=return_dict,
408
+ use_rmpad=use_rmpad,
409
+ cu_seqlens=cu_seqlens,
410
  )
411
 
412
  logits = outputs[0]
413
 
414
  loss = None
415
  if labels is not None:
416
+ loss_fct = nn.CrossEntropyLoss()
417
+ if use_rmpad:
418
+ labels = labels.view(-1)[input_ids_indices.long()]
419
+ shift_labels = torch.cat((labels[1:], labels.new_ones((1))*-100))
420
+ shift_labels.requires_grad = False
421
+ lbl_seq_lens = (cu_seqlens[1:]-1).long()
422
+ shift_labels[lbl_seq_lens] = -100
423
+ loss = loss_fct(logits.squeeze(0), shift_labels)
424
  else:
425
+ # Shift so that tokens < n predict n
426
  shift_logits = logits[..., :-1, :].contiguous()
427
  shift_labels = labels[..., 1:].contiguous()
428
+ # Flatten the tokens
429
+ shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
430
+ shift_labels = shift_labels.view(-1)
431
+ # Enable model parallelism
432
+ shift_labels = shift_labels.to(shift_logits.device)
433
+ loss = loss_fct(shift_logits, shift_labels)
434
+ elif use_rmpad: # 训练的时候,就不 unpad logits 了,节省显存。
435
+ logits = _pad_input(logits.squeeze(0), input_ids_indices, bsz, max_seq_len)
436
 
437
  if not return_dict:
438
  output = (logits,) + outputs[1:]
 
444
  past_key_values=outputs.past_key_values,
445
  hidden_states=outputs.hidden_states,
446
  attentions=outputs.attentions,
447
+ position_ids=position_ids,
448
  )
449
 
450
  def prepare_inputs_for_generation(
451
+ self,
452
+ input_ids,
453
+ attention_mask=None,
454
+ position_ids=None,
455
+ past_key_values=None,
456
+ cache_position=None,
457
+ use_cache=True,
458
+ pixel_values=None,
459
+ image_grid_thw=None,
460
+ **kwargs,
461
  ):
462
  if past_key_values is not None:
463
+ past_length = past_key_values.get_seq_length()
464
+ input_ids = input_ids[:, past_length:]
465
+
466
+ model_inputs = {
467
+ "input_ids": input_ids,
468
+ "attention_mask": attention_mask,
469
+ "past_key_values": past_key_values,
470
+ "use_cache": use_cache,
471
+ }
472
+ if kwargs.get('num_images') is not None:
473
+ model_inputs['num_images'] = kwargs['num_images']
474
+
475
+ if cache_position[0] == 0:
476
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
477
+ # Otherwise we need pixel values to be passed to model
478
+ model_inputs["pixel_values"] = pixel_values
479
+ model_inputs["image_grid_thw"] = image_grid_thw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  else:
481
+ model_inputs['position_ids'] = position_ids[:, -1, ...].unsqueeze(1).to(device=input_ids.device) + 1
 
 
 
 
 
 
 
 
 
 
482
  return model_inputs
483
 
484
+
485
+ def _update_model_kwargs_for_generation(
486
+ self,
487
+ outputs: ModelOutput,
488
+ model_kwargs: Dict[str, Any],
489
+ is_encoder_decoder: bool = False,
490
+ num_new_tokens: int = 1,
491
+ ) -> Dict[str, Any]:
492
+ model_kwargs = super()._update_model_kwargs_for_generation(
493
+ outputs=outputs,
494
+ model_kwargs=model_kwargs,
495
+ is_encoder_decoder=is_encoder_decoder,
496
+ num_new_tokens=num_new_tokens,
497
+ )
498
+
499
+ if getattr(outputs, "position_ids", None) is not None:
500
+ model_kwargs["position_ids"] = outputs.position_ids
501
+
502
+ return model_kwargs
models/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+
5
+ def _unpad_input(input_ids, attention_mask):
6
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
7
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
8
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
9
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
10
+ input_ids = rearrange(input_ids, 'b s ... -> (b s) ...')[indices]
11
+ return input_ids, indices, cu_seqlens, max_seqlen_in_batch
12
+
13
+ def _pad_input(hidden_states, indices, batch, seqlen):
14
+ output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
15
+ dtype=hidden_states.dtype)
16
+ output[indices] = hidden_states
17
+ return rearrange(output, '(b s) ... -> b s ...', b=batch)
requirements.txt CHANGED
@@ -19,5 +19,6 @@ torch==2.1.0
19
  torchvision==0.16.0
20
  torchaudio==2.1.0
21
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
22
- transformers==4.44.2
23
  triton==2.1.0
 
 
19
  torchvision==0.16.0
20
  torchaudio==2.1.0
21
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
22
+ transformers==4.47.0
23
  triton==2.1.0
24
+ func_timeout==4.3.5
tools/color.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ class Color:
15
+
16
+ @staticmethod
17
+ def red(x):
18
+ return '\33[31m' +x + '\033[0m'
19
+
20
+ @staticmethod
21
+ def green(x):
22
+ return '\33[32m' +x + '\033[0m'
23
+
24
+ @staticmethod
25
+ def yellow(x):
26
+ return '\33[33m' +x + '\033[0m'
27
+
28
+ @staticmethod
29
+ def blue(x):
30
+ return '\33[34m' +x + '\033[0m'
31
+
32
+ @staticmethod
33
+ def violet(x):
34
+ return '\33[35m' +x + '\033[0m'
35
+
36
+
tools/conversation.py CHANGED
@@ -16,12 +16,43 @@
16
  from PIL import Image
17
  import torch
18
  from transformers import StoppingCriteria, StoppingCriteriaList
 
 
 
19
 
20
  from enum import auto, Enum
21
  import os
22
- from dataset.processor import Processor
23
  import re
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  IMAGE_TOKEN = "<image>"
27
  VIDEO_TOKEN = "<video>"
@@ -31,24 +62,48 @@ class SeparatorStyle(Enum):
31
  SINGLE = auto()
32
  TWO = auto()
33
 
34
- def get_prompt(conv):
35
- ret = ""
36
- if conv.system:
37
- ret = conv.system + conv.sep1
38
  for i, (role, message) in enumerate(conv.messages):
39
  if message:
40
- # In current version, the image should be add at the first conversation round.
41
- # So we need to remove the special image tokens in following user input.
42
- if i > 0:
43
- message = re.sub(f"({IMAGE_TOKEN}|{VIDEO_TOKEN})\n*", "", message)
44
- ret += role + ": " + message
45
- if i % 2:
46
- ret += conv.sep2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
- ret += conv.sep1
49
- else:
50
- ret += role + ": "
51
- return ret
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  class StoppingCriteriaSub(StoppingCriteria):
@@ -64,53 +119,36 @@ class StoppingCriteriaSub(StoppingCriteria):
64
 
65
 
66
  class Chat:
67
- def __init__(self, model, processor: Processor, device='cuda', debug=False):
68
  self.model = model
69
  self.processor = processor
70
  self.device = device
71
  self.debug = debug
72
- stop_words_ids = [torch.tensor([self.processor.tokenizer.eos_token_id]).to(device)]
73
  self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
74
 
75
  def ask(self,text,conv):
76
- conv.messages.append([conv.roles[0], text])
77
  return conv
78
 
79
- def prepare_model_inputs(self, conv, visual_data_file=None, images=None, n_frames=None):
80
- conv.messages.append([conv.roles[1], None])
81
- print(conv.messages)
82
- conv.messages[0][1] = re.sub(f"({IMAGE_TOKEN}|{VIDEO_TOKEN})\n*", "", conv.messages[0][1])
83
-
84
- if images is None or isinstance(images, list) and len(images) == 0:
85
- if isinstance(visual_data_file, str) and os.path.exists(visual_data_file):
86
- images = self.processor.load_images(visual_data_file, n_frames)
87
- elif isinstance(visual_data_file, Image.Image):
88
- images = [visual_data_file]
89
- elif visual_data_file is None or visual_data_file == "":
90
- images = None
91
- else:
92
- raise NotImplementedError
93
-
94
- # os.system("rm tmp_images/*")
95
- # for i, img in enumerate(images):
96
- # img.save(f"tmp_images/{i+1}.jpg")
97
-
98
- if isinstance(images, list) and len(images) > 0:
99
- conv.messages[0][1] = IMAGE_TOKEN*len(images) + '\n' + conv.messages[0][1]
100
-
101
- prompt = get_prompt(conv)
102
  if self.debug:
103
- print(f"visual_data_file: {visual_data_file}")
104
- print(f"Prompt: {prompt}", flush=True)
105
-
106
- inputs = self.processor(prompt, images=images, edit_prompt=False, return_prompt=False)
107
- # print(self.processor.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))
108
- inputs = {k:v.to(self.device) for k,v in inputs.items() if v is not None}
109
- return inputs, conv, images
110
-
111
- def answer(self, conv, visual_data_file=None, images=None, n_frames=None, max_new_tokens=256, num_beams=1, min_length=1, top_p=1.0,
 
 
 
112
  repetition_penalty=1.0, length_penalty=1, temperature=0):
113
- inputs, conv, images = self.prepare_model_inputs(conv, visual_data_file, images, n_frames)
114
  if self.model is not None:
115
  outputs = self.model.generate(
116
  **inputs,
@@ -124,11 +162,13 @@ class Chat:
124
  length_penalty=length_penalty,
125
  temperature=temperature,
126
  )
127
- output_text = self.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
128
  else:
129
  output_text = "Fake respone as launched in debug mode!"
130
- conv.messages[-1][1] = output_text
131
- return output_text, conv, images
 
 
132
 
133
  class EasyDict(dict):
134
  """
@@ -204,19 +244,13 @@ conv_tarsier_yi = EasyDict({
204
  }
205
  )
206
 
207
- conv_tarsier_qwen2 = EasyDict({
208
  "system": "",
209
- "roles": ("USER", "ASSISTANT"),
210
  "messages": [],
211
- "sep1": " ",
212
- "sep2": "<|endoftext|>",
213
  }
214
  )
215
 
216
  conv_templates = {
217
- "tarsier-7b": conv_tarsier,
218
- "tarsier-13b": conv_tarsier,
219
- "tarsier-34b": conv_tarsier_yi,
220
- "tarsier2-7b": conv_tarsier_qwen2
221
  }
222
-
 
16
  from PIL import Image
17
  import torch
18
  from transformers import StoppingCriteria, StoppingCriteriaList
19
+ from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
20
+ from dataset.tarsier_datamodule import TarsierDataProcessor
21
+ from dataset.utils import *
22
 
23
  from enum import auto, Enum
24
  import os
 
25
  import re
26
 
27
+ data_dict_tmp = {
28
+ "messages": [
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {
33
+ "type": "video",
34
+ "video": {
35
+ "video_file": "/mnt/hdfs/vlm/videos/movies_aligned_0523/tt8266310/tt8266310_1.50.24-1.50.29.mp4"}
36
+ },
37
+ {
38
+ "type": "text",
39
+ "text": "Describe the video in detail."
40
+ }
41
+ ]
42
+ },
43
+ {
44
+ "role": "assistant",
45
+ "content": [
46
+ {
47
+ "type": "text",
48
+ "text": "A man in the driver's seat, wearing a black jacket with a maroon shirt, fastens his seatbelt while smiling at the man in the passenger seat, who is adjusting his position. The passenger, also wearing a black jacket with a maroon shirt, turns to look forward and smiles. The driver then leans forward to start the car and leans back in his seat. In the background, a beige car is visible through the window."
49
+ }]}
50
+ ],
51
+ "dataset": "video_caption",
52
+ "task": "video/caption",
53
+ "idx": 0,
54
+ }
55
+
56
 
57
  IMAGE_TOKEN = "<image>"
58
  VIDEO_TOKEN = "<video>"
 
62
  SINGLE = auto()
63
  TWO = auto()
64
 
65
+ def get_data_dict(conv, max_n_frames=None):
66
+ data_dict = {
67
+ "messages": []
68
+ }
69
  for i, (role, message) in enumerate(conv.messages):
70
  if message:
71
+ text = message["text"]
72
+ content_type = message["type"]
73
+ content = {}
74
+ if content_type == "text":
75
+ content['type'] = 'text'
76
+ content['text'] = text
77
+ task = "text-only"
78
+ elif content_type == "video":
79
+ content['type'] = 'video'
80
+ content['video'] = {
81
+ "video_file": text
82
+ }
83
+ if max_n_frames is not None:
84
+ content['video']['n_frames'] = max_n_frames
85
+ task = "video/QA"
86
+ elif content_type == "image":
87
+ content['type'] = 'image'
88
+ content['image'] = {
89
+ "image_file": text
90
+ }
91
+ task = "image/QA"
92
  else:
93
+ content['type'] = 'text'
94
+ content['text'] = text
95
+ task = "text-only"
96
+ if data_dict['messages'] and data_dict['messages'][-1]['role'] == role:
97
+ data_dict['messages'][-1]['content'].append(content)
98
+ else:
99
+ data_dict['messages'].append({
100
+ "role": role,
101
+ "content": [content]
102
+ })
103
+ data_dict['dataset'] = task
104
+ data_dict['task'] = task
105
+ check_data_format(data_dict)
106
+ return data_dict
107
 
108
 
109
  class StoppingCriteriaSub(StoppingCriteria):
 
119
 
120
 
121
  class Chat:
122
+ def __init__(self, model, processor: TarsierDataProcessor, device='cuda', debug=False):
123
  self.model = model
124
  self.processor = processor
125
  self.device = device
126
  self.debug = debug
127
+ stop_words_ids = [torch.tensor([self.processor.processor.tokenizer.eos_token_id]).to(device)]
128
  self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
 
130
  def ask(self,text,conv):
131
+ conv.messages.append([conv.roles[0], {"text": text, "type": "text"}])
132
  return conv
133
 
134
+ def prepare_model_inputs(self, conv, n_frames=None):
135
+ # print(conv.messages)
136
+ data_dict = get_data_dict(conv, n_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  if self.debug:
138
+ # print(f"visual_data_file: {visual_data_file}", flush=True)
139
+ print(f"###Prompt:\n{get_prompt_from_data_dict(data_dict)}")
140
+
141
+ batch_data = self.processor(data_dict)
142
+ model_inputs = {}
143
+ for k, v in batch_data.items():
144
+ if not isinstance(v, torch.Tensor):
145
+ continue
146
+ model_inputs[k] = v.to(self.device)
147
+ return model_inputs, conv
148
+
149
+ def answer(self, conv, n_frames=None, max_new_tokens=256, num_beams=1, min_length=1, top_p=1.0,
150
  repetition_penalty=1.0, length_penalty=1, temperature=0):
151
+ inputs, conv = self.prepare_model_inputs(conv, n_frames)
152
  if self.model is not None:
153
  outputs = self.model.generate(
154
  **inputs,
 
162
  length_penalty=length_penalty,
163
  temperature=temperature,
164
  )
165
+ output_text = self.processor.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
166
  else:
167
  output_text = "Fake respone as launched in debug mode!"
168
+ conv.messages.append(
169
+ [conv.roles[1], {"text": output_text, "type": "text"}]
170
+ )
171
+ return output_text, conv
172
 
173
  class EasyDict(dict):
174
  """
 
244
  }
245
  )
246
 
247
+ conv_tarsier_qwen2_vl = EasyDict({
248
  "system": "",
249
+ "roles": ("user", "assistant"),
250
  "messages": [],
 
 
251
  }
252
  )
253
 
254
  conv_templates = {
255
+ "tarsier2-7b": conv_tarsier_qwen2_vl
 
 
 
256
  }
 
tools/rw_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ from json import JSONEncoder
16
+ import numpy
17
+ import pandas as pd
18
+
19
+ class NumpyArrayEncoder(JSONEncoder):
20
+ def default(self, obj):
21
+ if isinstance(obj, numpy.ndarray):
22
+ return obj.tolist()
23
+ return JSONEncoder.default(self, obj)
24
+
25
+ def write_txt(data, path):
26
+ with open(path, 'w', encoding='utf-8')as f:
27
+ for d in data:
28
+ f.write(f'{d}\n')
29
+
30
+ def read_txt(path):
31
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
32
+ lines = [l.strip('\n') for l in f.readlines()]
33
+ return lines
34
+
35
+ def read_jsonlines(path):
36
+ objs = []
37
+ with open(path) as f:
38
+ for line in f:
39
+ line = json.loads(line)
40
+ objs.append(line)
41
+ return objs
42
+
43
+ def write_jsonlines(data, path, cls=None, ensure_ascii=False):
44
+ with open(path, 'w') as f:
45
+ for d in data:
46
+ d = json.dumps(d, ensure_ascii=ensure_ascii, cls=cls)
47
+ f.write(d)
48
+ f.write('\n')
49
+
50
+ def read_parquet(path):
51
+ data = pd.read_parquet(path)
52
+ return data.to_dict('records')
53
+
54
+ def write_parquet(data, path):
55
+ data = pd.DataFrame(data)
56
+ data.to_parquet(path)
57
+
58
+ def read_csv(path):
59
+ data = pd.read_csv(path)
60
+ return data.to_dict(orient='records')
61
+
62
+ def write_csv(data, path):
63
+ data = pd.DataFrame(data)
64
+ data.to_csv(path, index=False, sep='\t')
tools/utils.py CHANGED
@@ -12,46 +12,21 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from models.modeling_tarsier import TarsierForConditionalGeneration, LlavaConfig
15
- from dataset.processor import Processor
 
16
  import torch
17
  import base64
 
 
18
  import os
19
 
20
  HF_TOKEN = os.environ.get('HF_TOKEN', '')
21
 
22
- class Color:
23
-
24
- @staticmethod
25
- def red(x):
26
- return '\33[31m' +x + '\033[0m'
27
-
28
- @staticmethod
29
- def green(x):
30
- return '\33[32m' +x + '\033[0m'
31
-
32
- @staticmethod
33
- def yellow(x):
34
- return '\33[33m' +x + '\033[0m'
35
-
36
- @staticmethod
37
- def blue(x):
38
- return '\33[34m' +x + '\033[0m'
39
-
40
- @staticmethod
41
- def violet(x):
42
- return '\33[35m' +x + '\033[0m'
43
-
44
- def file_to_base64(img_path):
45
- with open(img_path, 'rb') as video_file:
46
- video_b64_str = base64.b64encode(video_file.read()).decode()
47
- return video_b64_str
48
-
49
- def load_model_and_processor(model_name_or_path, max_n_frames=8):
50
- print(Color.red(f"Load model and processor from: {model_name_or_path}; with max_n_frames={max_n_frames}"), flush=True)
51
- processor = Processor(
52
- model_name_or_path,
53
- max_n_frames=max_n_frames,
54
- )
55
  model_config = LlavaConfig.from_pretrained(
56
  model_name_or_path,
57
  trust_remote_code=True,
@@ -68,3 +43,8 @@ def load_model_and_processor(model_name_or_path, max_n_frames=8):
68
  model.eval()
69
  return model, processor
70
 
 
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from models.modeling_tarsier import TarsierForConditionalGeneration, LlavaConfig
15
+ # from dataset.processor import Processor
16
+ from dataset.tarsier_datamodule import init_processor
17
  import torch
18
  import base64
19
+ from tools.color import Color
20
+ import yaml
21
  import os
22
 
23
  HF_TOKEN = os.environ.get('HF_TOKEN', '')
24
 
25
+ def load_model_and_processor(model_name_or_path, data_config):
26
+ print(Color.red(f"Load model and processor from: {model_name_or_path}"), flush=True)
27
+ if isinstance(data_config, str):
28
+ data_config = yaml.safe_load(open(data_config, 'r'))
29
+ processor = init_processor(model_name_or_path, data_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  model_config = LlavaConfig.from_pretrained(
31
  model_name_or_path,
32
  trust_remote_code=True,
 
43
  model.eval()
44
  return model, processor
45
 
46
+ def file_to_base64(img_path):
47
+ with open(img_path, 'rb') as video_file:
48
+ video_b64_str = base64.b64encode(video_file.read()).decode()
49
+ return video_b64_str
50
+