drscotthawley commited on
Commit
8a80eb5
·
1 Parent(s): 5e340e8

fixed typo

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. sample.py +114 -26
app.py CHANGED
@@ -118,7 +118,7 @@ def process_image(image, repaint, busyness):
118
  print("Saving masked image file to ", masked_img_file)
119
  image.save(masked_img_file)
120
  num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
121
- bs = numx
122
  repaint = repaint
123
  seed_scale = 1.0
124
  CT_HOME = '.'
 
118
  print("Saving masked image file to ", masked_img_file)
119
  image.save(masked_img_file)
120
  num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
121
+ bs = num
122
  repaint = repaint
123
  seed_scale = 1.0
124
  CT_HOME = '.'
sample.py CHANGED
@@ -5,9 +5,7 @@
5
 
6
  """Samples from k-diffusion models."""
7
 
8
- import gradio
9
- import spaces
10
- import natten
11
  import argparse
12
  from pathlib import Path
13
 
@@ -24,11 +22,11 @@ from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule
24
  #CHORD_BORDER = 8 # chord border size in pixels
25
  from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
26
 
 
27
 
28
  # ---- my mangled sampler that includes repaint
29
  import torchsde
30
 
31
- #@spaces.GPU
32
  class BatchedBrownianTree:
33
  """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
34
 
@@ -56,7 +54,6 @@ class BatchedBrownianTree:
56
  return w if self.batched else w[0]
57
 
58
 
59
- #@spaces.GPU
60
  class BrownianTreeNoiseSampler:
61
  """A noise sampler backed by a torchsde.BrownianTree.
62
 
@@ -94,7 +91,6 @@ def to_d(x, sigma, denoised):
94
  return (x - denoised) / append_dims(sigma, x.ndim)
95
 
96
 
97
- #@spaces.GPU
98
  @torch.no_grad()
99
  def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
100
  """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -129,7 +125,6 @@ def get_scalings(sigma, sigma_data=0.5):
129
  return c_skip, c_out, c_in
130
 
131
 
132
- #@spaces.GPU
133
  @torch.no_grad()
134
  def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
135
  disable=None, eta=1., s_noise=1., noise_sampler=None,
@@ -289,14 +284,12 @@ def sample(model, x, steps, eta, **extra_args):
289
 
290
  # Soft mask inpainting is just shrinking hard (binary) mask inpainting
291
  # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
292
- #@spaces.GPU
293
  def get_bmask(i, steps, mask):
294
  strength = (i+1)/(steps)
295
  # convert to binary mask
296
  bmask = torch.where(mask<=strength,1,0)
297
  return bmask
298
 
299
- #@spaces.GPU
300
  def make_cond_model_fn(model, cond_fn):
301
  def cond_model_fn(x, sigma, **kwargs):
302
  with torch.enable_grad():
@@ -312,7 +305,6 @@ def make_cond_model_fn(model, cond_fn):
312
  # For sampling, set both init_data and mask to None
313
  # For variations, set init_data
314
  # For inpainting, set both init_data & mask
315
- #@spaces.GPU
316
  def sample_k(
317
  model_fn,
318
  noise,
@@ -425,7 +417,7 @@ def infer_mask_from_init_img(img, mask_with='white'):
425
  mask[img[2,:,:]==1] = 1 # blue
426
  return mask*1.0
427
 
428
- #@spaces.GPU
429
  def grow_mask(init_mask, grow_by=2):
430
  "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
431
  new_mask = init_mask.clone()
@@ -434,7 +426,7 @@ def grow_mask(init_mask, grow_by=2):
434
  new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
435
  return new_mask
436
 
437
- #@spaces.GPU
438
  def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
439
  "adds extra noise inside mask"
440
  init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
@@ -448,15 +440,13 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
448
  init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
449
  return init_image
450
 
451
- #@spaces.GPU
452
  def get_init_image_and_mask(args, device):
453
  convert_tensor = transforms.ToTensor()
454
  init_image = Image.open(args.init_image).convert('RGB')
455
  init_image = convert_tensor(init_image)
456
  #normalize image from 0..1 to -1..1
457
  init_image = (2.0 * init_image) - 1.0
458
-
459
-
460
  init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
461
 
462
  inpaint_task = 'infer' # infer mask from init_image
@@ -522,7 +512,115 @@ def get_init_image_and_mask(args, device):
522
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
523
  return init_image.to(device), init_mask.to(device)
524
 
525
- #@spaces.GPU # generates an error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  def main():
527
  global init_image, init_mask
528
  p = argparse.ArgumentParser(description=__doc__,
@@ -574,12 +672,7 @@ def main():
574
  sigma_min = model_config['sigma_min']
575
  sigma_max = model_config['sigma_max']
576
 
577
- # SHH modified
578
  torch.set_float32_matmul_precision('high')
579
- #class_cond = torch.tensor([0]).to(device)
580
- #num_classes = 10
581
- #class_cond = torch.remainder(torch.arange(0, args.n), num_classes).int().to(device)
582
- #extra_args = {'class_cond':class_cond}
583
  extra_args = {}
584
  init_image, init_mask = None, None
585
  if args.init_image is not None:
@@ -595,11 +688,6 @@ def main():
595
  tqdm.write('Sampling...')
596
  sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
597
 
598
- #ddpm_sampler = DDPM(model)
599
- #model_fn = model
600
- #ddpm_sampler = K.external.VDenoiser(model_fn)
601
-
602
- #@spaces.GPU
603
  def sample_fn(n, debug=True):
604
  x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
605
  print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
 
5
 
6
  """Samples from k-diffusion models."""
7
 
8
+
 
 
9
  import argparse
10
  from pathlib import Path
11
 
 
22
  #CHORD_BORDER = 8 # chord border size in pixels
23
  from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
24
 
25
+ import spaces
26
 
27
  # ---- my mangled sampler that includes repaint
28
  import torchsde
29
 
 
30
  class BatchedBrownianTree:
31
  """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
32
 
 
54
  return w if self.batched else w[0]
55
 
56
 
 
57
  class BrownianTreeNoiseSampler:
58
  """A noise sampler backed by a torchsde.BrownianTree.
59
 
 
91
  return (x - denoised) / append_dims(sigma, x.ndim)
92
 
93
 
 
94
  @torch.no_grad()
95
  def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
96
  """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
 
125
  return c_skip, c_out, c_in
126
 
127
 
 
128
  @torch.no_grad()
129
  def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
130
  disable=None, eta=1., s_noise=1., noise_sampler=None,
 
284
 
285
  # Soft mask inpainting is just shrinking hard (binary) mask inpainting
286
  # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
 
287
  def get_bmask(i, steps, mask):
288
  strength = (i+1)/(steps)
289
  # convert to binary mask
290
  bmask = torch.where(mask<=strength,1,0)
291
  return bmask
292
 
 
293
  def make_cond_model_fn(model, cond_fn):
294
  def cond_model_fn(x, sigma, **kwargs):
295
  with torch.enable_grad():
 
305
  # For sampling, set both init_data and mask to None
306
  # For variations, set init_data
307
  # For inpainting, set both init_data & mask
 
308
  def sample_k(
309
  model_fn,
310
  noise,
 
417
  mask[img[2,:,:]==1] = 1 # blue
418
  return mask*1.0
419
 
420
+
421
  def grow_mask(init_mask, grow_by=2):
422
  "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
423
  new_mask = init_mask.clone()
 
426
  new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
427
  return new_mask
428
 
429
+
430
  def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
431
  "adds extra noise inside mask"
432
  init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
 
440
  init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
441
  return init_image
442
 
443
+
444
  def get_init_image_and_mask(args, device):
445
  convert_tensor = transforms.ToTensor()
446
  init_image = Image.open(args.init_image).convert('RGB')
447
  init_image = convert_tensor(init_image)
448
  #normalize image from 0..1 to -1..1
449
  init_image = (2.0 * init_image) - 1.0
 
 
450
  init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
451
 
452
  inpaint_task = 'infer' # infer mask from init_image
 
512
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
513
  return init_image.to(device), init_mask.to(device)
514
 
515
+
516
+
517
+
518
+ # wrapper compatible with ZeroGPU, callable from outside
519
+ @spaces.GPU
520
+ def zero_wrapper(args, device):
521
+ global init_image, init_mask
522
+
523
+ config = K.config.load_config(args.config if args.config else args.checkpoint)
524
+ model_config = config['model']
525
+ # TODO: allow non-square input sizes
526
+ assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
527
+ size = model_config['input_size']
528
+
529
+ print('zero_wrapper: Using device:', device, flush=True)
530
+
531
+ inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
532
+ cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model
533
+ if cse is not None:
534
+ inner_model.cse = cse
535
+ try:
536
+ inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
537
+ except:
538
+ #ckpt = torch.load(args.checkpoint).to(device)
539
+ ckpt = torch.load(args.checkpoint, map_location='cpu')
540
+ inner_model.load_state_dict(ckpt['model'])
541
+
542
+ print('Parameters:', K.utils.n_params(inner_model))
543
+ model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
544
+
545
+ sigma_min = model_config['sigma_min']
546
+ sigma_max = model_config['sigma_max']
547
+ torch.set_float32_matmul_precision('high')
548
+ extra_args = {}
549
+ init_image, init_mask = None, None
550
+ if args.init_image is not None:
551
+ init_image, init_mask = get_init_image_and_mask(args, device)
552
+ init_image = init_image.to(device)
553
+ init_mask = init_mask.to(device)
554
+ @torch.no_grad()
555
+ @K.utils.eval_mode(model)
556
+ def run():
557
+ global init_image, init_mask
558
+ if accelerator.is_local_main_process:
559
+ tqdm.write('Sampling...')
560
+ sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
561
+
562
+ def sample_fn(n, debug=True):
563
+ x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
564
+ print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
565
+
566
+ if args.init_image is not None:
567
+ init_data, mask = get_init_image_and_mask(args, device)
568
+ init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation?
569
+ if cse is not None:
570
+ chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
571
+ else:
572
+ chord_cond = None
573
+ #print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
574
+ else:
575
+ init_data, mask, chord_cond = None, None, None
576
+ # chord_cond doesn't work anyway so f it:
577
+ chord_cond = None
578
+
579
+ print("chord_cond = ", chord_cond)
580
+ if chord_cond is not None:
581
+ extra_args['chord_cond'] = chord_cond
582
+ # these two work:
583
+ #x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
584
+ #x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
585
+
586
+ noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device)
587
+
588
+ sampler_type="my-dpmpp-2m-sde" # "k-lms"
589
+ #sampler_type="my-sample-euler"
590
+ #sampler_type="dpmpp-2m-sde"
591
+ #sampler_type = "dpmpp-3m-sde"
592
+ #sampler_type = "k-dpmpp-2s-ancestral"
593
+ print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]])
594
+ x_0 = sample_k(inner_model, noise, sampler_type=sampler_type,
595
+ init_data=init_data, mask=mask, steps=args.steps,
596
+ sigma_min=sigma_min, sigma_max=sigma_max, rho=7.,
597
+ device=device, model_config=model_config, repaint=args.repaint,
598
+ **extra_args)
599
+ #x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args)
600
+ print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
601
+ if x_0.isnan().any():
602
+ assert False, "x_0 has NaNs"
603
+
604
+ # do gpu garbage collection before proceeding
605
+ torch.cuda.empty_cache()
606
+ return x_0
607
+
608
+ x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
609
+ if accelerator.is_main_process:
610
+ for i, out in enumerate(x_0):
611
+ filename = f'{args.prefix}_{i:05}.png'
612
+ K.utils.to_pil_image(out).save(filename)
613
+
614
+ try:
615
+ run()
616
+ except KeyboardInterrupt:
617
+ pass
618
+
619
+
620
+
621
+
622
+
623
+
624
  def main():
625
  global init_image, init_mask
626
  p = argparse.ArgumentParser(description=__doc__,
 
672
  sigma_min = model_config['sigma_min']
673
  sigma_max = model_config['sigma_max']
674
 
 
675
  torch.set_float32_matmul_precision('high')
 
 
 
 
676
  extra_args = {}
677
  init_image, init_mask = None, None
678
  if args.init_image is not None:
 
688
  tqdm.write('Sampling...')
689
  sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
690
 
 
 
 
 
 
691
  def sample_fn(n, debug=True):
692
  x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
693
  print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())