fffiloni commited on
Commit
60ac9c1
·
verified ·
1 Parent(s): 4daf597

add use low vram option

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -34,7 +34,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
  print(device)
35
 
36
  # Flag for low VRAM usage
37
- low_vram = False
38
 
39
  # Function definition for low VRAM usage
40
  def models_to(model, device="cpu", excepts=None):
@@ -107,11 +107,13 @@ models_b = WurstCoreB.Models(
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
 
110
  if low_vram:
111
  # Off-load old generator (which is not used in models_rbm)
112
  models.generator.to("cpu")
113
  torch.cuda.empty_cache()
114
  gc.collect()
 
115
 
116
  generator_rbm = StageCRBM()
117
  for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
@@ -128,10 +130,10 @@ models_rbm.generator.eval().requires_grad_(False)
128
 
129
 
130
 
131
- def infer(ref_style_file, style_description, caption, progress):
132
  global models_rbm, models_b, device
133
 
134
- if low_vram:
135
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
136
  try:
137
 
@@ -167,7 +169,7 @@ def infer(ref_style_file, style_description, caption, progress):
167
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
168
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
169
 
170
- if low_vram:
171
  # The sampling process uses more vram, so we offload everything except two modules to the cpu.
172
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
173
 
@@ -236,10 +238,10 @@ def infer(ref_style_file, style_description, caption, progress):
236
  torch.cuda.empty_cache()
237
  gc.collect()
238
 
239
- def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
240
  global models_rbm, models_b, device
241
  sam_model = LangSAM()
242
- if low_vram:
243
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
244
  models_to(sam_model, device=device)
245
  models_to(sam_model.sam, device=device)
@@ -288,7 +290,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
288
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
289
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
290
 
291
- if low_vram:
292
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
293
  models_to(sam_model, device="cpu")
294
  models_to(sam_model.sam, device="cpu")
@@ -363,13 +365,13 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
363
  torch.cuda.empty_cache()
364
  gc.collect()
365
 
366
- def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
367
  result = None
368
  progress = gr.Progress(track_tqdm=True)
369
  if use_subject_ref is True:
370
- result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, progress)
371
  else:
372
- result = infer(style_reference_image, style_description, subject_prompt, progress)
373
  return result
374
 
375
  def show_hide_subject_image_component(use_subject_ref):
@@ -406,7 +408,9 @@ with gr.Blocks(analytics_enabled=False) as demo:
406
  subject_prompt = gr.Textbox(
407
  label = "Subject Prompt"
408
  )
409
- use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
 
 
410
 
411
  with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
412
  subject_reference = gr.Image(label="Subject Reference", type="filepath")
@@ -418,13 +422,13 @@ with gr.Blocks(analytics_enabled=False) as demo:
418
  output_image = gr.Image(label="Output Image")
419
  gr.Examples(
420
  examples = [
421
- ["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False],
422
- ["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False],
423
- ["./data/glowing.png", "glowing style", "a dwarf", None, False],
424
- ["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True]
425
  ],
426
  fn=run,
427
- inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
428
  outputs=[output_image],
429
  cache_examples=False
430
 
@@ -439,7 +443,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
439
 
440
  submit_btn.click(
441
  fn = run,
442
- inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
443
  outputs = [output_image],
444
  show_api = False
445
  )
 
34
  print(device)
35
 
36
  # Flag for low VRAM usage
37
+ # low_vram = False
38
 
39
  # Function definition for low VRAM usage
40
  def models_to(model, device="cpu", excepts=None):
 
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
110
+ """
111
  if low_vram:
112
  # Off-load old generator (which is not used in models_rbm)
113
  models.generator.to("cpu")
114
  torch.cuda.empty_cache()
115
  gc.collect()
116
+ """
117
 
118
  generator_rbm = StageCRBM()
119
  for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
 
130
 
131
 
132
 
133
+ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
134
  global models_rbm, models_b, device
135
 
136
+ if use_low_vram:
137
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
138
  try:
139
 
 
169
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
170
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
171
 
172
+ if use_low_vram:
173
  # The sampling process uses more vram, so we offload everything except two modules to the cpu.
174
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
175
 
 
238
  torch.cuda.empty_cache()
239
  gc.collect()
240
 
241
+ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
242
  global models_rbm, models_b, device
243
  sam_model = LangSAM()
244
+ if use_low_vram:
245
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
246
  models_to(sam_model, device=device)
247
  models_to(sam_model.sam, device=device)
 
290
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
291
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
292
 
293
+ if use_low_vram:
294
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
295
  models_to(sam_model, device="cpu")
296
  models_to(sam_model.sam, device="cpu")
 
365
  torch.cuda.empty_cache()
366
  gc.collect()
367
 
368
+ def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
369
  result = None
370
  progress = gr.Progress(track_tqdm=True)
371
  if use_subject_ref is True:
372
+ result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, use_low_vram, progress)
373
  else:
374
+ result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
375
  return result
376
 
377
  def show_hide_subject_image_component(use_subject_ref):
 
408
  subject_prompt = gr.Textbox(
409
  label = "Subject Prompt"
410
  )
411
+ with gr.Row():
412
+ use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
413
+ use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
414
 
415
  with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
416
  subject_reference = gr.Image(label="Subject Reference", type="filepath")
 
422
  output_image = gr.Image(label="Output Image")
423
  gr.Examples(
424
  examples = [
425
+ ["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
426
+ ["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
427
+ ["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
428
+ ["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True, False]
429
  ],
430
  fn=run,
431
+ inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
432
  outputs=[output_image],
433
  cache_examples=False
434
 
 
443
 
444
  submit_btn.click(
445
  fn = run,
446
+ inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
447
  outputs = [output_image],
448
  show_api = False
449
  )