Kohaku-Blueleaf commited on
Commit
06f0d78
·
1 Parent(s): e73123b

Fix encode prompt impl

Browse files
Files changed (4) hide show
  1. app-local.py +378 -0
  2. app.py +23 -16
  3. diff.py +89 -60
  4. meta.py +2 -2
app-local.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import gradio as gr
3
+
4
+ ## if kgen not exist
5
+ try:
6
+ import kgen
7
+ except:
8
+ GH_TOKEN = os.getenv("GITHUB_TOKEN")
9
+ git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TIPO-KGen@tipo"
10
+
11
+ ## call pip install
12
+ os.system(f"pip install git+{git_url}")
13
+
14
+ import re
15
+ import random
16
+ from time import time
17
+
18
+ import torch
19
+ from transformers import set_seed
20
+
21
+ if sys.platform == "win32":
22
+ # dev env in windows, @spaces.GPU will cause problem
23
+ def GPU(**kwargs):
24
+ return lambda x: x
25
+
26
+ else:
27
+ from spaces import GPU
28
+
29
+ import kgen.models as models
30
+ import kgen.executor.tipo as tipo
31
+ from kgen.formatter import seperate_tags, apply_format
32
+ from kgen.generate import generate
33
+
34
+ from diff import load_model, encode_prompts
35
+ from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
36
+
37
+
38
+ sdxl_pipe = load_model()
39
+ sdxl_pipe.text_encoder.to("cpu")
40
+ sdxl_pipe.text_encoder_2.to("cpu")
41
+ sdxl_pipe.vae.to("cpu")
42
+ sdxl_pipe.k_diffusion_model.to("cpu")
43
+
44
+ models.load_model("Amber-River/tipo", device="cuda", subfolder="500M-epoch3")
45
+ generate(max_new_tokens=4)
46
+ torch.cuda.empty_cache()
47
+
48
+
49
+ DEFAULT_TAGS = """
50
+ 1girl, king halo (umamusume), umamusume,
51
+ ningen mame, ciloranko, ogipote, misu kasumi,
52
+ solo, leaning forward, sky,
53
+ masterpiece, absurdres, sensitive, newest
54
+ """.strip()
55
+ DEFAULT_NL = """
56
+ An illustration of a girl
57
+ """.strip()
58
+
59
+
60
+ def format_time(timing):
61
+ total = timing["total"]
62
+ generate_pass = timing["generate_pass"]
63
+
64
+ result = ""
65
+
66
+ result += f"""
67
+ ### Process Time
68
+ | Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second|
69
+ |-|-|-|
70
+ """
71
+ if "generated_tokens" in timing:
72
+ total_generated_tokens = timing["generated_tokens"]
73
+ total_input_tokens = timing["input_tokens"]
74
+ if "generated_tokens" in timing and "total_sampling" in timing:
75
+ sampling_time = timing["total_sampling"] / 1000
76
+ process_time = timing["prompt_process"] / 1000
77
+ model_time = timing["total_eval"] / 1000
78
+
79
+ result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second|
80
+ | Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second|
81
+ | Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second|
82
+ """
83
+
84
+ if "generated_tokens" in timing:
85
+ result += f"""
86
+ ### Processed Tokens:
87
+ * {total_input_tokens:} Input Tokens
88
+ * {total_generated_tokens:} Output Tokens
89
+ """
90
+ return result
91
+
92
+
93
+ @GPU(duration=10)
94
+ @torch.no_grad()
95
+ def generate(
96
+ tags,
97
+ nl_prompt,
98
+ black_list,
99
+ temp,
100
+ output_format,
101
+ target_length,
102
+ top_p,
103
+ min_p,
104
+ top_k,
105
+ seed,
106
+ escape_brackets,
107
+ ):
108
+ torch.cuda.empty_cache()
109
+ default_format = DEFAULT_FORMAT[output_format]
110
+ tipo.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()]
111
+ generation_setting = {
112
+ "seed": seed,
113
+ "temperature": temp,
114
+ "top_p": top_p,
115
+ "min_p": min_p,
116
+ "top_k": top_k,
117
+ }
118
+ inputs = seperate_tags(tags.split(","))
119
+ if nl_prompt:
120
+ if "<|extended|>" in default_format:
121
+ inputs["extended"] = nl_prompt
122
+ elif "<|generated|>" in default_format:
123
+ inputs["generated"] = nl_prompt
124
+ input_prompt = apply_format(inputs, default_format)
125
+ if escape_brackets:
126
+ input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt)
127
+
128
+ meta, operations, general, nl_prompt = tipo.parse_tipo_request(
129
+ seperate_tags(tags.split(",")),
130
+ nl_prompt,
131
+ tag_length_target=target_length,
132
+ generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt,
133
+ )
134
+ t0 = time()
135
+ for result, timing in tipo.tipo_runner_generator(
136
+ meta, operations, general, nl_prompt, **generation_setting
137
+ ):
138
+ result = apply_format(result, default_format)
139
+ if escape_brackets:
140
+ result = re.sub(r"([()\[\]])", r"\\\1", result)
141
+ timing["total"] = time() - t0
142
+ yield result, input_prompt, format_time(timing)
143
+ torch.cuda.empty_cache()
144
+
145
+
146
+ @GPU(duration=20)
147
+ @torch.no_grad()
148
+ def generate_image(
149
+ seed,
150
+ prompt,
151
+ prompt2,
152
+ ):
153
+ torch.cuda.empty_cache()
154
+ set_seed(seed)
155
+ sdxl_pipe.text_encoder.to("cuda")
156
+ sdxl_pipe.text_encoder_2.to("cuda")
157
+ prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
158
+ encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
159
+ )
160
+ sdxl_pipe.vae.to("cuda")
161
+ sdxl_pipe.k_diffusion_model.to("cuda")
162
+ print(prompt_embeds.device)
163
+ result2 = sdxl_pipe(
164
+ prompt_embeds=prompt_embeds,
165
+ negative_prompt_embeds=negative_prompt_embeds,
166
+ pooled_prompt_embeds=pooled_embeds2,
167
+ negative_pooled_prompt_embeds=neg_pooled_embeds2,
168
+ num_inference_steps=24,
169
+ width=1024,
170
+ height=1024,
171
+ guidance_scale=6.0,
172
+ ).images[0]
173
+ sdxl_pipe.text_encoder.to("cpu")
174
+ sdxl_pipe.text_encoder_2.to("cpu")
175
+ sdxl_pipe.vae.to("cpu")
176
+ sdxl_pipe.k_diffusion_model.to("cpu")
177
+ torch.cuda.empty_cache()
178
+ yield result2, None
179
+
180
+ set_seed(seed)
181
+ sdxl_pipe.text_encoder.to("cuda")
182
+ sdxl_pipe.text_encoder_2.to("cuda")
183
+ prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
184
+ encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
185
+ )
186
+ sdxl_pipe.vae.to("cuda")
187
+ sdxl_pipe.k_diffusion_model.to("cuda")
188
+ result = sdxl_pipe(
189
+ prompt_embeds=prompt_embeds,
190
+ negative_prompt_embeds=negative_prompt_embeds,
191
+ pooled_prompt_embeds=pooled_embeds2,
192
+ negative_pooled_prompt_embeds=neg_pooled_embeds2,
193
+ num_inference_steps=24,
194
+ width=1024,
195
+ height=1024,
196
+ guidance_scale=6.0,
197
+ ).images[0]
198
+ sdxl_pipe.text_encoder.to("cpu")
199
+ sdxl_pipe.text_encoder_2.to("cpu")
200
+ sdxl_pipe.vae.to("cpu")
201
+ sdxl_pipe.k_diffusion_model.to("cpu")
202
+ torch.cuda.empty_cache()
203
+ yield result2, result
204
+
205
+
206
+ if __name__ == "__main__":
207
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
208
+ with gr.Accordion("Introduction and Instructions", open=False):
209
+ gr.Markdown(
210
+ """
211
+ ## TIPO Demo
212
+ ### What is this
213
+ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
214
+ <br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
215
+ <br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2)
216
+
217
+ ### How to use this demo
218
+ 1. Enter your tags(optional): put the desired tags into "danboru tags" box
219
+ 2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box
220
+ 3. Enter your black list(optional): put the desired black list into "black list" box
221
+ 4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ...
222
+ 4. Click "TIPO" button: you will see refined prompt on "result" box
223
+ 5. If you like the result, click "Generate Image From Result" button
224
+ * You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt
225
+ * The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False
226
+
227
+ ### Why inference code is private? When will it be open sourced?
228
+ 1. This model/tool is still under development, currently is early Alpha version.
229
+ 2. I'm doing some research and projects based on this.
230
+ 3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself.
231
+ 4. Once the project/research are done, I will open source all these models/codes with Apache2 license.
232
+
233
+ ### Notification
234
+ **TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
235
+ <br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
236
+ """
237
+ )
238
+ with gr.Row():
239
+ with gr.Column(scale=5):
240
+ with gr.Row():
241
+ with gr.Column(scale=3):
242
+ tags_input = gr.TextArea(
243
+ label="Danbooru Tags",
244
+ lines=7,
245
+ show_copy_button=True,
246
+ interactive=True,
247
+ value=DEFAULT_TAGS,
248
+ placeholder="Enter danbooru tags here",
249
+ )
250
+ nl_prompt_input = gr.Textbox(
251
+ label="Natural Language Prompt",
252
+ lines=7,
253
+ show_copy_button=True,
254
+ interactive=True,
255
+ value=DEFAULT_NL,
256
+ placeholder="Enter Natural Language Prompt here",
257
+ )
258
+ black_list = gr.TextArea(
259
+ label="Black List (seperated by comma)",
260
+ lines=4,
261
+ interactive=True,
262
+ value="monochrome",
263
+ placeholder="Enter tag/nl black list here",
264
+ )
265
+ with gr.Column(scale=2):
266
+ output_format = gr.Dropdown(
267
+ label="Output Format",
268
+ choices=list(DEFAULT_FORMAT.keys()),
269
+ value="Both, tag first (recommend)",
270
+ )
271
+ target_length = gr.Dropdown(
272
+ label="Target Length",
273
+ choices=["very_short", "short", "long", "very_long"],
274
+ value="long",
275
+ )
276
+ temp = gr.Slider(
277
+ label="Temp",
278
+ minimum=0.0,
279
+ maximum=1.5,
280
+ value=0.5,
281
+ step=0.05,
282
+ )
283
+ top_p = gr.Slider(
284
+ label="Top P",
285
+ minimum=0.0,
286
+ maximum=1.0,
287
+ value=0.95,
288
+ step=0.05,
289
+ )
290
+ min_p = gr.Slider(
291
+ label="Min P",
292
+ minimum=0.0,
293
+ maximum=0.2,
294
+ value=0.05,
295
+ step=0.01,
296
+ )
297
+ top_k = gr.Slider(
298
+ label="Top K", minimum=0, maximum=120, value=60, step=1
299
+ )
300
+ with gr.Row():
301
+ seed = gr.Number(
302
+ label="Seed",
303
+ minimum=0,
304
+ maximum=2147483647,
305
+ value=20090220,
306
+ step=1,
307
+ )
308
+ escape_brackets = gr.Checkbox(
309
+ label="Escape Brackets", value=False
310
+ )
311
+ submit = gr.Button("TIPO!", variant="primary")
312
+ with gr.Accordion("Speed statstics", open=False):
313
+ cost_time = gr.Markdown()
314
+ with gr.Column(scale=5):
315
+ result = gr.TextArea(
316
+ label="Result", lines=8, show_copy_button=True, interactive=False
317
+ )
318
+ input_prompt = gr.Textbox(
319
+ label="Input Prompt", lines=1, interactive=False, visible=False
320
+ )
321
+ gen_img = gr.Button(
322
+ "Generate Image from Result", variant="primary", interactive=False
323
+ )
324
+ with gr.Row():
325
+ with gr.Column():
326
+ img1 = gr.Image(label="Original Propmt", interactive=False)
327
+ with gr.Column():
328
+ img2 = gr.Image(label="Generated Prompt", interactive=False)
329
+
330
+ def generate_wrapper(*args):
331
+ yield "", "", "", gr.update(interactive=False),
332
+ for i in generate(*args):
333
+ yield *i, gr.update(interactive=False)
334
+ yield *i, gr.update(interactive=True)
335
+
336
+ submit.click(
337
+ generate_wrapper,
338
+ [
339
+ tags_input,
340
+ nl_prompt_input,
341
+ black_list,
342
+ temp,
343
+ output_format,
344
+ target_length,
345
+ top_p,
346
+ min_p,
347
+ top_k,
348
+ seed,
349
+ escape_brackets,
350
+ ],
351
+ [
352
+ result,
353
+ input_prompt,
354
+ cost_time,
355
+ gen_img,
356
+ ],
357
+ queue=True,
358
+ )
359
+
360
+ def generate_image_wrapper(seed, result, input_prompt):
361
+ for img1, img2 in generate_image(seed, result, input_prompt):
362
+ yield img1, img2, gr.update(interactive=False)
363
+ yield img1, img2, gr.update(interactive=True)
364
+
365
+ gen_img.click(
366
+ generate_image_wrapper,
367
+ [seed, result, input_prompt],
368
+ [img1, img2, submit],
369
+ queue=True,
370
+ )
371
+ gen_img.click(
372
+ lambda *args: gr.update(interactive=False),
373
+ None,
374
+ [submit],
375
+ queue=False,
376
+ )
377
+
378
+ demo.launch()
app.py CHANGED
@@ -17,10 +17,12 @@ from time import time
17
 
18
  import torch
19
  from transformers import set_seed
 
20
  if sys.platform == "win32":
21
- #dev env in windows, @spaces.GPU will cause problem
22
  def GPU(func, *args, **kwargs):
23
  return func
 
24
  else:
25
  from spaces import GPU
26
 
@@ -33,7 +35,7 @@ from diff import load_model, encode_prompts
33
  from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
34
 
35
 
36
- sdxl_pipe = load_model()
37
 
38
  models.load_model(
39
  "Amber-River/tipo",
@@ -145,14 +147,14 @@ def generate_image(
145
  ):
146
  torch.cuda.empty_cache()
147
  set_seed(seed)
148
- prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
149
  encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
150
  )
151
  result2 = sdxl_pipe(
152
- prompt_embeds=prompt_embeds,
153
- negative_prompt_embeds=negative_prompt_embeds,
154
- pooled_prompt_embeds=pooled_embeds2,
155
- negative_pooled_prompt_embeds=neg_pooled_embeds2,
156
  num_inference_steps=24,
157
  width=1024,
158
  height=1024,
@@ -160,15 +162,15 @@ def generate_image(
160
  ).images[0]
161
  yield result2, None
162
 
163
- prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
164
  encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
165
  )
166
  set_seed(seed)
167
  result = sdxl_pipe(
168
- prompt_embeds=prompt_embeds,
169
- negative_prompt_embeds=negative_prompt_embeds,
170
- pooled_prompt_embeds=pooled_embeds2,
171
- negative_pooled_prompt_embeds=neg_pooled_embeds2,
172
  num_inference_steps=24,
173
  width=1024,
174
  height=1024,
@@ -209,7 +211,7 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
209
 
210
  ### Notification
211
  **TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
212
- <br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
213
  """
214
  )
215
  with gr.Row():
@@ -243,7 +245,7 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
243
  output_format = gr.Dropdown(
244
  label="Output Format",
245
  choices=list(DEFAULT_FORMAT.keys()),
246
- value="Both, tag first (recommend)"
247
  )
248
  target_length = gr.Dropdown(
249
  label="Target Length",
@@ -295,17 +297,21 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
295
  input_prompt = gr.Textbox(
296
  label="Input Prompt", lines=1, interactive=False, visible=False
297
  )
298
- gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False)
 
 
299
  with gr.Row():
300
  with gr.Column():
301
  img1 = gr.Image(label="Original Prompt", interactive=False)
302
  with gr.Column():
303
  img2 = gr.Image(label="Generated Prompt", interactive=False)
 
304
  def generate_wrapper(*args):
305
  yield "", "", "", gr.update(interactive=False),
306
  for i in generate(*args):
307
  yield *i, gr.update(interactive=False)
308
  yield *i, gr.update(interactive=True)
 
309
  submit.click(
310
  generate_wrapper,
311
  [
@@ -329,11 +335,12 @@ TIPO is a tool to extend, generate, refine the input prompt for T2I models.
329
  ],
330
  queue=True,
331
  )
332
-
333
  def generate_image_wrapper(seed, result, input_prompt):
334
  for img1, img2 in generate_image(seed, result, input_prompt):
335
  yield img1, img2, gr.update(interactive=False)
336
  yield img1, img2, gr.update(interactive=True)
 
337
  gen_img.click(
338
  generate_image_wrapper,
339
  [seed, result, input_prompt],
 
17
 
18
  import torch
19
  from transformers import set_seed
20
+
21
  if sys.platform == "win32":
22
+ # dev env in windows, @spaces.GPU will cause problem
23
  def GPU(func, *args, **kwargs):
24
  return func
25
+
26
  else:
27
  from spaces import GPU
28
 
 
35
  from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
36
 
37
 
38
+ sdxl_pipe = load_model("OnomaAIResearch/Illustrious-xl-early-release-v0")
39
 
40
  models.load_model(
41
  "Amber-River/tipo",
 
147
  ):
148
  torch.cuda.empty_cache()
149
  set_seed(seed)
150
+ prompt_embeds, pooled_embeds2 = (
151
  encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
152
  )
153
  result2 = sdxl_pipe(
154
+ prompt_embeds=prompt_embeds[0:1],
155
+ negative_prompt_embeds=prompt_embeds[1:],
156
+ pooled_prompt_embeds=pooled_embeds2[0:1],
157
+ negative_pooled_prompt_embeds=pooled_embeds2[1:],
158
  num_inference_steps=24,
159
  width=1024,
160
  height=1024,
 
162
  ).images[0]
163
  yield result2, None
164
 
165
+ prompt_embeds, pooled_embeds2 = (
166
  encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
167
  )
168
  set_seed(seed)
169
  result = sdxl_pipe(
170
+ prompt_embeds=prompt_embeds[0:1],
171
+ negative_prompt_embeds=prompt_embeds[1:],
172
+ pooled_prompt_embeds=pooled_embeds2[0:1],
173
+ negative_pooled_prompt_embeds=pooled_embeds2[1:],
174
  num_inference_steps=24,
175
  width=1024,
176
  height=1024,
 
211
 
212
  ### Notification
213
  **TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
214
+ <br>The generated images come from OnomaAIResearch/Illustrious-xl-early-release-v0 SDXL-based model**
215
  """
216
  )
217
  with gr.Row():
 
245
  output_format = gr.Dropdown(
246
  label="Output Format",
247
  choices=list(DEFAULT_FORMAT.keys()),
248
+ value="Both, tag first (recommend)",
249
  )
250
  target_length = gr.Dropdown(
251
  label="Target Length",
 
297
  input_prompt = gr.Textbox(
298
  label="Input Prompt", lines=1, interactive=False, visible=False
299
  )
300
+ gen_img = gr.Button(
301
+ "Generate Image from Result", variant="primary", interactive=False
302
+ )
303
  with gr.Row():
304
  with gr.Column():
305
  img1 = gr.Image(label="Original Prompt", interactive=False)
306
  with gr.Column():
307
  img2 = gr.Image(label="Generated Prompt", interactive=False)
308
+
309
  def generate_wrapper(*args):
310
  yield "", "", "", gr.update(interactive=False),
311
  for i in generate(*args):
312
  yield *i, gr.update(interactive=False)
313
  yield *i, gr.update(interactive=True)
314
+
315
  submit.click(
316
  generate_wrapper,
317
  [
 
335
  ],
336
  queue=True,
337
  )
338
+
339
  def generate_image_wrapper(seed, result, input_prompt):
340
  for img1, img2 in generate_image(seed, result, input_prompt):
341
  yield img1, img2, gr.update(interactive=False)
342
  yield img1, img2, gr.update(interactive=True)
343
+
344
  gen_img.click(
345
  generate_image_wrapper,
346
  [seed, result, input_prompt],
diff.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import partial
2
 
3
  import torch
@@ -46,76 +47,104 @@ def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"):
46
  return pipe
47
 
48
 
49
- def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt):
50
- max_length = pipe.tokenizer.model_max_length
51
-
52
- input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
53
- input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda")
54
-
55
- negative_ids = pipe.tokenizer(
56
- neg_prompt,
57
- truncation=False,
58
- padding="max_length",
59
- max_length=input_ids.shape[-1],
60
- return_tensors="pt",
61
- ).input_ids.to("cuda")
62
- negative_ids2 = pipe.tokenizer_2(
63
- neg_prompt,
64
- truncation=False,
65
- padding="max_length",
66
- max_length=input_ids.shape[-1],
67
- return_tensors="pt",
68
- ).input_ids.to("cuda")
69
-
70
- if negative_ids.size() > input_ids.size():
71
- input_ids = pipe.tokenizer(
72
- prompt,
73
- truncation=False,
74
- padding="max_length",
75
- max_length=negative_ids.shape[-1],
76
- return_tensors="pt",
77
- ).input_ids.to("cuda")
78
- input_ids2 = pipe.tokenizer_2(
79
- prompt,
80
- truncation=False,
81
- padding="max_length",
82
- max_length=negative_ids.shape[-1],
83
- return_tensors="pt",
84
- ).input_ids.to("cuda")
85
 
86
  concat_embeds = []
87
- neg_embeds = []
88
- for i in range(0, input_ids.shape[-1], max_length):
89
- concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0])
90
- neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0])
 
 
 
 
 
 
 
 
 
91
 
92
  concat_embeds2 = []
93
- neg_embeds2 = []
94
  pooled_embeds2 = []
95
- neg_pooled_embeds2 = []
96
- for i in range(0, input_ids.shape[-1], max_length):
97
- hidden_states = pipe.text_encoder_2(
98
- input_ids2[:, i : i + max_length], output_hidden_states=True
99
- )
100
- concat_embeds2.append(hidden_states.hidden_states[-2])
101
  pooled_embeds2.append(hidden_states[0])
102
-
103
- hidden_states = pipe.text_encoder_2(
104
- negative_ids2[:, i : i + max_length], output_hidden_states=True
105
- )
106
- neg_embeds2.append(hidden_states.hidden_states[-2])
107
- neg_pooled_embeds2.append(hidden_states[0])
108
 
109
  prompt_embeds = torch.cat(concat_embeds, dim=1)
110
- negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
111
  prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
112
- negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
113
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
114
- negative_prompt_embeds = torch.cat(
115
- [negative_prompt_embeds, negative_prompt_embeds2], dim=-1
116
- )
117
 
118
  pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
119
- neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
120
 
121
- return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
  from functools import partial
3
 
4
  import torch
 
47
  return pipe
48
 
49
 
50
+ @torch.no_grad()
51
+ def encode_prompts(
52
+ pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = ""
53
+ ):
54
+ prompts = [prompt, neg_prompt]
55
+ max_length = pipe.tokenizer.model_max_length - 2
56
+
57
+ input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt")
58
+ input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt")
59
+ length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1))
60
+ target_length = math.ceil(length / max_length) * max_length + 2
61
+
62
+ input_ids = pipe.tokenizer(
63
+ prompts, padding="max_length", max_length=target_length, return_tensors="pt"
64
+ ).input_ids
65
+ input_ids = (
66
+ input_ids[:, 0:1],
67
+ input_ids[:, 1:-1],
68
+ input_ids[:, -1:],
69
+ )
70
+ input_ids2 = pipe.tokenizer_2(
71
+ prompts, padding="max_length", max_length=target_length, return_tensors="pt"
72
+ ).input_ids
73
+ input_ids2 = (
74
+ input_ids2[:, 0:1],
75
+ input_ids2[:, 1:-1],
76
+ input_ids2[:, -1:],
77
+ )
 
 
 
 
 
 
 
 
78
 
79
  concat_embeds = []
80
+ for i in range(0, input_ids[1].shape[-1], max_length):
81
+ input_id1 = torch.concat(
82
+ (input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1
83
+ ).to(pipe.device)
84
+ result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[
85
+ -2
86
+ ]
87
+ if i == 0:
88
+ concat_embeds.append(result[:, :-1])
89
+ elif i == input_ids[1].shape[-1] - max_length:
90
+ concat_embeds.append(result[:, 1:])
91
+ else:
92
+ concat_embeds.append(result[:, 1:-1])
93
 
94
  concat_embeds2 = []
 
95
  pooled_embeds2 = []
96
+ for i in range(0, input_ids2[1].shape[-1], max_length):
97
+ input_id2 = torch.concat(
98
+ (input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1
99
+ ).to(pipe.device)
100
+ hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True)
 
101
  pooled_embeds2.append(hidden_states[0])
102
+ if i == 0:
103
+ concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1])
104
+ elif i == input_ids2[1].shape[-1] - max_length:
105
+ concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:])
106
+ else:
107
+ concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1])
108
 
109
  prompt_embeds = torch.cat(concat_embeds, dim=1)
 
110
  prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
 
111
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
 
 
 
112
 
113
  pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
 
114
 
115
+ return prompt_embeds, pooled_embeds2
116
+
117
+
118
+ if __name__ == "__main__":
119
+ from meta import DEFAULT_NEGATIVE_PROMPT
120
+ prompt = """
121
+ 1girl,
122
+ king halo (umamusume), umamusume,
123
+
124
+ ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc),
125
+
126
+ solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts,
127
+
128
+ masterpiece, newest, absurdres, sensitive
129
+ """.strip()
130
+ sdxl_pipe = load_model("KBlueLeaf/xxx")
131
+ # sdxl_pipe = load_model()
132
+ prompt_embeds, pooled_embeds2 = encode_prompts(
133
+ sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT
134
+ )
135
+ result = sdxl_pipe(
136
+ prompt_embeds=prompt_embeds[0:1],
137
+ negative_prompt_embeds=prompt_embeds[1:],
138
+ pooled_prompt_embeds=pooled_embeds2[0:1],
139
+ negative_pooled_prompt_embeds=pooled_embeds2[1:],
140
+ num_inference_steps=24,
141
+ width=1024,
142
+ height=1024,
143
+ guidance_scale=6.0,
144
+ ).images[0]
145
+
146
+ result.save("test.png")
147
+
148
+ module = torch.compile(sdxl_pipe)
149
+ if isinstance(module, torch._dynamo.OptimizedModule):
150
+ original_module = module._orig_mod
meta.py CHANGED
@@ -15,7 +15,7 @@ multiple tails, multiple views, copyright name, watermark, artist name, signatur
15
  """
16
 
17
  DEFAULT_FORMAT = {
18
- "tag only (DTG mode)":"""
19
  <|special|>, <|characters|>, <|copyrights|>,
20
  <|artist|>,
21
 
@@ -55,5 +55,5 @@ DEFAULT_FORMAT = {
55
  <|extended|>.
56
 
57
  <|quality|>, <|meta|>, <|rating|>
58
- """.strip()
59
  }
 
15
  """
16
 
17
  DEFAULT_FORMAT = {
18
+ "tag only (DTG mode)": """
19
  <|special|>, <|characters|>, <|copyrights|>,
20
  <|artist|>,
21
 
 
55
  <|extended|>.
56
 
57
  <|quality|>, <|meta|>, <|rating|>
58
+ """.strip(),
59
  }