r3gm commited on
Commit
7ecc2b1
·
verified ·
1 Parent(s): 3463ce4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -22
app.py CHANGED
@@ -1,6 +1,11 @@
1
  import spaces
2
  import os
3
- from stablepy import Model_Diffusers
 
 
 
 
 
4
  from constants import (
5
  DIRECTORY_MODELS,
6
  DIRECTORY_LORAS,
@@ -51,6 +56,7 @@ from utils import (
51
  download_diffuser_repo,
52
  progress_step_bar,
53
  html_template_message,
 
54
  )
55
  from datetime import datetime
56
  import gradio as gr
@@ -95,6 +101,7 @@ lora_model_list = get_model_list(DIRECTORY_LORAS)
95
  lora_model_list.insert(0, "None")
96
  lora_model_list = lora_model_list + DIFFUSERS_FORMAT_LORAS
97
  vae_model_list = get_model_list(DIRECTORY_VAES)
 
98
  vae_model_list.insert(0, "None")
99
 
100
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
@@ -153,7 +160,12 @@ class GuiSD:
153
 
154
  yield f"Loading model: {model_name}"
155
 
156
- if vae_model:
 
 
 
 
 
157
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
158
  if model_type != vae_type:
159
  gr.Warning(WARNING_MSG_VAE)
@@ -226,6 +238,8 @@ class GuiSD:
226
  lora5,
227
  lora_scale5,
228
  sampler,
 
 
229
  img_height,
230
  img_width,
231
  model_name,
@@ -265,6 +279,7 @@ class GuiSD:
265
  image_previews,
266
  display_images,
267
  save_generated_images,
 
268
  image_storage_location,
269
  retain_compel_previous_load,
270
  retain_detailfix_model_previous_load,
@@ -334,14 +349,15 @@ class GuiSD:
334
  (image_ip2, mask_ip2, model_ip2, mode_ip2, scale_ip2),
335
  ]
336
 
337
- for imgip, mskip, modelip, modeip, scaleip in all_adapters:
338
- if imgip:
339
- params_ip_img.append(imgip)
340
- if mskip:
341
- params_ip_msk.append(mskip)
342
- params_ip_model.append(modelip)
343
- params_ip_mode.append(modeip)
344
- params_ip_scale.append(scaleip)
 
345
 
346
  concurrency = 5
347
  self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
@@ -430,6 +446,8 @@ class GuiSD:
430
  "textual_inversion": embed_list if textual_inversion else [],
431
  "syntax_weights": syntax_weights, # "Classic"
432
  "sampler": sampler,
 
 
433
  "xformers_memory_efficient_attention": xformers_memory_efficient_attention,
434
  "gui_active": True,
435
  "loop_generation": loop_generation,
@@ -447,6 +465,7 @@ class GuiSD:
447
  "image_previews": image_previews,
448
  "display_images": display_images,
449
  "save_generated_images": save_generated_images,
 
450
  "image_storage_location": image_storage_location,
451
  "retain_compel_previous_load": retain_compel_previous_load,
452
  "retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
@@ -479,7 +498,7 @@ class GuiSD:
479
 
480
  actual_progress = 0
481
  info_images = gr.update()
482
- for img, seed, image_path, metadata in self.model(**pipe_params):
483
  info_state = progress_step_bar(actual_progress, steps)
484
  actual_progress += concurrency
485
  if image_path:
@@ -501,7 +520,7 @@ class GuiSD:
501
  if msg_lora:
502
  info_images += msg_lora
503
 
504
- info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
505
 
506
  download_links = "<br>".join(
507
  [
@@ -562,6 +581,14 @@ def sd_gen_generate_pipeline(*args):
562
  )
563
  print(lora_status)
564
 
 
 
 
 
 
 
 
 
565
  if verbose_arg:
566
  for status, lora in zip(lora_status, lora_list):
567
  if status:
@@ -688,7 +715,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
688
  with gr.Column(scale=1):
689
  steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Steps")
690
  cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7., label="CFG")
691
- sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
 
692
  img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
693
  img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
694
  seed_gui = gr.Number(minimum=-1, maximum=9999999999, value=-1, label="Seed")
@@ -707,14 +735,26 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
707
  "width": gr.update(value=1024),
708
  "height": gr.update(value=1024),
709
  "Seed": gr.update(value=-1),
710
- "Sampler": gr.update(value="Euler a"),
711
- "scale": gr.update(value=7.), # cfg
712
- "skip": gr.update(value=True),
713
  "Model": gr.update(value=name_model),
 
 
 
714
  }
715
  valid_keys = list(valid_receptors.keys())
716
 
717
  parameters = extract_parameters(base_prompt)
 
 
 
 
 
 
 
 
 
718
 
719
  for key, val in parameters.items():
720
  # print(val)
@@ -723,7 +763,10 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
723
  if key == "Sampler":
724
  if val not in scheduler_names:
725
  continue
726
- elif key == "skip":
 
 
 
727
  if "," in str(val):
728
  val = val.replace(",", "")
729
  if int(val) >= 2:
@@ -736,7 +779,9 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
736
  val = re.sub(r'\s+', ' ', re.sub(r',+', ',', val)).strip()
737
  if key in ["Steps", "width", "height", "Seed"]:
738
  val = int(val)
739
- if key == "scale":
 
 
740
  val = float(val)
741
  if key == "Model":
742
  filtered_models = [m for m in model_list if val in m]
@@ -765,6 +810,9 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
765
  cfg_gui,
766
  clip_skip_gui,
767
  model_name_gui,
 
 
 
768
  ],
769
  )
770
 
@@ -816,9 +864,14 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
816
  lora_scale_5_gui = lora_scale_slider("Lora Scale 5")
817
 
818
  with gr.Accordion("From URL", open=False, visible=True):
819
- text_lora = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
 
 
 
 
 
820
  romanize_text = gr.Checkbox(value=False, label="Transliterate name")
821
- button_lora = gr.Button("Obtain and refresh the LoRAs lists")
822
  new_lora_status = gr.HTML()
823
  button_lora.click(
824
  get_my_lora,
@@ -851,7 +904,10 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
851
  minimum=0.01, maximum=1.0, step=0.01, value=0.55, label="Strength",
852
  info="This option adjusts the level of changes for img2img and inpainting."
853
  )
854
- image_resolution_gui = gr.Slider(minimum=64, maximum=2048, step=64, value=1024, label="Image Resolution")
 
 
 
855
  preprocessor_name_gui = gr.Dropdown(label="Preprocessor Name", choices=PREPROCESSOR_CONTROLNET["canny"])
856
 
857
  def change_preprocessor_choices(task):
@@ -950,7 +1006,9 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
950
  mask_padding_b_gui = gr.Number(label="Mask padding:", value=32, minimum=1)
951
 
952
  with gr.Accordion("Other settings", open=False, visible=True):
 
953
  save_generated_images_gui = gr.Checkbox(value=True, label="Create a download link for the images")
 
954
  hires_before_adetailer_gui = gr.Checkbox(value=False, label="Hires Before Adetailer")
955
  hires_after_adetailer_gui = gr.Checkbox(value=True, label="Hires After Adetailer")
956
  generator_in_cpu_gui = gr.Checkbox(value=False, label="Generator in CPU")
@@ -1102,6 +1160,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1102
  lora5_gui,
1103
  lora_scale_5_gui,
1104
  sampler_gui,
 
 
1105
  img_height_gui,
1106
  img_width_gui,
1107
  model_name_gui,
@@ -1141,6 +1201,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1141
  image_previews_gui,
1142
  display_images_gui,
1143
  save_generated_images_gui,
 
1144
  image_storage_location_gui,
1145
  retain_compel_previous_load_gui,
1146
  retain_detailfix_model_previous_load_gui,
@@ -1201,4 +1262,4 @@ app.launch(
1201
  show_error=True,
1202
  debug=True,
1203
  allowed_paths=["./images/"],
1204
- )
 
1
  import spaces
2
  import os
3
+ from stablepy import (
4
+ Model_Diffusers,
5
+ SCHEDULE_TYPE_OPTIONS,
6
+ SCHEDULE_PREDICTION_TYPE_OPTIONS,
7
+ check_scheduler_compatibility,
8
+ )
9
  from constants import (
10
  DIRECTORY_MODELS,
11
  DIRECTORY_LORAS,
 
56
  download_diffuser_repo,
57
  progress_step_bar,
58
  html_template_message,
59
+ escape_html,
60
  )
61
  from datetime import datetime
62
  import gradio as gr
 
101
  lora_model_list.insert(0, "None")
102
  lora_model_list = lora_model_list + DIFFUSERS_FORMAT_LORAS
103
  vae_model_list = get_model_list(DIRECTORY_VAES)
104
+ vae_model_list.insert(0, "BakedVAE")
105
  vae_model_list.insert(0, "None")
106
 
107
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
 
160
 
161
  yield f"Loading model: {model_name}"
162
 
163
+ if vae_model == "BakedVAE":
164
+ if not os.path.exists(model_name):
165
+ vae_model = model_name
166
+ else:
167
+ vae_model = None
168
+ elif vae_model:
169
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
170
  if model_type != vae_type:
171
  gr.Warning(WARNING_MSG_VAE)
 
238
  lora5,
239
  lora_scale5,
240
  sampler,
241
+ schedule_type,
242
+ schedule_prediction_type,
243
  img_height,
244
  img_width,
245
  model_name,
 
279
  image_previews,
280
  display_images,
281
  save_generated_images,
282
+ filename_pattern,
283
  image_storage_location,
284
  retain_compel_previous_load,
285
  retain_detailfix_model_previous_load,
 
349
  (image_ip2, mask_ip2, model_ip2, mode_ip2, scale_ip2),
350
  ]
351
 
352
+ if not hasattr(self.model.pipe, "transformer"):
353
+ for imgip, mskip, modelip, modeip, scaleip in all_adapters:
354
+ if imgip:
355
+ params_ip_img.append(imgip)
356
+ if mskip:
357
+ params_ip_msk.append(mskip)
358
+ params_ip_model.append(modelip)
359
+ params_ip_mode.append(modeip)
360
+ params_ip_scale.append(scaleip)
361
 
362
  concurrency = 5
363
  self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
 
446
  "textual_inversion": embed_list if textual_inversion else [],
447
  "syntax_weights": syntax_weights, # "Classic"
448
  "sampler": sampler,
449
+ "schedule_type": schedule_type,
450
+ "schedule_prediction_type": schedule_prediction_type,
451
  "xformers_memory_efficient_attention": xformers_memory_efficient_attention,
452
  "gui_active": True,
453
  "loop_generation": loop_generation,
 
465
  "image_previews": image_previews,
466
  "display_images": display_images,
467
  "save_generated_images": save_generated_images,
468
+ "filename_pattern": filename_pattern,
469
  "image_storage_location": image_storage_location,
470
  "retain_compel_previous_load": retain_compel_previous_load,
471
  "retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
 
498
 
499
  actual_progress = 0
500
  info_images = gr.update()
501
+ for img, [seed, image_path, metadata] in self.model(**pipe_params):
502
  info_state = progress_step_bar(actual_progress, steps)
503
  actual_progress += concurrency
504
  if image_path:
 
520
  if msg_lora:
521
  info_images += msg_lora
522
 
523
+ info_images = info_images + "<br>" + "GENERATION DATA:<br>" + escape_html(metadata[0]) + "<br>-------<br>"
524
 
525
  download_links = "<br>".join(
526
  [
 
581
  )
582
  print(lora_status)
583
 
584
+ sampler_name = args[17]
585
+ schedule_type_name = args[18]
586
+ _, _, msg_sampler = check_scheduler_compatibility(
587
+ sd_gen.model.class_name, sampler_name, schedule_type_name
588
+ )
589
+ if msg_sampler:
590
+ gr.Warning(msg_sampler)
591
+
592
  if verbose_arg:
593
  for status, lora in zip(lora_status, lora_list):
594
  if status:
 
715
  with gr.Column(scale=1):
716
  steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Steps")
717
  cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7., label="CFG")
718
+ sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler")
719
+ schedule_type_gui = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_OPTIONS, value=SCHEDULE_TYPE_OPTIONS[0])
720
  img_width_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Width")
721
  img_height_gui = gr.Slider(minimum=64, maximum=4096, step=8, value=1024, label="Img Height")
722
  seed_gui = gr.Number(minimum=-1, maximum=9999999999, value=-1, label="Seed")
 
735
  "width": gr.update(value=1024),
736
  "height": gr.update(value=1024),
737
  "Seed": gr.update(value=-1),
738
+ "Sampler": gr.update(value="Euler"),
739
+ "CFG scale": gr.update(value=7.), # cfg
740
+ "Clip skip": gr.update(value=True),
741
  "Model": gr.update(value=name_model),
742
+ "Schedule type": gr.update(value="Automatic"),
743
+ "PAG": gr.update(value=.0),
744
+ "FreeU": gr.update(value=False),
745
  }
746
  valid_keys = list(valid_receptors.keys())
747
 
748
  parameters = extract_parameters(base_prompt)
749
+ # print(parameters)
750
+
751
+ if "Sampler" in parameters:
752
+ value_sampler = parameters["Sampler"]
753
+ for s_type in SCHEDULE_TYPE_OPTIONS:
754
+ if s_type in value_sampler:
755
+ value_sampler = value_sampler.replace(s_type, "").strip()
756
+ parameters["Sampler"] = value_sampler
757
+ parameters["Schedule type"] = s_type
758
 
759
  for key, val in parameters.items():
760
  # print(val)
 
763
  if key == "Sampler":
764
  if val not in scheduler_names:
765
  continue
766
+ if key == "Schedule type":
767
+ if val not in SCHEDULE_TYPE_OPTIONS:
768
+ val = "Automatic"
769
+ elif key == "Clip skip":
770
  if "," in str(val):
771
  val = val.replace(",", "")
772
  if int(val) >= 2:
 
779
  val = re.sub(r'\s+', ' ', re.sub(r',+', ',', val)).strip()
780
  if key in ["Steps", "width", "height", "Seed"]:
781
  val = int(val)
782
+ if key == "FreeU":
783
+ val = True
784
+ if key in ["CFG scale", "PAG"]:
785
  val = float(val)
786
  if key == "Model":
787
  filtered_models = [m for m in model_list if val in m]
 
810
  cfg_gui,
811
  clip_skip_gui,
812
  model_name_gui,
813
+ schedule_type_gui,
814
+ pag_scale_gui,
815
+ free_u_gui,
816
  ],
817
  )
818
 
 
864
  lora_scale_5_gui = lora_scale_slider("Lora Scale 5")
865
 
866
  with gr.Accordion("From URL", open=False, visible=True):
867
+ text_lora = gr.Textbox(
868
+ label="LoRA's download URL",
869
+ placeholder="https://civitai.com/api/download/models/28907",
870
+ lines=1,
871
+ info="It has to be .safetensors files, and you can also download them from Hugging Face.",
872
+ )
873
  romanize_text = gr.Checkbox(value=False, label="Transliterate name")
874
+ button_lora = gr.Button("Get and Refresh the LoRA Lists")
875
  new_lora_status = gr.HTML()
876
  button_lora.click(
877
  get_my_lora,
 
904
  minimum=0.01, maximum=1.0, step=0.01, value=0.55, label="Strength",
905
  info="This option adjusts the level of changes for img2img and inpainting."
906
  )
907
+ image_resolution_gui = gr.Slider(
908
+ minimum=64, maximum=2048, step=64, value=1024, label="Image Resolution",
909
+ info="The maximum proportional size of the generated image based on the uploaded image."
910
+ )
911
  preprocessor_name_gui = gr.Dropdown(label="Preprocessor Name", choices=PREPROCESSOR_CONTROLNET["canny"])
912
 
913
  def change_preprocessor_choices(task):
 
1006
  mask_padding_b_gui = gr.Number(label="Mask padding:", value=32, minimum=1)
1007
 
1008
  with gr.Accordion("Other settings", open=False, visible=True):
1009
+ schedule_prediction_type_gui = gr.Dropdown(label="Discrete Sampling Type", choices=SCHEDULE_PREDICTION_TYPE_OPTIONS, value=SCHEDULE_PREDICTION_TYPE_OPTIONS[0])
1010
  save_generated_images_gui = gr.Checkbox(value=True, label="Create a download link for the images")
1011
+ filename_pattern_gui = gr.Textbox(label="Filename pattern", value="model,seed", placeholder="model,seed,sampler,schedule_type,img_width,img_height,guidance_scale,num_steps,vae,prompt_section,neg_prompt_section", lines=1)
1012
  hires_before_adetailer_gui = gr.Checkbox(value=False, label="Hires Before Adetailer")
1013
  hires_after_adetailer_gui = gr.Checkbox(value=True, label="Hires After Adetailer")
1014
  generator_in_cpu_gui = gr.Checkbox(value=False, label="Generator in CPU")
 
1160
  lora5_gui,
1161
  lora_scale_5_gui,
1162
  sampler_gui,
1163
+ schedule_type_gui,
1164
+ schedule_prediction_type_gui,
1165
  img_height_gui,
1166
  img_width_gui,
1167
  model_name_gui,
 
1201
  image_previews_gui,
1202
  display_images_gui,
1203
  save_generated_images_gui,
1204
+ filename_pattern_gui,
1205
  image_storage_location_gui,
1206
  retain_compel_previous_load_gui,
1207
  retain_detailfix_model_previous_load_gui,
 
1262
  show_error=True,
1263
  debug=True,
1264
  allowed_paths=["./images/"],
1265
+ )