Freak-ppa's picture
Upload 2 files
f4d058d verified
import os
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
BlipConfig,
BlipTextConfig,
BlipVisionConfig,
)
import torch
import model_management
import folder_paths
class BLIPImg2Txt:
def __init__(
self,
conditional_caption: str,
min_words: int,
max_words: int,
temperature: float,
repetition_penalty: float,
search_beams: int,
model_id: str = "Salesforce/blip-image-captioning-large",
custom_model_path: str = None,
):
self.conditional_caption = conditional_caption
self.model_id = model_id
self.custom_model_path = custom_model_path
if self.custom_model_path and os.path.exists(self.custom_model_path):
self.model_path = self.custom_model_path
else:
self.model_path = folder_paths.get_full_path("blip", model_id)
if temperature > 1.1 or temperature < 0.90:
do_sample = True
num_beams = 1
else:
do_sample = False
num_beams = search_beams if search_beams > 1 else 1
self.text_config_kwargs = {
"do_sample": do_sample,
"max_length": max_words,
"min_length": min_words,
"repetition_penalty": repetition_penalty,
"padding": "max_length",
}
if not do_sample:
self.text_config_kwargs["temperature"] = temperature
self.text_config_kwargs["num_beams"] = num_beams
def generate_caption(self, image: Image.Image) -> str:
if image.mode != "RGB":
image = image.convert("RGB")
if self.model_path and os.path.exists(self.model_path):
model_path = self.model_path
local_files_only = True
else:
model_path = self.model_id
local_files_only = False
processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only)
config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only)
config_text.update(self.text_config_kwargs)
config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only)
config = BlipConfig.from_text_vision_configs(config_text, config_vision)
model = BlipForConditionalGeneration.from_pretrained(
model_path,
config=config,
torch_dtype=torch.float16,
local_files_only=local_files_only
).to(model_management.get_torch_device())
inputs = processor(
image,
self.conditional_caption,
return_tensors="pt",
).to(model_management.get_torch_device(), torch.float16)
with torch.no_grad():
out = model.generate(**inputs)
ret = processor.decode(out[0], skip_special_tokens=True)
del model
torch.cuda.empty_cache()
return ret