Spaces:
Build error
Build error
import os | |
from PIL import Image | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
BlipConfig, | |
BlipTextConfig, | |
BlipVisionConfig, | |
) | |
import torch | |
import model_management | |
import folder_paths | |
class BLIPImg2Txt: | |
def __init__( | |
self, | |
conditional_caption: str, | |
min_words: int, | |
max_words: int, | |
temperature: float, | |
repetition_penalty: float, | |
search_beams: int, | |
model_id: str = "Salesforce/blip-image-captioning-large", | |
custom_model_path: str = None, | |
): | |
self.conditional_caption = conditional_caption | |
self.model_id = model_id | |
self.custom_model_path = custom_model_path | |
if self.custom_model_path and os.path.exists(self.custom_model_path): | |
self.model_path = self.custom_model_path | |
else: | |
self.model_path = folder_paths.get_full_path("blip", model_id) | |
if temperature > 1.1 or temperature < 0.90: | |
do_sample = True | |
num_beams = 1 | |
else: | |
do_sample = False | |
num_beams = search_beams if search_beams > 1 else 1 | |
self.text_config_kwargs = { | |
"do_sample": do_sample, | |
"max_length": max_words, | |
"min_length": min_words, | |
"repetition_penalty": repetition_penalty, | |
"padding": "max_length", | |
} | |
if not do_sample: | |
self.text_config_kwargs["temperature"] = temperature | |
self.text_config_kwargs["num_beams"] = num_beams | |
def generate_caption(self, image: Image.Image) -> str: | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
if self.model_path and os.path.exists(self.model_path): | |
model_path = self.model_path | |
local_files_only = True | |
else: | |
model_path = self.model_id | |
local_files_only = False | |
processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only) | |
config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only) | |
config_text.update(self.text_config_kwargs) | |
config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only) | |
config = BlipConfig.from_text_vision_configs(config_text, config_vision) | |
model = BlipForConditionalGeneration.from_pretrained( | |
model_path, | |
config=config, | |
torch_dtype=torch.float16, | |
local_files_only=local_files_only | |
).to(model_management.get_torch_device()) | |
inputs = processor( | |
image, | |
self.conditional_caption, | |
return_tensors="pt", | |
).to(model_management.get_torch_device(), torch.float16) | |
with torch.no_grad(): | |
out = model.generate(**inputs) | |
ret = processor.decode(out[0], skip_special_tokens=True) | |
del model | |
torch.cuda.empty_cache() | |
return ret |