Kohaku-Blueleaf
commited on
Commit
·
06f0d78
1
Parent(s):
e73123b
Fix encode prompt impl
Browse files
app-local.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
## if kgen not exist
|
5 |
+
try:
|
6 |
+
import kgen
|
7 |
+
except:
|
8 |
+
GH_TOKEN = os.getenv("GITHUB_TOKEN")
|
9 |
+
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TIPO-KGen@tipo"
|
10 |
+
|
11 |
+
## call pip install
|
12 |
+
os.system(f"pip install git+{git_url}")
|
13 |
+
|
14 |
+
import re
|
15 |
+
import random
|
16 |
+
from time import time
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from transformers import set_seed
|
20 |
+
|
21 |
+
if sys.platform == "win32":
|
22 |
+
# dev env in windows, @spaces.GPU will cause problem
|
23 |
+
def GPU(**kwargs):
|
24 |
+
return lambda x: x
|
25 |
+
|
26 |
+
else:
|
27 |
+
from spaces import GPU
|
28 |
+
|
29 |
+
import kgen.models as models
|
30 |
+
import kgen.executor.tipo as tipo
|
31 |
+
from kgen.formatter import seperate_tags, apply_format
|
32 |
+
from kgen.generate import generate
|
33 |
+
|
34 |
+
from diff import load_model, encode_prompts
|
35 |
+
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
|
36 |
+
|
37 |
+
|
38 |
+
sdxl_pipe = load_model()
|
39 |
+
sdxl_pipe.text_encoder.to("cpu")
|
40 |
+
sdxl_pipe.text_encoder_2.to("cpu")
|
41 |
+
sdxl_pipe.vae.to("cpu")
|
42 |
+
sdxl_pipe.k_diffusion_model.to("cpu")
|
43 |
+
|
44 |
+
models.load_model("Amber-River/tipo", device="cuda", subfolder="500M-epoch3")
|
45 |
+
generate(max_new_tokens=4)
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
|
48 |
+
|
49 |
+
DEFAULT_TAGS = """
|
50 |
+
1girl, king halo (umamusume), umamusume,
|
51 |
+
ningen mame, ciloranko, ogipote, misu kasumi,
|
52 |
+
solo, leaning forward, sky,
|
53 |
+
masterpiece, absurdres, sensitive, newest
|
54 |
+
""".strip()
|
55 |
+
DEFAULT_NL = """
|
56 |
+
An illustration of a girl
|
57 |
+
""".strip()
|
58 |
+
|
59 |
+
|
60 |
+
def format_time(timing):
|
61 |
+
total = timing["total"]
|
62 |
+
generate_pass = timing["generate_pass"]
|
63 |
+
|
64 |
+
result = ""
|
65 |
+
|
66 |
+
result += f"""
|
67 |
+
### Process Time
|
68 |
+
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second|
|
69 |
+
|-|-|-|
|
70 |
+
"""
|
71 |
+
if "generated_tokens" in timing:
|
72 |
+
total_generated_tokens = timing["generated_tokens"]
|
73 |
+
total_input_tokens = timing["input_tokens"]
|
74 |
+
if "generated_tokens" in timing and "total_sampling" in timing:
|
75 |
+
sampling_time = timing["total_sampling"] / 1000
|
76 |
+
process_time = timing["prompt_process"] / 1000
|
77 |
+
model_time = timing["total_eval"] / 1000
|
78 |
+
|
79 |
+
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second|
|
80 |
+
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second|
|
81 |
+
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second|
|
82 |
+
"""
|
83 |
+
|
84 |
+
if "generated_tokens" in timing:
|
85 |
+
result += f"""
|
86 |
+
### Processed Tokens:
|
87 |
+
* {total_input_tokens:} Input Tokens
|
88 |
+
* {total_generated_tokens:} Output Tokens
|
89 |
+
"""
|
90 |
+
return result
|
91 |
+
|
92 |
+
|
93 |
+
@GPU(duration=10)
|
94 |
+
@torch.no_grad()
|
95 |
+
def generate(
|
96 |
+
tags,
|
97 |
+
nl_prompt,
|
98 |
+
black_list,
|
99 |
+
temp,
|
100 |
+
output_format,
|
101 |
+
target_length,
|
102 |
+
top_p,
|
103 |
+
min_p,
|
104 |
+
top_k,
|
105 |
+
seed,
|
106 |
+
escape_brackets,
|
107 |
+
):
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
default_format = DEFAULT_FORMAT[output_format]
|
110 |
+
tipo.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()]
|
111 |
+
generation_setting = {
|
112 |
+
"seed": seed,
|
113 |
+
"temperature": temp,
|
114 |
+
"top_p": top_p,
|
115 |
+
"min_p": min_p,
|
116 |
+
"top_k": top_k,
|
117 |
+
}
|
118 |
+
inputs = seperate_tags(tags.split(","))
|
119 |
+
if nl_prompt:
|
120 |
+
if "<|extended|>" in default_format:
|
121 |
+
inputs["extended"] = nl_prompt
|
122 |
+
elif "<|generated|>" in default_format:
|
123 |
+
inputs["generated"] = nl_prompt
|
124 |
+
input_prompt = apply_format(inputs, default_format)
|
125 |
+
if escape_brackets:
|
126 |
+
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt)
|
127 |
+
|
128 |
+
meta, operations, general, nl_prompt = tipo.parse_tipo_request(
|
129 |
+
seperate_tags(tags.split(",")),
|
130 |
+
nl_prompt,
|
131 |
+
tag_length_target=target_length,
|
132 |
+
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt,
|
133 |
+
)
|
134 |
+
t0 = time()
|
135 |
+
for result, timing in tipo.tipo_runner_generator(
|
136 |
+
meta, operations, general, nl_prompt, **generation_setting
|
137 |
+
):
|
138 |
+
result = apply_format(result, default_format)
|
139 |
+
if escape_brackets:
|
140 |
+
result = re.sub(r"([()\[\]])", r"\\\1", result)
|
141 |
+
timing["total"] = time() - t0
|
142 |
+
yield result, input_prompt, format_time(timing)
|
143 |
+
torch.cuda.empty_cache()
|
144 |
+
|
145 |
+
|
146 |
+
@GPU(duration=20)
|
147 |
+
@torch.no_grad()
|
148 |
+
def generate_image(
|
149 |
+
seed,
|
150 |
+
prompt,
|
151 |
+
prompt2,
|
152 |
+
):
|
153 |
+
torch.cuda.empty_cache()
|
154 |
+
set_seed(seed)
|
155 |
+
sdxl_pipe.text_encoder.to("cuda")
|
156 |
+
sdxl_pipe.text_encoder_2.to("cuda")
|
157 |
+
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
|
158 |
+
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
|
159 |
+
)
|
160 |
+
sdxl_pipe.vae.to("cuda")
|
161 |
+
sdxl_pipe.k_diffusion_model.to("cuda")
|
162 |
+
print(prompt_embeds.device)
|
163 |
+
result2 = sdxl_pipe(
|
164 |
+
prompt_embeds=prompt_embeds,
|
165 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
166 |
+
pooled_prompt_embeds=pooled_embeds2,
|
167 |
+
negative_pooled_prompt_embeds=neg_pooled_embeds2,
|
168 |
+
num_inference_steps=24,
|
169 |
+
width=1024,
|
170 |
+
height=1024,
|
171 |
+
guidance_scale=6.0,
|
172 |
+
).images[0]
|
173 |
+
sdxl_pipe.text_encoder.to("cpu")
|
174 |
+
sdxl_pipe.text_encoder_2.to("cpu")
|
175 |
+
sdxl_pipe.vae.to("cpu")
|
176 |
+
sdxl_pipe.k_diffusion_model.to("cpu")
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
yield result2, None
|
179 |
+
|
180 |
+
set_seed(seed)
|
181 |
+
sdxl_pipe.text_encoder.to("cuda")
|
182 |
+
sdxl_pipe.text_encoder_2.to("cuda")
|
183 |
+
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
|
184 |
+
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
|
185 |
+
)
|
186 |
+
sdxl_pipe.vae.to("cuda")
|
187 |
+
sdxl_pipe.k_diffusion_model.to("cuda")
|
188 |
+
result = sdxl_pipe(
|
189 |
+
prompt_embeds=prompt_embeds,
|
190 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
191 |
+
pooled_prompt_embeds=pooled_embeds2,
|
192 |
+
negative_pooled_prompt_embeds=neg_pooled_embeds2,
|
193 |
+
num_inference_steps=24,
|
194 |
+
width=1024,
|
195 |
+
height=1024,
|
196 |
+
guidance_scale=6.0,
|
197 |
+
).images[0]
|
198 |
+
sdxl_pipe.text_encoder.to("cpu")
|
199 |
+
sdxl_pipe.text_encoder_2.to("cpu")
|
200 |
+
sdxl_pipe.vae.to("cpu")
|
201 |
+
sdxl_pipe.k_diffusion_model.to("cpu")
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
yield result2, result
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
208 |
+
with gr.Accordion("Introduction and Instructions", open=False):
|
209 |
+
gr.Markdown(
|
210 |
+
"""
|
211 |
+
## TIPO Demo
|
212 |
+
### What is this
|
213 |
+
TIPO is a tool to extend, generate, refine the input prompt for T2I models.
|
214 |
+
<br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
|
215 |
+
<br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2)
|
216 |
+
|
217 |
+
### How to use this demo
|
218 |
+
1. Enter your tags(optional): put the desired tags into "danboru tags" box
|
219 |
+
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box
|
220 |
+
3. Enter your black list(optional): put the desired black list into "black list" box
|
221 |
+
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ...
|
222 |
+
4. Click "TIPO" button: you will see refined prompt on "result" box
|
223 |
+
5. If you like the result, click "Generate Image From Result" button
|
224 |
+
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt
|
225 |
+
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False
|
226 |
+
|
227 |
+
### Why inference code is private? When will it be open sourced?
|
228 |
+
1. This model/tool is still under development, currently is early Alpha version.
|
229 |
+
2. I'm doing some research and projects based on this.
|
230 |
+
3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself.
|
231 |
+
4. Once the project/research are done, I will open source all these models/codes with Apache2 license.
|
232 |
+
|
233 |
+
### Notification
|
234 |
+
**TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
|
235 |
+
<br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
|
236 |
+
"""
|
237 |
+
)
|
238 |
+
with gr.Row():
|
239 |
+
with gr.Column(scale=5):
|
240 |
+
with gr.Row():
|
241 |
+
with gr.Column(scale=3):
|
242 |
+
tags_input = gr.TextArea(
|
243 |
+
label="Danbooru Tags",
|
244 |
+
lines=7,
|
245 |
+
show_copy_button=True,
|
246 |
+
interactive=True,
|
247 |
+
value=DEFAULT_TAGS,
|
248 |
+
placeholder="Enter danbooru tags here",
|
249 |
+
)
|
250 |
+
nl_prompt_input = gr.Textbox(
|
251 |
+
label="Natural Language Prompt",
|
252 |
+
lines=7,
|
253 |
+
show_copy_button=True,
|
254 |
+
interactive=True,
|
255 |
+
value=DEFAULT_NL,
|
256 |
+
placeholder="Enter Natural Language Prompt here",
|
257 |
+
)
|
258 |
+
black_list = gr.TextArea(
|
259 |
+
label="Black List (seperated by comma)",
|
260 |
+
lines=4,
|
261 |
+
interactive=True,
|
262 |
+
value="monochrome",
|
263 |
+
placeholder="Enter tag/nl black list here",
|
264 |
+
)
|
265 |
+
with gr.Column(scale=2):
|
266 |
+
output_format = gr.Dropdown(
|
267 |
+
label="Output Format",
|
268 |
+
choices=list(DEFAULT_FORMAT.keys()),
|
269 |
+
value="Both, tag first (recommend)",
|
270 |
+
)
|
271 |
+
target_length = gr.Dropdown(
|
272 |
+
label="Target Length",
|
273 |
+
choices=["very_short", "short", "long", "very_long"],
|
274 |
+
value="long",
|
275 |
+
)
|
276 |
+
temp = gr.Slider(
|
277 |
+
label="Temp",
|
278 |
+
minimum=0.0,
|
279 |
+
maximum=1.5,
|
280 |
+
value=0.5,
|
281 |
+
step=0.05,
|
282 |
+
)
|
283 |
+
top_p = gr.Slider(
|
284 |
+
label="Top P",
|
285 |
+
minimum=0.0,
|
286 |
+
maximum=1.0,
|
287 |
+
value=0.95,
|
288 |
+
step=0.05,
|
289 |
+
)
|
290 |
+
min_p = gr.Slider(
|
291 |
+
label="Min P",
|
292 |
+
minimum=0.0,
|
293 |
+
maximum=0.2,
|
294 |
+
value=0.05,
|
295 |
+
step=0.01,
|
296 |
+
)
|
297 |
+
top_k = gr.Slider(
|
298 |
+
label="Top K", minimum=0, maximum=120, value=60, step=1
|
299 |
+
)
|
300 |
+
with gr.Row():
|
301 |
+
seed = gr.Number(
|
302 |
+
label="Seed",
|
303 |
+
minimum=0,
|
304 |
+
maximum=2147483647,
|
305 |
+
value=20090220,
|
306 |
+
step=1,
|
307 |
+
)
|
308 |
+
escape_brackets = gr.Checkbox(
|
309 |
+
label="Escape Brackets", value=False
|
310 |
+
)
|
311 |
+
submit = gr.Button("TIPO!", variant="primary")
|
312 |
+
with gr.Accordion("Speed statstics", open=False):
|
313 |
+
cost_time = gr.Markdown()
|
314 |
+
with gr.Column(scale=5):
|
315 |
+
result = gr.TextArea(
|
316 |
+
label="Result", lines=8, show_copy_button=True, interactive=False
|
317 |
+
)
|
318 |
+
input_prompt = gr.Textbox(
|
319 |
+
label="Input Prompt", lines=1, interactive=False, visible=False
|
320 |
+
)
|
321 |
+
gen_img = gr.Button(
|
322 |
+
"Generate Image from Result", variant="primary", interactive=False
|
323 |
+
)
|
324 |
+
with gr.Row():
|
325 |
+
with gr.Column():
|
326 |
+
img1 = gr.Image(label="Original Propmt", interactive=False)
|
327 |
+
with gr.Column():
|
328 |
+
img2 = gr.Image(label="Generated Prompt", interactive=False)
|
329 |
+
|
330 |
+
def generate_wrapper(*args):
|
331 |
+
yield "", "", "", gr.update(interactive=False),
|
332 |
+
for i in generate(*args):
|
333 |
+
yield *i, gr.update(interactive=False)
|
334 |
+
yield *i, gr.update(interactive=True)
|
335 |
+
|
336 |
+
submit.click(
|
337 |
+
generate_wrapper,
|
338 |
+
[
|
339 |
+
tags_input,
|
340 |
+
nl_prompt_input,
|
341 |
+
black_list,
|
342 |
+
temp,
|
343 |
+
output_format,
|
344 |
+
target_length,
|
345 |
+
top_p,
|
346 |
+
min_p,
|
347 |
+
top_k,
|
348 |
+
seed,
|
349 |
+
escape_brackets,
|
350 |
+
],
|
351 |
+
[
|
352 |
+
result,
|
353 |
+
input_prompt,
|
354 |
+
cost_time,
|
355 |
+
gen_img,
|
356 |
+
],
|
357 |
+
queue=True,
|
358 |
+
)
|
359 |
+
|
360 |
+
def generate_image_wrapper(seed, result, input_prompt):
|
361 |
+
for img1, img2 in generate_image(seed, result, input_prompt):
|
362 |
+
yield img1, img2, gr.update(interactive=False)
|
363 |
+
yield img1, img2, gr.update(interactive=True)
|
364 |
+
|
365 |
+
gen_img.click(
|
366 |
+
generate_image_wrapper,
|
367 |
+
[seed, result, input_prompt],
|
368 |
+
[img1, img2, submit],
|
369 |
+
queue=True,
|
370 |
+
)
|
371 |
+
gen_img.click(
|
372 |
+
lambda *args: gr.update(interactive=False),
|
373 |
+
None,
|
374 |
+
[submit],
|
375 |
+
queue=False,
|
376 |
+
)
|
377 |
+
|
378 |
+
demo.launch()
|
app.py
CHANGED
@@ -17,10 +17,12 @@ from time import time
|
|
17 |
|
18 |
import torch
|
19 |
from transformers import set_seed
|
|
|
20 |
if sys.platform == "win32":
|
21 |
-
#dev env in windows, @spaces.GPU will cause problem
|
22 |
def GPU(func, *args, **kwargs):
|
23 |
return func
|
|
|
24 |
else:
|
25 |
from spaces import GPU
|
26 |
|
@@ -33,7 +35,7 @@ from diff import load_model, encode_prompts
|
|
33 |
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
|
34 |
|
35 |
|
36 |
-
sdxl_pipe = load_model()
|
37 |
|
38 |
models.load_model(
|
39 |
"Amber-River/tipo",
|
@@ -145,14 +147,14 @@ def generate_image(
|
|
145 |
):
|
146 |
torch.cuda.empty_cache()
|
147 |
set_seed(seed)
|
148 |
-
prompt_embeds,
|
149 |
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
|
150 |
)
|
151 |
result2 = sdxl_pipe(
|
152 |
-
prompt_embeds=prompt_embeds,
|
153 |
-
negative_prompt_embeds=
|
154 |
-
pooled_prompt_embeds=pooled_embeds2,
|
155 |
-
negative_pooled_prompt_embeds=
|
156 |
num_inference_steps=24,
|
157 |
width=1024,
|
158 |
height=1024,
|
@@ -160,15 +162,15 @@ def generate_image(
|
|
160 |
).images[0]
|
161 |
yield result2, None
|
162 |
|
163 |
-
prompt_embeds,
|
164 |
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
|
165 |
)
|
166 |
set_seed(seed)
|
167 |
result = sdxl_pipe(
|
168 |
-
prompt_embeds=prompt_embeds,
|
169 |
-
negative_prompt_embeds=
|
170 |
-
pooled_prompt_embeds=pooled_embeds2,
|
171 |
-
negative_pooled_prompt_embeds=
|
172 |
num_inference_steps=24,
|
173 |
width=1024,
|
174 |
height=1024,
|
@@ -209,7 +211,7 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
|
|
209 |
|
210 |
### Notification
|
211 |
**TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
|
212 |
-
<br>The generated
|
213 |
"""
|
214 |
)
|
215 |
with gr.Row():
|
@@ -243,7 +245,7 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
|
|
243 |
output_format = gr.Dropdown(
|
244 |
label="Output Format",
|
245 |
choices=list(DEFAULT_FORMAT.keys()),
|
246 |
-
value="Both, tag first (recommend)"
|
247 |
)
|
248 |
target_length = gr.Dropdown(
|
249 |
label="Target Length",
|
@@ -295,17 +297,21 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
|
|
295 |
input_prompt = gr.Textbox(
|
296 |
label="Input Prompt", lines=1, interactive=False, visible=False
|
297 |
)
|
298 |
-
gen_img = gr.Button(
|
|
|
|
|
299 |
with gr.Row():
|
300 |
with gr.Column():
|
301 |
img1 = gr.Image(label="Original Prompt", interactive=False)
|
302 |
with gr.Column():
|
303 |
img2 = gr.Image(label="Generated Prompt", interactive=False)
|
|
|
304 |
def generate_wrapper(*args):
|
305 |
yield "", "", "", gr.update(interactive=False),
|
306 |
for i in generate(*args):
|
307 |
yield *i, gr.update(interactive=False)
|
308 |
yield *i, gr.update(interactive=True)
|
|
|
309 |
submit.click(
|
310 |
generate_wrapper,
|
311 |
[
|
@@ -329,11 +335,12 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
|
|
329 |
],
|
330 |
queue=True,
|
331 |
)
|
332 |
-
|
333 |
def generate_image_wrapper(seed, result, input_prompt):
|
334 |
for img1, img2 in generate_image(seed, result, input_prompt):
|
335 |
yield img1, img2, gr.update(interactive=False)
|
336 |
yield img1, img2, gr.update(interactive=True)
|
|
|
337 |
gen_img.click(
|
338 |
generate_image_wrapper,
|
339 |
[seed, result, input_prompt],
|
|
|
17 |
|
18 |
import torch
|
19 |
from transformers import set_seed
|
20 |
+
|
21 |
if sys.platform == "win32":
|
22 |
+
# dev env in windows, @spaces.GPU will cause problem
|
23 |
def GPU(func, *args, **kwargs):
|
24 |
return func
|
25 |
+
|
26 |
else:
|
27 |
from spaces import GPU
|
28 |
|
|
|
35 |
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
|
36 |
|
37 |
|
38 |
+
sdxl_pipe = load_model("OnomaAIResearch/Illustrious-xl-early-release-v0")
|
39 |
|
40 |
models.load_model(
|
41 |
"Amber-River/tipo",
|
|
|
147 |
):
|
148 |
torch.cuda.empty_cache()
|
149 |
set_seed(seed)
|
150 |
+
prompt_embeds, pooled_embeds2 = (
|
151 |
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
|
152 |
)
|
153 |
result2 = sdxl_pipe(
|
154 |
+
prompt_embeds=prompt_embeds[0:1],
|
155 |
+
negative_prompt_embeds=prompt_embeds[1:],
|
156 |
+
pooled_prompt_embeds=pooled_embeds2[0:1],
|
157 |
+
negative_pooled_prompt_embeds=pooled_embeds2[1:],
|
158 |
num_inference_steps=24,
|
159 |
width=1024,
|
160 |
height=1024,
|
|
|
162 |
).images[0]
|
163 |
yield result2, None
|
164 |
|
165 |
+
prompt_embeds, pooled_embeds2 = (
|
166 |
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
|
167 |
)
|
168 |
set_seed(seed)
|
169 |
result = sdxl_pipe(
|
170 |
+
prompt_embeds=prompt_embeds[0:1],
|
171 |
+
negative_prompt_embeds=prompt_embeds[1:],
|
172 |
+
pooled_prompt_embeds=pooled_embeds2[0:1],
|
173 |
+
negative_pooled_prompt_embeds=pooled_embeds2[1:],
|
174 |
num_inference_steps=24,
|
175 |
width=1024,
|
176 |
height=1024,
|
|
|
211 |
|
212 |
### Notification
|
213 |
**TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
|
214 |
+
<br>The generated images come from OnomaAIResearch/Illustrious-xl-early-release-v0 SDXL-based model**
|
215 |
"""
|
216 |
)
|
217 |
with gr.Row():
|
|
|
245 |
output_format = gr.Dropdown(
|
246 |
label="Output Format",
|
247 |
choices=list(DEFAULT_FORMAT.keys()),
|
248 |
+
value="Both, tag first (recommend)",
|
249 |
)
|
250 |
target_length = gr.Dropdown(
|
251 |
label="Target Length",
|
|
|
297 |
input_prompt = gr.Textbox(
|
298 |
label="Input Prompt", lines=1, interactive=False, visible=False
|
299 |
)
|
300 |
+
gen_img = gr.Button(
|
301 |
+
"Generate Image from Result", variant="primary", interactive=False
|
302 |
+
)
|
303 |
with gr.Row():
|
304 |
with gr.Column():
|
305 |
img1 = gr.Image(label="Original Prompt", interactive=False)
|
306 |
with gr.Column():
|
307 |
img2 = gr.Image(label="Generated Prompt", interactive=False)
|
308 |
+
|
309 |
def generate_wrapper(*args):
|
310 |
yield "", "", "", gr.update(interactive=False),
|
311 |
for i in generate(*args):
|
312 |
yield *i, gr.update(interactive=False)
|
313 |
yield *i, gr.update(interactive=True)
|
314 |
+
|
315 |
submit.click(
|
316 |
generate_wrapper,
|
317 |
[
|
|
|
335 |
],
|
336 |
queue=True,
|
337 |
)
|
338 |
+
|
339 |
def generate_image_wrapper(seed, result, input_prompt):
|
340 |
for img1, img2 in generate_image(seed, result, input_prompt):
|
341 |
yield img1, img2, gr.update(interactive=False)
|
342 |
yield img1, img2, gr.update(interactive=True)
|
343 |
+
|
344 |
gen_img.click(
|
345 |
generate_image_wrapper,
|
346 |
[seed, result, input_prompt],
|
diff.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from functools import partial
|
2 |
|
3 |
import torch
|
@@ -46,76 +47,104 @@ def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"):
|
|
46 |
return pipe
|
47 |
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
).input_ids.to("cuda")
|
78 |
-
input_ids2 = pipe.tokenizer_2(
|
79 |
-
prompt,
|
80 |
-
truncation=False,
|
81 |
-
padding="max_length",
|
82 |
-
max_length=negative_ids.shape[-1],
|
83 |
-
return_tensors="pt",
|
84 |
-
).input_ids.to("cuda")
|
85 |
|
86 |
concat_embeds = []
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
concat_embeds2 = []
|
93 |
-
neg_embeds2 = []
|
94 |
pooled_embeds2 = []
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
)
|
100 |
-
concat_embeds2.append(hidden_states.hidden_states[-2])
|
101 |
pooled_embeds2.append(hidden_states[0])
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
|
109 |
prompt_embeds = torch.cat(concat_embeds, dim=1)
|
110 |
-
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
111 |
prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
|
112 |
-
negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
|
113 |
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
|
114 |
-
negative_prompt_embeds = torch.cat(
|
115 |
-
[negative_prompt_embeds, negative_prompt_embeds2], dim=-1
|
116 |
-
)
|
117 |
|
118 |
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
|
119 |
-
neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
|
120 |
|
121 |
-
return prompt_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
from functools import partial
|
3 |
|
4 |
import torch
|
|
|
47 |
return pipe
|
48 |
|
49 |
|
50 |
+
@torch.no_grad()
|
51 |
+
def encode_prompts(
|
52 |
+
pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = ""
|
53 |
+
):
|
54 |
+
prompts = [prompt, neg_prompt]
|
55 |
+
max_length = pipe.tokenizer.model_max_length - 2
|
56 |
+
|
57 |
+
input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt")
|
58 |
+
input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt")
|
59 |
+
length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1))
|
60 |
+
target_length = math.ceil(length / max_length) * max_length + 2
|
61 |
+
|
62 |
+
input_ids = pipe.tokenizer(
|
63 |
+
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
|
64 |
+
).input_ids
|
65 |
+
input_ids = (
|
66 |
+
input_ids[:, 0:1],
|
67 |
+
input_ids[:, 1:-1],
|
68 |
+
input_ids[:, -1:],
|
69 |
+
)
|
70 |
+
input_ids2 = pipe.tokenizer_2(
|
71 |
+
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
|
72 |
+
).input_ids
|
73 |
+
input_ids2 = (
|
74 |
+
input_ids2[:, 0:1],
|
75 |
+
input_ids2[:, 1:-1],
|
76 |
+
input_ids2[:, -1:],
|
77 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
concat_embeds = []
|
80 |
+
for i in range(0, input_ids[1].shape[-1], max_length):
|
81 |
+
input_id1 = torch.concat(
|
82 |
+
(input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1
|
83 |
+
).to(pipe.device)
|
84 |
+
result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[
|
85 |
+
-2
|
86 |
+
]
|
87 |
+
if i == 0:
|
88 |
+
concat_embeds.append(result[:, :-1])
|
89 |
+
elif i == input_ids[1].shape[-1] - max_length:
|
90 |
+
concat_embeds.append(result[:, 1:])
|
91 |
+
else:
|
92 |
+
concat_embeds.append(result[:, 1:-1])
|
93 |
|
94 |
concat_embeds2 = []
|
|
|
95 |
pooled_embeds2 = []
|
96 |
+
for i in range(0, input_ids2[1].shape[-1], max_length):
|
97 |
+
input_id2 = torch.concat(
|
98 |
+
(input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1
|
99 |
+
).to(pipe.device)
|
100 |
+
hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True)
|
|
|
101 |
pooled_embeds2.append(hidden_states[0])
|
102 |
+
if i == 0:
|
103 |
+
concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1])
|
104 |
+
elif i == input_ids2[1].shape[-1] - max_length:
|
105 |
+
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:])
|
106 |
+
else:
|
107 |
+
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1])
|
108 |
|
109 |
prompt_embeds = torch.cat(concat_embeds, dim=1)
|
|
|
110 |
prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
|
|
|
111 |
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
|
|
|
|
|
|
|
112 |
|
113 |
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
|
|
|
114 |
|
115 |
+
return prompt_embeds, pooled_embeds2
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
from meta import DEFAULT_NEGATIVE_PROMPT
|
120 |
+
prompt = """
|
121 |
+
1girl,
|
122 |
+
king halo (umamusume), umamusume,
|
123 |
+
|
124 |
+
ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc),
|
125 |
+
|
126 |
+
solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts,
|
127 |
+
|
128 |
+
masterpiece, newest, absurdres, sensitive
|
129 |
+
""".strip()
|
130 |
+
sdxl_pipe = load_model("KBlueLeaf/xxx")
|
131 |
+
# sdxl_pipe = load_model()
|
132 |
+
prompt_embeds, pooled_embeds2 = encode_prompts(
|
133 |
+
sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT
|
134 |
+
)
|
135 |
+
result = sdxl_pipe(
|
136 |
+
prompt_embeds=prompt_embeds[0:1],
|
137 |
+
negative_prompt_embeds=prompt_embeds[1:],
|
138 |
+
pooled_prompt_embeds=pooled_embeds2[0:1],
|
139 |
+
negative_pooled_prompt_embeds=pooled_embeds2[1:],
|
140 |
+
num_inference_steps=24,
|
141 |
+
width=1024,
|
142 |
+
height=1024,
|
143 |
+
guidance_scale=6.0,
|
144 |
+
).images[0]
|
145 |
+
|
146 |
+
result.save("test.png")
|
147 |
+
|
148 |
+
module = torch.compile(sdxl_pipe)
|
149 |
+
if isinstance(module, torch._dynamo.OptimizedModule):
|
150 |
+
original_module = module._orig_mod
|
meta.py
CHANGED
@@ -15,7 +15,7 @@ multiple tails, multiple views, copyright name, watermark, artist name, signatur
|
|
15 |
"""
|
16 |
|
17 |
DEFAULT_FORMAT = {
|
18 |
-
"tag only (DTG mode)":"""
|
19 |
<|special|>, <|characters|>, <|copyrights|>,
|
20 |
<|artist|>,
|
21 |
|
@@ -55,5 +55,5 @@ DEFAULT_FORMAT = {
|
|
55 |
<|extended|>.
|
56 |
|
57 |
<|quality|>, <|meta|>, <|rating|>
|
58 |
-
""".strip()
|
59 |
}
|
|
|
15 |
"""
|
16 |
|
17 |
DEFAULT_FORMAT = {
|
18 |
+
"tag only (DTG mode)": """
|
19 |
<|special|>, <|characters|>, <|copyrights|>,
|
20 |
<|artist|>,
|
21 |
|
|
|
55 |
<|extended|>.
|
56 |
|
57 |
<|quality|>, <|meta|>, <|rating|>
|
58 |
+
""".strip(),
|
59 |
}
|