DongfuJiang commited on
Commit
86744eb
·
1 Parent(s): c89a453
app.py CHANGED
@@ -5,8 +5,8 @@ import time
5
  from PIL import Image
6
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
7
  from typing import List
8
- processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")
9
- model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")
10
 
11
  @spaces.GPU
12
  def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
 
5
  from PIL import Image
6
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
7
  from typing import List
8
+ processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
9
+ model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
10
 
11
  @spaces.GPU
12
  def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
models/conversation.py CHANGED
@@ -10,6 +10,7 @@ class SeparatorStyle(Enum):
10
  MPT = auto()
11
  PLAIN = auto()
12
  LLAMA_2 = auto()
 
13
  MFuyu = auto()
14
 
15
 
@@ -30,6 +31,7 @@ class Conversation:
30
  def get_prompt(self):
31
  messages = self.messages
32
  if len(messages) > 0 and type(messages[0][1]) is tuple:
 
33
  messages = self.messages.copy()
34
  init_role, init_msg = messages[0].copy()
35
  init_msg = init_msg[0].replace("<image>", "").strip()
@@ -39,7 +41,6 @@ class Conversation:
39
  messages.insert(1, (self.roles[1], "Received."))
40
  else:
41
  messages[0] = (init_role, "<image>" + init_msg)
42
-
43
  if self.sep_style == SeparatorStyle.SINGLE:
44
  ret = self.system + self.sep
45
  for role, message in messages:
@@ -89,6 +90,15 @@ class Conversation:
89
  else:
90
  ret += ""
91
  ret = ret.lstrip(self.sep)
 
 
 
 
 
 
 
 
 
92
  elif self.sep_style == SeparatorStyle.MFuyu:
93
  seps = [self.sep, self.sep2]
94
  ret = self.system + "\n"
@@ -393,6 +403,25 @@ conv_mllava_v1_mmtag = Conversation(
393
  version="v1_mmtag",
394
  )
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  default_conversation = conv_mfuyu_v1
398
  conv_templates = {
@@ -409,6 +438,9 @@ conv_templates = {
409
  "llava_v1": conv_llava_v1,
410
  "v1_mmtag": conv_llava_v1_mmtag,
411
  "llava_llama_2": conv_llava_llama_2,
 
 
 
412
 
413
  "mpt": conv_mpt,
414
  }
 
10
  MPT = auto()
11
  PLAIN = auto()
12
  LLAMA_2 = auto()
13
+ LLAMA_3 = auto()
14
  MFuyu = auto()
15
 
16
 
 
31
  def get_prompt(self):
32
  messages = self.messages
33
  if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+
35
  messages = self.messages.copy()
36
  init_role, init_msg = messages[0].copy()
37
  init_msg = init_msg[0].replace("<image>", "").strip()
 
41
  messages.insert(1, (self.roles[1], "Received."))
42
  else:
43
  messages[0] = (init_role, "<image>" + init_msg)
 
44
  if self.sep_style == SeparatorStyle.SINGLE:
45
  ret = self.system + self.sep
46
  for role, message in messages:
 
90
  else:
91
  ret += ""
92
  ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
94
+ ret = self.system + self.sep
95
+ for role, message in messages:
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + message + self.sep
100
+ else:
101
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
102
  elif self.sep_style == SeparatorStyle.MFuyu:
103
  seps = [self.sep, self.sep2]
104
  ret = self.system + "\n"
 
403
  version="v1_mmtag",
404
  )
405
 
406
+ conv_mllava_v1 = Conversation(
407
+ system="A chat between a curious human and an artificial intelligence assistant. "
408
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
409
+ roles=("USER", "ASSISTANT"),
410
+ version="v1",
411
+ messages=(),
412
+ offset=0,
413
+ sep_style=SeparatorStyle.SINGLE,
414
+ sep="</s>",
415
+ )
416
+
417
+ conv_llama_3 = Conversation(
418
+ system="<|start_header_id|>system<|end_header_id|>\n\nYou are a pirate chatbot who always responds in pirate speak!",
419
+ roles=("user", "assistant"),
420
+ messages=(),
421
+ offset=0,
422
+ sep_style=SeparatorStyle.LLAMA_3,
423
+ sep="<|eot_id|>",
424
+ )
425
 
426
  default_conversation = conv_mfuyu_v1
427
  conv_templates = {
 
438
  "llava_v1": conv_llava_v1,
439
  "v1_mmtag": conv_llava_v1_mmtag,
440
  "llava_llama_2": conv_llava_llama_2,
441
+ "llama_3": conv_llama_3,
442
+ "mllava_v1": conv_mllava_v1,
443
+ "mllava_v1_mmtag": conv_mllava_v1_mmtag,
444
 
445
  "mpt": conv_mpt,
446
  }
models/mllava/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
2
  from .processing_llava import MLlavaProcessor
 
3
  from .utils import chat_mllava
 
1
  from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
2
  from .processing_llava import MLlavaProcessor
3
+ from .configuration_llava import LlavaConfig
4
  from .utils import chat_mllava
models/mllava/modeling_llava.py CHANGED
@@ -249,15 +249,15 @@ LLAVA_INPUTS_DOCSTRING = r"""
249
  LLAVA_START_DOCSTRING,
250
  )
251
  class LlavaForConditionalGeneration(LlavaPreTrainedModel):
252
- def __init__(self, config: LlavaConfig):
253
  super().__init__(config)
254
- self.vision_tower = AutoModel.from_config(config.vision_config)
255
 
256
  self.multi_modal_projector = LlavaMultiModalProjector(config)
257
  self.vocab_size = config.vocab_size
258
  self.language_model = AutoModelForCausalLM.from_config(
259
  config.text_config, attn_implementation=config._attn_implementation
260
- )
261
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
262
  self.post_init()
263
 
@@ -428,6 +428,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
428
 
429
  # 2. Merge text and images
430
  if pixel_values is not None and input_ids.shape[1] != 1:
 
 
 
 
 
431
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
432
  # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
433
  selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
 
249
  LLAVA_START_DOCSTRING,
250
  )
251
  class LlavaForConditionalGeneration(LlavaPreTrainedModel):
252
+ def __init__(self, config: LlavaConfig, vision_tower=None, language_model=None):
253
  super().__init__(config)
254
+ self.vision_tower = AutoModel.from_config(config.vision_config) if vision_tower is None else vision_tower
255
 
256
  self.multi_modal_projector = LlavaMultiModalProjector(config)
257
  self.vocab_size = config.vocab_size
258
  self.language_model = AutoModelForCausalLM.from_config(
259
  config.text_config, attn_implementation=config._attn_implementation
260
+ ) if language_model is None else language_model
261
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
262
  self.post_init()
263
 
 
428
 
429
  # 2. Merge text and images
430
  if pixel_values is not None and input_ids.shape[1] != 1:
431
+ if isinstance(pixel_values, list):
432
+ pixel_values = torch.cat([x for x in pixel_values if x is not None], dim=0)
433
+ # for siglip, need to transform the pixel_values to the right data type
434
+ if pixel_values.dtype != self.vision_tower.dtype:
435
+ pixel_values = pixel_values.type(self.vision_tower.dtype)
436
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
437
  # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
438
  selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
models/mllava/processing_llava.py CHANGED
@@ -16,7 +16,8 @@
16
  Processor class for Llava.
17
  """
18
 
19
-
 
20
  from typing import List, Optional, Union, Dict
21
 
22
  # from ...feature_extraction_utils import BatchFeature
@@ -30,6 +31,9 @@ from transformers.image_utils import ImageInput
30
  from transformers.processing_utils import ProcessorMixin
31
  from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
32
  from transformers.utils import TensorType
 
 
 
33
 
34
  from PIL import Image
35
  import logging
@@ -52,8 +56,8 @@ class MLlavaProcessor(ProcessorMixin):
52
  """
53
 
54
  attributes = ["image_processor", "tokenizer"]
55
- image_processor_class = "CLIPImageProcessor"
56
- tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
57
 
58
  def __init__(self, image_processor=None, tokenizer=None):
59
  super().__init__(image_processor, tokenizer)
@@ -109,7 +113,7 @@ class MLlavaProcessor(ProcessorMixin):
109
  if i < num_images:
110
  text[i] = t + "<image>"
111
  text = "".join(text)
112
- logger.warning("Number of <image> tokens exceeds number of images. Automatically removing extra tokens at the end of the text.")
113
  # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
114
  texts = [text]
115
  elif isinstance(text, list):
@@ -135,7 +139,7 @@ class MLlavaProcessor(ProcessorMixin):
135
  if j < num_images:
136
  t[j] = s + "<image>"
137
  t = "".join(t)
138
- logger.warning("Number of <image> tokens exceeds number of images. Automatically removing extra tokens at the end of the text.")
139
  # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
140
  text[i] = t
141
  texts = text
@@ -171,6 +175,7 @@ class MLlavaProcessor(ProcessorMixin):
171
  truncation: Union[bool, str, TruncationStrategy] = None,
172
  max_length=None,
173
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
 
174
  ) -> BatchFeature:
175
  """
176
  Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@@ -218,13 +223,14 @@ class MLlavaProcessor(ProcessorMixin):
218
  `None`).
219
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
220
  """
221
- texts, images = self.preprocess_interleaved_images_and_text(text, images)
 
222
  if images is not None:
223
  pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] # [batch_size, num_channels, height, width], e.g. [1, 3, 336, 336]
224
  else:
225
  pixel_values = None
226
  text_inputs = self.tokenizer(
227
- texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
228
  )
229
  # text_inputs:
230
  # 1. input_ids: [batch_size, sequence_length], e.g. [1, 6]
@@ -259,9 +265,117 @@ class MLlavaProcessor(ProcessorMixin):
259
  results = {}
260
  assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
261
  for k in model_inputs[0].keys():
262
- if model_inputs[0][k] is not None:
263
- results[k] = torch.cat([inputs[k] for inputs in model_inputs], dim=0)
264
  else:
265
- results[k] = None
266
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
 
16
  Processor class for Llava.
17
  """
18
 
19
+ import os
20
+ import json
21
  from typing import List, Optional, Union, Dict
22
 
23
  # from ...feature_extraction_utils import BatchFeature
 
31
  from transformers.processing_utils import ProcessorMixin
32
  from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
33
  from transformers.utils import TensorType
34
+ from transformers.processing_utils import transformers_module
35
+ from transformers.utils.hub import is_remote_url, download_url, cached_file, is_offline_mode
36
+ from transformers.utils import IMAGE_PROCESSOR_NAME
37
 
38
  from PIL import Image
39
  import logging
 
56
  """
57
 
58
  attributes = ["image_processor", "tokenizer"]
59
+ image_processor_class = ("CLIPImageProcessor", "SiglipImageProcessor")
60
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast", "PreTrainedTokenizerFast")
61
 
62
  def __init__(self, image_processor=None, tokenizer=None):
63
  super().__init__(image_processor, tokenizer)
 
113
  if i < num_images:
114
  text[i] = t + "<image>"
115
  text = "".join(text)
116
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
117
  # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
118
  texts = [text]
119
  elif isinstance(text, list):
 
139
  if j < num_images:
140
  t[j] = s + "<image>"
141
  t = "".join(t)
142
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
143
  # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
144
  text[i] = t
145
  texts = text
 
175
  truncation: Union[bool, str, TruncationStrategy] = None,
176
  max_length=None,
177
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
178
+ add_image_ids: bool = True,
179
  ) -> BatchFeature:
180
  """
181
  Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
 
223
  `None`).
224
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
225
  """
226
+ if add_image_ids:
227
+ text, images = self.preprocess_interleaved_images_and_text(text, images)
228
  if images is not None:
229
  pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] # [batch_size, num_channels, height, width], e.g. [1, 3, 336, 336]
230
  else:
231
  pixel_values = None
232
  text_inputs = self.tokenizer(
233
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
234
  )
235
  # text_inputs:
236
  # 1. input_ids: [batch_size, sequence_length], e.g. [1, 6]
 
265
  results = {}
266
  assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
267
  for k in model_inputs[0].keys():
268
+ if k == "pixel_values":
269
+ results[k] = [inputs[k] if inputs[k] is not None else None for inputs in model_inputs]
270
  else:
271
+ results[k] = torch.cat([inputs[k] for inputs in model_inputs], dim=0)
272
  return results
273
+
274
+ @classmethod
275
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
276
+ args = []
277
+
278
+ cache_dir = kwargs.pop("cache_dir", None)
279
+ force_download = kwargs.pop("force_download", False)
280
+ resume_download = kwargs.pop("resume_download", False)
281
+ proxies = kwargs.pop("proxies", None)
282
+ token = kwargs.pop("token", None)
283
+ local_files_only = kwargs.pop("local_files_only", False)
284
+ revision = kwargs.pop("revision", None)
285
+ subfolder = kwargs.pop("subfolder", "")
286
+
287
+ from_pipeline = kwargs.pop("_from_pipeline", None)
288
+ from_auto_class = kwargs.pop("_from_auto", False)
289
+
290
+ user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
291
+ if from_pipeline is not None:
292
+ user_agent["using_pipeline"] = from_pipeline
293
+
294
+ if is_offline_mode() and not local_files_only:
295
+ logger.info("Offline mode: forcing local_files_only=True")
296
+ local_files_only = True
297
+
298
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
299
+ is_local = os.path.isdir(pretrained_model_name_or_path)
300
+ if os.path.isdir(pretrained_model_name_or_path):
301
+ processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
302
+ if os.path.isfile(pretrained_model_name_or_path):
303
+ resolved_processor_file = pretrained_model_name_or_path
304
+ is_local = True
305
+ elif is_remote_url(pretrained_model_name_or_path):
306
+ processor_file = pretrained_model_name_or_path
307
+ resolved_processor_file = download_url(pretrained_model_name_or_path)
308
+ else:
309
+ processor_file = IMAGE_PROCESSOR_NAME
310
+ try:
311
+ # Load from local folder or from cache or download from model Hub and cache
312
+ resolved_processor_file = cached_file(
313
+ pretrained_model_name_or_path,
314
+ processor_file,
315
+ cache_dir=cache_dir,
316
+ force_download=force_download,
317
+ proxies=proxies,
318
+ resume_download=resume_download,
319
+ local_files_only=local_files_only,
320
+ token=token,
321
+ user_agent=user_agent,
322
+ revision=revision,
323
+ subfolder=subfolder,
324
+ _raise_exceptions_for_missing_entries=True,
325
+ )
326
+ except EnvironmentError:
327
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
328
+ # the original exception.
329
+ raise
330
+ except Exception:
331
+ # For any other exception, we throw a generic error.
332
+ raise EnvironmentError(
333
+ f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
334
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
335
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
336
+ f" directory containing a {IMAGE_PROCESSOR_NAME} file"
337
+ )
338
+
339
+ # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
340
+ # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
341
+ # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
342
+ # However, for models added in the future, we won't get the expected error if this file is missing.
343
+ if resolved_processor_file is None:
344
+ image_processor_dict = {}
345
+
346
+ try:
347
+ # Load processor dict
348
+ with open(resolved_processor_file, "r", encoding="utf-8") as reader:
349
+ text = reader.read()
350
+ image_processor_dict = json.loads(text)
351
+
352
+ except json.JSONDecodeError:
353
+ raise EnvironmentError(
354
+ f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
355
+ )
356
+
357
+ for attribute_name in cls.attributes:
358
+ class_name = getattr(cls, f"{attribute_name}_class")
359
+ if isinstance(class_name, tuple):
360
+ if attribute_name == "tokenizer":
361
+ classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
362
+ use_fast = kwargs.get("use_fast", True)
363
+ if use_fast and classes[1] is not None:
364
+ attribute_class = classes[1]
365
+ else:
366
+ attribute_class = classes[0]
367
+ elif attribute_name == "image_processor":
368
+ image_processor_type = image_processor_dict.get("image_processor_type", None)
369
+ if image_processor_type is not None:
370
+ assert image_processor_type in class_name, f"Invalid image processor type: {image_processor_type}"
371
+ attribute_class = getattr(transformers_module, image_processor_type)
372
+ else:
373
+ attribute_class = getattr(transformers_module, class_name[0])
374
+ else:
375
+ raise ValueError(f"Invalid attribute name: {attribute_name}")
376
+ else:
377
+ attribute_class = getattr(transformers_module, class_name)
378
+
379
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
380
+ return args
381
 
models/mllava/utils.py CHANGED
@@ -2,7 +2,9 @@ import PIL
2
  import torch
3
  from .modeling_llava import LlavaForConditionalGeneration
4
  from .processing_llava import MLlavaProcessor
5
- from ..conversation import conv_mllava_v1_mmtag as default_conv
 
 
6
  from typing import List, Tuple, Union, Tuple
7
 
8
  def chat_mllava(
@@ -12,7 +14,6 @@ def chat_mllava(
12
  processor:MLlavaProcessor,
13
  max_input_length:int=None,
14
  history:List[dict]=None,
15
- stream:bool=False,
16
  **kwargs) -> Tuple[str, List[dict]]:
17
  """
18
  Chat with the Mllava model
@@ -29,7 +30,17 @@ def chat_mllava(
29
 
30
 
31
  """
32
- conv = default_conv.copy()
 
 
 
 
 
 
 
 
 
 
33
  conv.messages = []
34
  if history is not None:
35
  for message in history:
@@ -38,17 +49,8 @@ def chat_mllava(
38
  conv.append_message(message["role"], message["text"])
39
  else:
40
  history = []
41
-
42
- if text is not None:
43
- conv.append_message(conv.roles[0], text)
44
- conv.append_message(conv.roles[1], "")
45
- history.append({"role": conv.roles[0], "text": text})
46
- history.append({"role": conv.roles[1], "text": ""})
47
- else:
48
- assert history, "The history should not be empty if the text is None"
49
- assert history[-1]['role'] == conv.roles[1], "The last message in the history should be the assistant, an empty message"
50
- assert history[-2]['text'], "The last user message in the history should not be empty"
51
- assert history[-1]['text'] == "", "The last assistant message in the history should be empty"
52
 
53
  prompt = conv.get_prompt()
54
  if images:
@@ -57,27 +59,89 @@ def chat_mllava(
57
  images[i] = PIL.Image.open(images[i])
58
 
59
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
60
- inputs = {k: v.to(model.device) if v is not None else v for k, v in inputs.items()}
 
 
 
 
 
 
 
61
 
62
- if stream:
63
- from transformers import TextIteratorStreamer
64
- from threading import Thread
65
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
66
- kwargs["streamer"] = streamer
67
- inputs.update(kwargs)
68
- thread = Thread(target=model.generate, kwargs=inputs)
69
- thread.start()
70
- for _output in streamer:
71
- history[-1]["text"] += _output
72
- yield history[-1]["text"], history
73
- else:
74
- output_ids = model.generate(**inputs, **kwargs)
75
- output_ids = output_ids[0]
76
-
77
- # remove the input tokens
78
- generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
79
- generated_text = processor.decode(generated_ids, skip_special_tokens=True)
80
 
81
- history[-1]["text"] = history[-1]["text"].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return generated_text, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from .modeling_llava import LlavaForConditionalGeneration
4
  from .processing_llava import MLlavaProcessor
5
+ # from ..conversation import conv_mllava_v1_mmtag as default_conv
6
+ from ..conversation import conv_mllava_v1 as default_conv, conv_templates
7
+
8
  from typing import List, Tuple, Union, Tuple
9
 
10
  def chat_mllava(
 
14
  processor:MLlavaProcessor,
15
  max_input_length:int=None,
16
  history:List[dict]=None,
 
17
  **kwargs) -> Tuple[str, List[dict]]:
18
  """
19
  Chat with the Mllava model
 
30
 
31
 
32
  """
33
+ if "llama-3" in model.language_model.name_or_path.lower():
34
+ conv = conv_templates['llama_3']
35
+ terminators = [
36
+ processor.tokenizer.eos_token_id,
37
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
38
+ ]
39
+ else:
40
+ conv = default_conv
41
+ terminators = None
42
+ kwargs["eos_token_id"] = terminators
43
+ conv = conv.copy()
44
  conv.messages = []
45
  if history is not None:
46
  for message in history:
 
49
  conv.append_message(message["role"], message["text"])
50
  else:
51
  history = []
52
+ conv.append_message(conv.roles[0], text)
53
+ conv.append_message(conv.roles[1], "")
 
 
 
 
 
 
 
 
 
54
 
55
  prompt = conv.get_prompt()
56
  if images:
 
59
  images[i] = PIL.Image.open(images[i])
60
 
61
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
62
+ for k, v in inputs.items():
63
+ if v is not None:
64
+ if isinstance(v, torch.Tensor):
65
+ inputs[k] = v.to(model.device)
66
+ elif isinstance(v, list):
67
+ inputs[k] = [x.to(model.device) for x in v]
68
+ else:
69
+ raise ValueError(f"Invalid input type: {type(v)}")
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ output_ids = model.generate(**inputs, **kwargs)
73
+ output_ids = output_ids[0]
74
+
75
+ # remove the input tokens
76
+ generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
77
+ generated_text = processor.decode(generated_ids, skip_special_tokens=True)
78
+
79
+ history.append({"role": conv.roles[0], "text": text})
80
+ history.append({"role": conv.roles[1], "text": generated_text})
81
+
82
+ return generated_text, history
83
+
84
+
85
+ def chat_mllava_stream(
86
+ text:str,
87
+ images: List[Union[PIL.Image.Image, str]],
88
+ model:LlavaForConditionalGeneration,
89
+ processor:MLlavaProcessor,
90
+ max_input_length:int=None,
91
+ history:List[dict]=None,
92
+ **kwargs) -> Tuple[str, List[dict]]:
93
+ """
94
+ Chat with the Mllava model
95
+ Args:
96
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
97
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
98
+ model: LlavaForConditionalGeneration, the model to be used
99
+ processor: MLlavaProcessor, the processor to be used
100
+ max_input_length: int, the maximum input length
101
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
102
+ kwargs: dict, the generation kwargs
103
+ Returns:
104
+ Tuple[str, List[dict]], the generated text and the history of the conversation
105
 
106
+
107
+ """
108
+ conv = default_conv.copy()
109
+ conv.messages = []
110
+ if history is not None:
111
+ for message in history:
112
+ message["role"] = message["role"].upper()
113
+ assert message["role"] in conv.roles
114
+ conv.append_message(message["role"], message["text"])
115
+ else:
116
+ history = []
117
+ conv.append_message(conv.roles[0], text)
118
+ conv.append_message(conv.roles[1], "")
119
+
120
+ prompt = conv.get_prompt()
121
+ if images:
122
+ for i in range(len(images)):
123
+ if isinstance(images[i], str):
124
+ images[i] = PIL.Image.open(images[i])
125
+
126
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
127
+ for k, v in inputs.items():
128
+ if v is not None:
129
+ if isinstance(v, torch.Tensor):
130
+ inputs[k] = v.to(model.device)
131
+ elif isinstance(v, list):
132
+ inputs[k] = [x.to(model.device) for x in v]
133
+ else:
134
+ raise ValueError(f"Invalid input type: {type(v)}")
135
+
136
+ from transformers import TextIteratorStreamer
137
+ from threading import Thread
138
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
139
+ kwargs["streamer"] = streamer
140
+ inputs.update(kwargs)
141
+ thread = Thread(target=model.generate, kwargs=inputs)
142
+ thread.start()
143
+ history.append({"role": conv.roles[0], "text": text})
144
+ history.append({"role": conv.roles[1], "text": ""})
145
+ for _output in streamer:
146
+ history[-1]["text"] += _output
147
+ yield history[-1]["text"], history