Spaces:
Build error
Build error
from PIL import Image | |
import torch | |
import model_management | |
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig | |
class LlavaImg2Txt: | |
""" | |
A class to generate text captions for images using the Llava model. | |
Args: | |
question_list (list[str]): A list of questions to ask the model about the image. | |
model_id (str): The model's name in the Hugging Face model hub. | |
use_4bit_quantization (bool): Whether to use 4-bit quantization to reduce memory usage. 4-bit quantization reduces the precision of model parameters, potentially affecting the quality of generated outputs. Use if VRAM is limited. Default is True. | |
use_low_cpu_mem (bool): In low_cpu_mem_usage mode, the model is initialized with optimizations aimed at reducing CPU memory consumption. This can be beneficial when working with large models or limited computational resources. Default is True. | |
use_flash2_attention (bool): Whether to use Flash-Attention 2. Flash-Attention 2 focuses on optimizing attention mechanisms, which are crucial for the model's performance during generation. Use if computational resources are abundant. Default is False. | |
max_tokens_per_chunk (int): The maximum number of tokens to generate per prompt chunk. Default is 300. | |
""" | |
def __init__( | |
self, | |
question_list, | |
model_id: str = "llava-hf/llava-1.5-7b-hf", | |
use_4bit_quantization: bool = True, | |
use_low_cpu_mem: bool = True, | |
use_flash2_attention: bool = False, | |
max_tokens_per_chunk: int = 300, | |
): | |
self.question_list = question_list | |
self.model_id = model_id | |
self.use_4bit = use_4bit_quantization | |
self.use_flash2 = use_flash2_attention | |
self.use_low_cpu_mem = use_low_cpu_mem | |
self.max_tokens_per_chunk = max_tokens_per_chunk | |
def generate_caption( | |
self, | |
raw_image: Image.Image, | |
) -> str: | |
""" | |
Generate a caption for an image using the Llava model. | |
Args: | |
raw_image (Image): Image to generate caption for | |
""" | |
# Convert Image to RGB first | |
if raw_image.mode != "RGB": | |
raw_image = raw_image.convert("RGB") | |
dtype = torch.float16 | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=self.use_4bit, | |
bnb_4bit_compute_dtype=dtype, | |
bnb_4bit_quant_type="fp4" | |
) | |
model = LlavaForConditionalGeneration.from_pretrained( | |
self.model_id, | |
torch_dtype=dtype, | |
low_cpu_mem_usage=self.use_low_cpu_mem, | |
use_flash_attention_2=self.use_flash2, | |
quantization_config=quant_config, | |
) | |
# model.to() is not supported for 4-bit or 8-bit bitsandbytes models. With 4-bit quantization, use the model as it is, since the model will already be set to the correct devices and casted to the correct `dtype`. | |
if torch.cuda.is_available() and not self.use_4bit: | |
model = model.to(model_management.get_torch_device(), torch.float16) | |
processor = AutoProcessor.from_pretrained(self.model_id) | |
prompt_chunks = self.__get_prompt_chunks(chunk_size=4) | |
caption = "" | |
with torch.no_grad(): | |
for prompt_list in prompt_chunks: | |
prompt = self.__get_single_answer_prompt(prompt_list) | |
inputs = processor(prompt, raw_image, return_tensors="pt").to( | |
model_management.get_torch_device(), torch.float16 | |
) | |
output = model.generate( | |
**inputs, max_new_tokens=self.max_tokens_per_chunk, do_sample=False | |
) | |
decoded = processor.decode(output[0][2:]) | |
cleaned = self.clean_output(decoded) | |
caption += cleaned | |
del model | |
torch.cuda.empty_cache() | |
return caption | |
def clean_output(self, decoded_output, delimiter=","): | |
output_only = decoded_output.split("ASSISTANT: ")[1] | |
lines = output_only.split("\n") | |
cleaned_output = "" | |
for line in lines: | |
cleaned_output += self.__replace_delimiter(line, ".", delimiter) | |
return cleaned_output | |
def __get_single_answer_prompt(self, questions): | |
""" | |
For multiple turns conversation: | |
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:" | |
From: https://huggingface.co/docs/transformers/en/model_doc/llava#usage-tips | |
Not sure how the formatting works for multi-turn but those are the docs. | |
""" | |
prompt = "USER: <image>\n" | |
for index, question in enumerate(questions): | |
if index != 0: | |
prompt += "USER: " | |
prompt += f"{question} </s >" | |
prompt += "ASSISTANT: " | |
return prompt | |
def __replace_delimiter(self, text: str, old, new=","): | |
"""Replace only the LAST instance of old with new""" | |
if old not in text: | |
return text.strip() + " " | |
last_old_index = text.rindex(old) | |
replaced = text[:last_old_index] + new + text[last_old_index + len(old) :] | |
return replaced.strip() + " " | |
def __get_prompt_chunks(self, chunk_size=4): | |
prompt_chunks = [] | |
for index, feature in enumerate(self.question_list): | |
if index % chunk_size == 0: | |
prompt_chunks.append([feature]) | |
else: | |
prompt_chunks[-1].append(feature) | |
return prompt_chunks | |