mx262 commited on
Commit
59a5edb
·
verified ·
1 Parent(s): 0624b31

Upload internvl_chat.py

Browse files
Files changed (1) hide show
  1. internvl_chat.py +318 -37
internvl_chat.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor
3
  import warnings
4
  from PIL import Image
5
  from .base import BaseModel
@@ -7,11 +7,13 @@ from ..smp import *
7
  from ..dataset import DATASET_TYPE
8
  import pandas as pd
9
  import string
 
10
  import torchvision.transforms as T
11
  import transformers
12
 
13
  from torchvision.transforms.functional import InterpolationMode
14
- import random
 
15
 
16
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
17
  IMAGENET_STD = (0.229, 0.224, 0.225)
@@ -143,35 +145,94 @@ def load_image2(image_file, input_size=448, target_aspect_ratio=(1,1), min_num=1
143
  pixel_values = torch.stack(pixel_values)
144
  return pixel_values
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  class InternVLChat(BaseModel):
147
 
148
  INSTALL_REQ = False
149
- INTERLEAVE = False
150
 
151
- def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False, **kwargs):
152
  assert model_path is not None
153
  assert version_cmp(transformers.__version__, '4.36.2', 'ge')
 
154
  self.model_path = model_path
155
  self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
156
- device = torch.cuda.current_device()
157
- self.device = device
158
- self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
159
- trust_remote_code=True,
160
- load_in_8bit=load_in_8bit).eval()
161
- if not load_in_8bit:
162
- self.model = self.model.to(device)
163
- self.image_size = self.model.config.vision_config.image_size
164
 
165
- if 'V1-1' in model_path:
166
- kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
- kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
169
- kwargs_default.update(kwargs)
170
- self.kwargs = kwargs_default
 
 
 
 
 
 
 
 
 
 
171
  warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
172
 
173
  def use_custom_prompt(self, dataset):
174
- return True
 
 
 
 
 
175
 
176
  def build_multi_choice_prompt(self, line, dataset=None):
177
  question = line['question']
@@ -196,28 +257,41 @@ class InternVLChat(BaseModel):
196
 
197
  return prompt
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def build_prompt(self, line, dataset=None):
200
  assert self.use_custom_prompt(dataset)
201
  assert dataset is None or isinstance(dataset, str)
202
  tgt_path = self.dump_image(line, dataset)
203
 
204
- if 'V1-1' in self.model_path:
205
  kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
206
  else:
207
  kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
208
  self.kwargs = kwargs_default
 
209
  if dataset is not None and listinstr(['MME'], dataset):
210
  question = line['question']
211
  prompt = question + ' Answer the question using a single word or phrase.'
212
- if 'V1-2' not in self.model_path:
213
- self.kwargs = dict(do_sample=True, max_new_tokens=5, top_k=50, num_beams=5, top_p=0.9)
214
  elif dataset is not None and listinstr(['HallusionBench'], dataset):
215
  question = line['question']
216
  prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
217
- elif dataset is not None and DATASET_TYPE(dataset) == 'multi-choice':
218
  prompt = self.build_multi_choice_prompt(line, dataset)
219
  elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
220
- if 'MathVista' in dataset:
221
  prompt = line['question']
222
  elif listinstr(['LLaVABench'], dataset):
223
  question = line['question']
@@ -229,14 +303,11 @@ class InternVLChat(BaseModel):
229
  prompt = question + '\nAnswer the question using a single word or phrase.'
230
  else:
231
  prompt = line['question']
232
-
233
  message = [dict(type='text', value=prompt)]
234
  message.extend([dict(type='image', value=s) for s in tgt_path])
235
-
236
  return message
237
 
238
- def generate(self, message, dataset=None):
239
- prompt, image_path = self.message_to_promptimg(message)
240
  if dataset is not None and listinstr(['ChartQA_TEST'], dataset):
241
  self.max_num = 12
242
  self.max_num2 = 3
@@ -245,33 +316,243 @@ class InternVLChat(BaseModel):
245
  self.max_num2 = 15
246
  self.min_num = 14
247
  self.min_num2 = 5
248
- elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST'], dataset):
249
  self.max_num = 23
250
  self.max_num2 = 5
251
  self.min_num = 15
252
  self.min_num2 = 3
253
- elif dataset is not None and listinstr(['OCRBench'], dataset):
254
  self.max_num = 24
255
  self.max_num2 = 8
256
  self.min_num = 9
257
  self.min_num2 = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  else:
259
  self.max_num = 8
260
  self.max_num2 = 4
261
  self.min_num = 3
262
  self.min_num2 = 1
263
- pixel_values, target_aspect_ratio = load_image(image_path, min_num=self.min_num, max_num=self.max_num)
264
- pixel_values = pixel_values.cuda().to(torch.bfloat16)
265
- pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=self.min_num2, max_num=self.max_num2)
266
- pixel_values2 = pixel_values2.cuda().to(torch.bfloat16)
267
- pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0)
268
 
 
 
 
 
 
 
 
 
269
  with torch.no_grad():
270
- response = self.model.chat(self.tokenizer, pixel_values=pixel_values, target_aspect_ratio=target_aspect_ratio,
271
  question=prompt, generation_config=self.kwargs)
272
- response = response.split('[UNUSED_TOKEN_145]')[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  return response
275
 
276
  def generate_inner(self, message, dataset=None):
277
- return self.generate(message, dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoConfig, AutoModel, CLIPImageProcessor
3
  import warnings
4
  from PIL import Image
5
  from .base import BaseModel
 
7
  from ..dataset import DATASET_TYPE
8
  import pandas as pd
9
  import string
10
+ import torch.distributed as dist
11
  import torchvision.transforms as T
12
  import transformers
13
 
14
  from torchvision.transforms.functional import InterpolationMode
15
+ import re
16
+
17
 
18
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
  IMAGENET_STD = (0.229, 0.224, 0.225)
 
145
  pixel_values = torch.stack(pixel_values)
146
  return pixel_values
147
 
148
+
149
+ # This function is used to split InternVL2-Llama3-76B
150
+ def split_model(model_name):
151
+ import math
152
+ device_map = {}
153
+ num_gpus = torch.cuda.device_count()
154
+ rank, world_size = get_rank_and_world_size()
155
+ num_gpus = num_gpus // world_size
156
+
157
+ num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48,
158
+ 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
159
+ # Since the first GPU will be used for ViT, treat it as 0.8 GPU.
160
+ num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2))
161
+ num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
162
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8)
163
+ layer_cnt = 0
164
+ for i, num_layer in enumerate(num_layers_per_gpu):
165
+ for j in range(num_layer):
166
+ device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i
167
+ layer_cnt += 1
168
+ device_map['vision_model'] = rank
169
+ device_map['mlp1'] = rank
170
+ device_map['language_model.model.tok_embeddings'] = rank
171
+ device_map['language_model.model.embed_tokens'] = rank
172
+ device_map['language_model.output'] = rank
173
+ device_map['language_model.model.norm'] = rank
174
+ device_map['language_model.lm_head'] = rank
175
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = rank
176
+ return device_map
177
+
178
+
179
  class InternVLChat(BaseModel):
180
 
181
  INSTALL_REQ = False
182
+ INTERLEAVE = True
183
 
184
+ def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False, version='V1.0', **kwargs):
185
  assert model_path is not None
186
  assert version_cmp(transformers.__version__, '4.36.2', 'ge')
187
+
188
  self.model_path = model_path
189
  self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
 
 
 
 
 
 
 
 
190
 
191
+ # Regular expression to match the pattern 'Image' followed by a number, e.g. Image1
192
+ self.pattern = r'Image(\d+)'
193
+ # Replacement pattern to insert a hyphen between 'Image' and the number, e.g. Image-1
194
+ self.replacement = r'Image-\1'
195
+
196
+ # Convert InternVL2 response to dataset format
197
+ # e.g. Image1 -> Image-1
198
+
199
+ # Regular expression to match the pattern 'Image-' followed by a number
200
+ self.reverse_pattern = r'Image-(\d+)'
201
+ # Replacement pattern to remove the hyphen (Image-1 -> Image1)
202
+ self.reverse_replacement = r'Image\1'
203
+
204
+ if listinstr(['InternVL2-Llama3-76B'], model_path):
205
+ device_map = split_model(model_path.split('/')[-1])
206
+ self.model = AutoModel.from_pretrained(
207
+ model_path,
208
+ torch_dtype=torch.bfloat16,
209
+ load_in_8bit=load_in_8bit,
210
+ trust_remote_code=True,
211
+ low_cpu_mem_usage=True,
212
+ device_map=device_map).eval()
213
  else:
214
+ device = torch.cuda.current_device()
215
+ self.device = device
216
+ self.model = AutoModel.from_pretrained(
217
+ model_path,
218
+ torch_dtype=torch.bfloat16,
219
+ trust_remote_code=True,
220
+ load_in_8bit=load_in_8bit).eval()
221
+ if not load_in_8bit:
222
+ self.model = self.model.to(device)
223
+
224
+ self.image_size = self.model.config.vision_config.image_size
225
+ self.version = version
226
+ self.kwargs = kwargs
227
  warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
228
 
229
  def use_custom_prompt(self, dataset):
230
+
231
+ if dataset is not None and listinstr(['MMDU'], dataset):
232
+ # For Multi-Turn we don't have custom prompt
233
+ return False
234
+ else:
235
+ return True
236
 
237
  def build_multi_choice_prompt(self, line, dataset=None):
238
  question = line['question']
 
257
 
258
  return prompt
259
 
260
+ def build_video_prompt(self, prompt, dataset=None, max_nframe=64):
261
+ for start in range(0, max_nframe, 8):
262
+ images_to_remove = ''.join([f'<image-{i}>' for i in range(start + 1, start + 9)])
263
+ prompt = prompt.replace(images_to_remove, '')
264
+ for i in range(max_nframe):
265
+ prompt = prompt.replace(f'<image-{i + 1}>', f'Frame{i + 1}')
266
+ if listinstr(['MMBench-Video'], dataset):
267
+ prompt = prompt.replace('\nAnswer:', '')
268
+ prompt += '\nAnswer the question using a single word or phrase.'
269
+ elif listinstr(['Video-MME'], dataset):
270
+ prompt = prompt.replace('\nAnswer:', '')
271
+ prompt += "\nAnswer with the option's letter from the given choices directly."
272
+ return prompt
273
+
274
  def build_prompt(self, line, dataset=None):
275
  assert self.use_custom_prompt(dataset)
276
  assert dataset is None or isinstance(dataset, str)
277
  tgt_path = self.dump_image(line, dataset)
278
 
279
+ if self.version == 'V1.1':
280
  kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
281
  else:
282
  kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
283
  self.kwargs = kwargs_default
284
+
285
  if dataset is not None and listinstr(['MME'], dataset):
286
  question = line['question']
287
  prompt = question + ' Answer the question using a single word or phrase.'
 
 
288
  elif dataset is not None and listinstr(['HallusionBench'], dataset):
289
  question = line['question']
290
  prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
291
+ elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
292
  prompt = self.build_multi_choice_prompt(line, dataset)
293
  elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
294
+ if listinstr(['MathVista', 'MathVision'], dataset):
295
  prompt = line['question']
296
  elif listinstr(['LLaVABench'], dataset):
297
  question = line['question']
 
303
  prompt = question + '\nAnswer the question using a single word or phrase.'
304
  else:
305
  prompt = line['question']
 
306
  message = [dict(type='text', value=prompt)]
307
  message.extend([dict(type='image', value=s) for s in tgt_path])
 
308
  return message
309
 
310
+ def set_max_num(self, dataset):
 
311
  if dataset is not None and listinstr(['ChartQA_TEST'], dataset):
312
  self.max_num = 12
313
  self.max_num2 = 3
 
316
  self.max_num2 = 15
317
  self.min_num = 14
318
  self.min_num2 = 5
319
+ elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST', 'SEEDBench_IMG'], dataset):
320
  self.max_num = 23
321
  self.max_num2 = 5
322
  self.min_num = 15
323
  self.min_num2 = 3
324
+ elif dataset is not None and listinstr(['OCRBench', 'POPE'], dataset):
325
  self.max_num = 24
326
  self.max_num2 = 8
327
  self.min_num = 9
328
  self.min_num2 = 5
329
+ elif dataset is not None and listinstr(['MME', 'HallusionBench'], dataset):
330
+ self.max_num = 11
331
+ self.max_num2 = 6
332
+ self.min_num = 4
333
+ self.min_num2 = 2
334
+ elif dataset is not None and listinstr(['AI2D_TEST'], dataset):
335
+ self.max_num = 12
336
+ self.max_num2 = 6
337
+ self.min_num = 5
338
+ self.min_num2 = 2
339
+ elif dataset is not None and listinstr(['CCBench'], dataset):
340
+ self.max_num = 24
341
+ self.max_num2 = 8
342
+ self.min_num = 9
343
+ self.min_num2 = 4
344
  else:
345
  self.max_num = 8
346
  self.max_num2 = 4
347
  self.min_num = 3
348
  self.min_num2 = 1
 
 
 
 
 
349
 
350
+ def generate_v1_2(self, message, dataset=None):
351
+ self.INTERLEAVE = False
352
+ prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
353
+ image = Image.open(image_path).convert('RGB')
354
+ image = image.resize((self.image_size, self.image_size))
355
+ image_processor = CLIPImageProcessor.from_pretrained(self.model_path)
356
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
357
+ pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
358
  with torch.no_grad():
359
+ response = self.model.chat(self.tokenizer, pixel_values=pixel_values,
360
  question=prompt, generation_config=self.kwargs)
361
+ return response
362
+
363
+ def generate_v1_5(self, message, dataset=None):
364
+ image_num = len([x for x in message if x['type'] == 'image'])
365
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
366
+
367
+ if listinstr(['Video'], dataset):
368
+ prompt = self.build_video_prompt(prompt, dataset)
369
+
370
+ if image_num > 1:
371
+ image_path = [x['value'] for x in message if x['type'] == 'image']
372
+ pixel_values_list = []
373
+ for file_name in image_path:
374
+ pixel_values_list.append(load_image(file_name, max_num=self.max_num).cuda().to(torch.bfloat16))
375
+ pixel_values = torch.cat(pixel_values_list, dim=0)
376
+ elif image_num == 1:
377
+ image_path = [x['value'] for x in message if x['type'] == 'image'][0]
378
+ pixel_values = load_image(image_path, max_num=self.max_num).cuda().to(torch.bfloat16)
379
+ else:
380
+ pixel_values = None
381
+ with torch.no_grad():
382
+ response = self.model.chat(
383
+ self.tokenizer,
384
+ pixel_values=pixel_values,
385
+ question=prompt,
386
+ generation_config=self.kwargs,
387
+ verbose=False)
388
+ return response
389
 
390
+ def generate_v2(self, message, dataset=None):
391
+ image_num = len([x for x in message if x['type'] == 'image'])
392
+ if image_num == 1:
393
+ prompt = '<image>\n' + '\n'.join([x['value'] for x in message if x['type'] == 'text'])
394
+ else:
395
+ prompt, image_idx = '', 1
396
+ for x in message:
397
+ if x['type'] == 'text':
398
+ prompt += x['value']
399
+ elif x['type'] == 'image':
400
+ prompt += f'<image-{image_idx}>'
401
+ image_idx += 1
402
+ prompt = ' '.join([f'<image-{i + 1}>: <image>' for i in range(image_num)]) + '\n' + prompt
403
+
404
+ if listinstr(['Video'], dataset):
405
+ prompt = self.build_video_prompt(prompt, dataset)
406
+
407
+ if image_num > 1:
408
+ image_path = [x['value'] for x in message if x['type'] == 'image']
409
+ num_patches_list = []
410
+ pixel_values_list = []
411
+ for image_idx, file_name in enumerate(image_path):
412
+ upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
413
+ curr_pixel_values = load_image(
414
+ file_name, max_num=self.max_num, upscale=upscale_flag).cuda().to(torch.bfloat16)
415
+
416
+ curr_pixel_values, target_aspect_ratio = load_image(image_path, min_num=self.min_num, max_num=self.max_num)
417
+ curr_pixel_values = curr_pixel_values.cuda().to(torch.bfloat16)
418
+ curr_pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=self.min_num2, max_num=self.max_num2)
419
+ curr_pixel_values2 = curr_pixel_values2.cuda().to(torch.bfloat16)
420
+ curr_pixel_values = torch.cat((curr_pixel_values[:-1], curr_pixel_values2[:-1], curr_pixel_values[-1:]), 0)
421
+ num_patches_list.append(curr_pixel_values.size(0))
422
+ pixel_values_list.append(curr_pixel_values)
423
+ pixel_values = torch.cat(pixel_values_list, dim=0)
424
+ elif image_num == 1:
425
+ image_path = [x['value'] for x in message if x['type'] == 'image'][0]
426
+ upscale_flag = listinstr(['MMMU_DEV_VAL'], dataset)
427
+ pixel_values, target_aspect_ratio = load_image(image_path, min_num=self.min_num, max_num=self.max_num)
428
+ pixel_values = pixel_values.cuda().to(torch.bfloat16)
429
+ pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=self.min_num2, max_num=self.max_num2)
430
+ pixel_values2 = pixel_values2.cuda().to(torch.bfloat16)
431
+ pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0)
432
+ num_patches_list = [pixel_values.size(0)]
433
+ else:
434
+ pixel_values = None
435
+ num_patches_list = []
436
+
437
+ with torch.no_grad():
438
+ response = self.model.chat(
439
+ self.tokenizer,
440
+ pixel_values=pixel_values,
441
+ target_aspect_ratio=(1,1),
442
+ num_patches_list=num_patches_list,
443
+ question=prompt,
444
+ generation_config=self.kwargs,
445
+ verbose=False
446
+ )
447
  return response
448
 
449
  def generate_inner(self, message, dataset=None):
450
+ self.set_max_num(dataset)
451
+ print(f'InternVL model version: {self.version}')
452
+ if self.version in ['V1.1', 'V1.2']:
453
+ return self.generate_v1_2(message, dataset)
454
+ elif self.version == 'V1.5':
455
+ return self.generate_v1_5(message, dataset)
456
+ elif self.version == 'V2.0':
457
+ return self.generate_v2(message, dataset)
458
+ else:
459
+ raise ValueError(f'Unsupported version: {self.version}')
460
+
461
+ def build_history(self, message):
462
+ # Global Variables
463
+ image_path = []
464
+ image_cnt = 0
465
+
466
+ def concat_tilist(tilist):
467
+ nonlocal image_cnt # Declare image_cnt as nonlocal to modify it
468
+ prompt = ''
469
+ for item in tilist:
470
+ # Substitute the pattern in the text
471
+ if item['type'] == 'text':
472
+ prompt += re.sub(self.pattern, self.replacement, item['value'])
473
+ elif item['type'] == 'image':
474
+ image_cnt += 1
475
+ prompt += '<image>\n'
476
+ image_path.append(item['value'])
477
+ return prompt
478
+
479
+ # Only previous messages
480
+ assert len(message) % 2 == 0
481
+ history = []
482
+ for i in range(len(message) // 2):
483
+ m1, m2 = message[2 * i], message[2 * i + 1]
484
+ assert m1['role'] == 'user' and m2['role'] == 'assistant'
485
+ history.append((concat_tilist(m1['content']), concat_tilist(m2['content'])))
486
+
487
+ return history, image_path, image_cnt
488
+
489
+ def chat_inner_v2(self, message, dataset=None):
490
+
491
+ image_cnt = 0
492
+ if len(message) > 1:
493
+ history, image_path, image_cnt = self.build_history(message[:-1])
494
+ else:
495
+ history, image_path, image_cnt = None, [], 1
496
+ current_msg = message[-1]
497
+ question = ''
498
+
499
+ # If message is just text in the conversation
500
+ if len(current_msg['content']) == 1 and current_msg['content'][0]['type'] == 'text':
501
+ question = current_msg['content'][0]['value']
502
+ question = re.sub(self.pattern, self.replacement, question) # Fix pattern as per InternVL
503
+ else:
504
+ for msg in current_msg['content']:
505
+ if msg['type'] == 'text':
506
+ question += re.sub(self.pattern, self.replacement, msg['value'])
507
+ elif msg['type'] == 'image':
508
+ image_cnt += 1
509
+ question += '<image>\n'
510
+ image_path.append(msg['value'])
511
+
512
+ if image_cnt > 1:
513
+ num_patches_list = []
514
+ pixel_values_list = []
515
+ for image_idx, file_name in enumerate(image_path):
516
+ upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
517
+ curr_pixel_values = load_image(
518
+ file_name, max_num=self.max_num, upscale=upscale_flag).cuda().to(torch.bfloat16)
519
+ num_patches_list.append(curr_pixel_values.size(0))
520
+ pixel_values_list.append(curr_pixel_values)
521
+ pixel_values = torch.cat(pixel_values_list, dim=0)
522
+ elif image_cnt == 1:
523
+ upscale_flag = listinstr(['MMMU_DEV_VAL'], dataset)
524
+ pixel_values = load_image(
525
+ image_path, max_num=self.max_num, upscale=upscale_flag).cuda().to(torch.bfloat16)
526
+ num_patches_list = [pixel_values.size(0)]
527
+ else:
528
+ pixel_values = None
529
+ num_patches_list = []
530
+
531
+ response, history = self.model.chat(
532
+ self.tokenizer,
533
+ pixel_values=pixel_values,
534
+ target_aspect_ratio=target_aspect_ratio,
535
+ num_patches_list=num_patches_list,
536
+ question=question,
537
+ generation_config=self.kwargs,
538
+ history=history,
539
+ return_history=True
540
+ )
541
+
542
+ response = re.sub(self.reverse_pattern, self.reverse_replacement, response)
543
+
544
+ return response
545
+
546
+ def chat_inner(self, message, dataset=None):
547
+ self.set_max_num(dataset)
548
+
549
+ if self.version in ['V1.1', 'V1.2']:
550
+ raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')
551
+ elif self.version == 'V1.5':
552
+ raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')
553
+ elif self.version == 'V2.0':
554
+ kwargs_default = dict(do_sample=False, max_new_tokens=512, top_p=None, num_beams=1)
555
+ self.kwargs = kwargs_default
556
+ return self.chat_inner_v2(message, dataset)
557
+ else:
558
+ raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')