Spaces:
Build error
Build error
File size: 8,440 Bytes
ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b f4d058d ffd0e5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
"""
@author: christian-byrne
@title: Img2Txt auto captioning. Choose from models: BLIP, Llava, MiniCPM, MS-GIT. Use model combos and merge results. Specify questions to ask about images (medium, art style, background). Supports Chinese 🇨🇳 questions via MiniCPM.
@nickname: Image to Text - Auto Caption
"""
import torch
from torchvision import transforms
from .img_tensor_utils import TensorImgUtils
from .llava_img2txt import LlavaImg2Txt
from .blip_img2txt import BLIPImg2Txt
from .mini_cpm_img2txt import MiniPCMImg2Txt
from typing import Tuple
import os
import folder_paths
class Img2TxtNode:
CATEGORY = "img2txt"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_image": ("IMAGE",),
},
"optional": {
"use_blip_model": (
"BOOLEAN",
{
"default": True,
"label_on": "Use BLIP (Requires 2Gb Disk)",
"label_off": "Don't use BLIP",
},
),
"use_llava_model": (
"BOOLEAN",
{
"default": False,
"label_on": "Use Llava (Requires 15Gb Disk)",
"label_off": "Don't use Llava",
},
),
"use_mini_pcm_model": (
"BOOLEAN",
{
"default": False,
"label_on": "Use MiniCPM (Requires 6Gb Disk)",
"label_off": "Don't use MiniCPM",
},
),
"use_all_models": (
"BOOLEAN",
{
"default": False,
"label_on": "Use all models and combine outputs (Total Size: 20+Gb)",
"label_off": "Use selected models only",
},
),
"blip_caption_prefix": (
"STRING",
{
"default": "a photograph of",
},
),
"prompt_questions": (
"STRING",
{
"default": "What is the subject of this image?\nWhat are the mediums used to make this?\nWhat are the artistic styles this is reminiscent of?\nWhich famous artists is this reminiscent of?\nHow sharp or detailed is this image?\nWhat is the environment and background of this image?\nWhat are the objects in this image?\nWhat is the composition of this image?\nWhat is the color palette in this image?\nWhat is the lighting in this image?",
"multiline": True,
},
),
"temperature": (
"FLOAT",
{
"default": 0.8,
"min": 0.1,
"max": 2.0,
"step": 0.01,
"display": "slider",
},
),
"repetition_penalty": (
"FLOAT",
{
"default": 1.2,
"min": 0.1,
"max": 2.0,
"step": 0.01,
"display": "slider",
},
),
"min_words": ("INT", {"default": 36}),
"max_words": ("INT", {"default": 128}),
"search_beams": ("INT", {"default": 5}),
"exclude_terms": (
"STRING",
{
"default": "watermark, text, writing",
},
),
},
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
"output_text": (
"STRING",
{
"default": "",
},
),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("caption",)
FUNCTION = "main"
OUTPUT_NODE = True
def main(
self,
input_image: torch.Tensor, # [Batch_n, H, W, 3-channel]
use_blip_model: bool,
use_llava_model: bool,
use_all_models: bool,
use_mini_pcm_model: bool,
blip_caption_prefix: str,
prompt_questions: str,
temperature: float,
repetition_penalty: float,
min_words: int,
max_words: int,
search_beams: int,
exclude_terms: str,
output_text: str = "",
unique_id=None,
extra_pnginfo=None,
) -> Tuple[str, ...]:
raw_image = transforms.ToPILImage()(
TensorImgUtils.convert_to_type(input_image, "CHW")
).convert("RGB")
if blip_caption_prefix == "":
blip_caption_prefix = "a photograph of"
captions = []
if use_all_models or use_blip_model:
blip_model_path = folder_paths.get_folder_paths("blip")[0]
print(f"blip_model_path: {blip_model_path}")
if not blip_model_path or not os.path.exists(blip_model_path):
raise ValueError("BLIP model 'blip-image-captioning-large' not found in ComfyUI models directory. Please ensure it's in the 'models/blip' folder.")
blip = BLIPImg2Txt(
conditional_caption=blip_caption_prefix,
min_words=min_words,
max_words=max_words,
temperature=temperature,
repetition_penalty=repetition_penalty,
search_beams=search_beams,
custom_model_path=blip_model_path
)
captions.append(blip.generate_caption(raw_image))
if use_all_models or use_llava_model:
llava_questions = prompt_questions.split("\n")
llava_questions = [
q
for q in llava_questions
if q != "" and q != " " and q != "\n" and q != "\n\n"
]
if len(llava_questions) > 0:
llava = LlavaImg2Txt(
question_list=llava_questions,
model_id="llava-hf/llava-1.5-7b-hf",
use_4bit_quantization=True,
use_low_cpu_mem=True,
use_flash2_attention=False,
max_tokens_per_chunk=300,
)
captions.append(llava.generate_caption(raw_image))
if use_all_models or use_mini_pcm_model:
mini_pcm = MiniPCMImg2Txt(
question_list=prompt_questions.split("\n"),
temperature=temperature,
)
captions.append(mini_pcm.generate_captions(raw_image))
out_string = self.exclude(exclude_terms, self.merge_captions(captions))
return {"ui": {"text": out_string}, "result": (out_string,)}
def merge_captions(self, captions: list) -> str:
"""Merge captions from multiple models into one string.
Necessary because we can expect the generated captions will generally
be comma-separated fragments ordered by relevance - so combine
fragments in an alternating order."""
merged_caption = ""
captions = [c.split(",") for c in captions]
for i in range(max(len(c) for c in captions)):
for j in range(len(captions)):
if i < len(captions[j]) and captions[j][i].strip() != "":
merged_caption += captions[j][i].strip() + ", "
return merged_caption
def exclude(self, exclude_terms: str, out_string: str) -> str:
# https://huggingface.co/Salesforce/blip-image-captioning-large/discussions/20
exclude_terms = "arafed," + exclude_terms
exclude_terms = [
term.strip().lower() for term in exclude_terms.split(",") if term != ""
]
for term in exclude_terms:
out_string = out_string.replace(term, "")
return out_string
|