Spaces:
Running
on
Zero
Running
on
Zero
drscotthawley
commited on
Commit
·
8a80eb5
1
Parent(s):
5e340e8
fixed typo
Browse files
app.py
CHANGED
@@ -118,7 +118,7 @@ def process_image(image, repaint, busyness):
|
|
118 |
print("Saving masked image file to ", masked_img_file)
|
119 |
image.save(masked_img_file)
|
120 |
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
|
121 |
-
bs =
|
122 |
repaint = repaint
|
123 |
seed_scale = 1.0
|
124 |
CT_HOME = '.'
|
|
|
118 |
print("Saving masked image file to ", masked_img_file)
|
119 |
image.save(masked_img_file)
|
120 |
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
|
121 |
+
bs = num
|
122 |
repaint = repaint
|
123 |
seed_scale = 1.0
|
124 |
CT_HOME = '.'
|
sample.py
CHANGED
@@ -5,9 +5,7 @@
|
|
5 |
|
6 |
"""Samples from k-diffusion models."""
|
7 |
|
8 |
-
|
9 |
-
import spaces
|
10 |
-
import natten
|
11 |
import argparse
|
12 |
from pathlib import Path
|
13 |
|
@@ -24,11 +22,11 @@ from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule
|
|
24 |
#CHORD_BORDER = 8 # chord border size in pixels
|
25 |
from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
|
26 |
|
|
|
27 |
|
28 |
# ---- my mangled sampler that includes repaint
|
29 |
import torchsde
|
30 |
|
31 |
-
#@spaces.GPU
|
32 |
class BatchedBrownianTree:
|
33 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
34 |
|
@@ -56,7 +54,6 @@ class BatchedBrownianTree:
|
|
56 |
return w if self.batched else w[0]
|
57 |
|
58 |
|
59 |
-
#@spaces.GPU
|
60 |
class BrownianTreeNoiseSampler:
|
61 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
62 |
|
@@ -94,7 +91,6 @@ def to_d(x, sigma, denoised):
|
|
94 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
95 |
|
96 |
|
97 |
-
#@spaces.GPU
|
98 |
@torch.no_grad()
|
99 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
100 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
@@ -129,7 +125,6 @@ def get_scalings(sigma, sigma_data=0.5):
|
|
129 |
return c_skip, c_out, c_in
|
130 |
|
131 |
|
132 |
-
#@spaces.GPU
|
133 |
@torch.no_grad()
|
134 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
135 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
@@ -289,14 +284,12 @@ def sample(model, x, steps, eta, **extra_args):
|
|
289 |
|
290 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
291 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
292 |
-
#@spaces.GPU
|
293 |
def get_bmask(i, steps, mask):
|
294 |
strength = (i+1)/(steps)
|
295 |
# convert to binary mask
|
296 |
bmask = torch.where(mask<=strength,1,0)
|
297 |
return bmask
|
298 |
|
299 |
-
#@spaces.GPU
|
300 |
def make_cond_model_fn(model, cond_fn):
|
301 |
def cond_model_fn(x, sigma, **kwargs):
|
302 |
with torch.enable_grad():
|
@@ -312,7 +305,6 @@ def make_cond_model_fn(model, cond_fn):
|
|
312 |
# For sampling, set both init_data and mask to None
|
313 |
# For variations, set init_data
|
314 |
# For inpainting, set both init_data & mask
|
315 |
-
#@spaces.GPU
|
316 |
def sample_k(
|
317 |
model_fn,
|
318 |
noise,
|
@@ -425,7 +417,7 @@ def infer_mask_from_init_img(img, mask_with='white'):
|
|
425 |
mask[img[2,:,:]==1] = 1 # blue
|
426 |
return mask*1.0
|
427 |
|
428 |
-
|
429 |
def grow_mask(init_mask, grow_by=2):
|
430 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
431 |
new_mask = init_mask.clone()
|
@@ -434,7 +426,7 @@ def grow_mask(init_mask, grow_by=2):
|
|
434 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
435 |
return new_mask
|
436 |
|
437 |
-
|
438 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
439 |
"adds extra noise inside mask"
|
440 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
@@ -448,15 +440,13 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
|
448 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
449 |
return init_image
|
450 |
|
451 |
-
|
452 |
def get_init_image_and_mask(args, device):
|
453 |
convert_tensor = transforms.ToTensor()
|
454 |
init_image = Image.open(args.init_image).convert('RGB')
|
455 |
init_image = convert_tensor(init_image)
|
456 |
#normalize image from 0..1 to -1..1
|
457 |
init_image = (2.0 * init_image) - 1.0
|
458 |
-
|
459 |
-
|
460 |
init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
|
461 |
|
462 |
inpaint_task = 'infer' # infer mask from init_image
|
@@ -522,7 +512,115 @@ def get_init_image_and_mask(args, device):
|
|
522 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
523 |
return init_image.to(device), init_mask.to(device)
|
524 |
|
525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
def main():
|
527 |
global init_image, init_mask
|
528 |
p = argparse.ArgumentParser(description=__doc__,
|
@@ -574,12 +672,7 @@ def main():
|
|
574 |
sigma_min = model_config['sigma_min']
|
575 |
sigma_max = model_config['sigma_max']
|
576 |
|
577 |
-
# SHH modified
|
578 |
torch.set_float32_matmul_precision('high')
|
579 |
-
#class_cond = torch.tensor([0]).to(device)
|
580 |
-
#num_classes = 10
|
581 |
-
#class_cond = torch.remainder(torch.arange(0, args.n), num_classes).int().to(device)
|
582 |
-
#extra_args = {'class_cond':class_cond}
|
583 |
extra_args = {}
|
584 |
init_image, init_mask = None, None
|
585 |
if args.init_image is not None:
|
@@ -595,11 +688,6 @@ def main():
|
|
595 |
tqdm.write('Sampling...')
|
596 |
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
|
597 |
|
598 |
-
#ddpm_sampler = DDPM(model)
|
599 |
-
#model_fn = model
|
600 |
-
#ddpm_sampler = K.external.VDenoiser(model_fn)
|
601 |
-
|
602 |
-
#@spaces.GPU
|
603 |
def sample_fn(n, debug=True):
|
604 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
605 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|
|
|
5 |
|
6 |
"""Samples from k-diffusion models."""
|
7 |
|
8 |
+
|
|
|
|
|
9 |
import argparse
|
10 |
from pathlib import Path
|
11 |
|
|
|
22 |
#CHORD_BORDER = 8 # chord border size in pixels
|
23 |
from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
|
24 |
|
25 |
+
import spaces
|
26 |
|
27 |
# ---- my mangled sampler that includes repaint
|
28 |
import torchsde
|
29 |
|
|
|
30 |
class BatchedBrownianTree:
|
31 |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
32 |
|
|
|
54 |
return w if self.batched else w[0]
|
55 |
|
56 |
|
|
|
57 |
class BrownianTreeNoiseSampler:
|
58 |
"""A noise sampler backed by a torchsde.BrownianTree.
|
59 |
|
|
|
91 |
return (x - denoised) / append_dims(sigma, x.ndim)
|
92 |
|
93 |
|
|
|
94 |
@torch.no_grad()
|
95 |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
|
96 |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
|
|
125 |
return c_skip, c_out, c_in
|
126 |
|
127 |
|
|
|
128 |
@torch.no_grad()
|
129 |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
|
130 |
disable=None, eta=1., s_noise=1., noise_sampler=None,
|
|
|
284 |
|
285 |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
286 |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
|
|
287 |
def get_bmask(i, steps, mask):
|
288 |
strength = (i+1)/(steps)
|
289 |
# convert to binary mask
|
290 |
bmask = torch.where(mask<=strength,1,0)
|
291 |
return bmask
|
292 |
|
|
|
293 |
def make_cond_model_fn(model, cond_fn):
|
294 |
def cond_model_fn(x, sigma, **kwargs):
|
295 |
with torch.enable_grad():
|
|
|
305 |
# For sampling, set both init_data and mask to None
|
306 |
# For variations, set init_data
|
307 |
# For inpainting, set both init_data & mask
|
|
|
308 |
def sample_k(
|
309 |
model_fn,
|
310 |
noise,
|
|
|
417 |
mask[img[2,:,:]==1] = 1 # blue
|
418 |
return mask*1.0
|
419 |
|
420 |
+
|
421 |
def grow_mask(init_mask, grow_by=2):
|
422 |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
|
423 |
new_mask = init_mask.clone()
|
|
|
426 |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
|
427 |
return new_mask
|
428 |
|
429 |
+
|
430 |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
|
431 |
"adds extra noise inside mask"
|
432 |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
|
|
|
440 |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
|
441 |
return init_image
|
442 |
|
443 |
+
|
444 |
def get_init_image_and_mask(args, device):
|
445 |
convert_tensor = transforms.ToTensor()
|
446 |
init_image = Image.open(args.init_image).convert('RGB')
|
447 |
init_image = convert_tensor(init_image)
|
448 |
#normalize image from 0..1 to -1..1
|
449 |
init_image = (2.0 * init_image) - 1.0
|
|
|
|
|
450 |
init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
|
451 |
|
452 |
inpaint_task = 'infer' # infer mask from init_image
|
|
|
512 |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
|
513 |
return init_image.to(device), init_mask.to(device)
|
514 |
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
# wrapper compatible with ZeroGPU, callable from outside
|
519 |
+
@spaces.GPU
|
520 |
+
def zero_wrapper(args, device):
|
521 |
+
global init_image, init_mask
|
522 |
+
|
523 |
+
config = K.config.load_config(args.config if args.config else args.checkpoint)
|
524 |
+
model_config = config['model']
|
525 |
+
# TODO: allow non-square input sizes
|
526 |
+
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
|
527 |
+
size = model_config['input_size']
|
528 |
+
|
529 |
+
print('zero_wrapper: Using device:', device, flush=True)
|
530 |
+
|
531 |
+
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
|
532 |
+
cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model
|
533 |
+
if cse is not None:
|
534 |
+
inner_model.cse = cse
|
535 |
+
try:
|
536 |
+
inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
|
537 |
+
except:
|
538 |
+
#ckpt = torch.load(args.checkpoint).to(device)
|
539 |
+
ckpt = torch.load(args.checkpoint, map_location='cpu')
|
540 |
+
inner_model.load_state_dict(ckpt['model'])
|
541 |
+
|
542 |
+
print('Parameters:', K.utils.n_params(inner_model))
|
543 |
+
model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
|
544 |
+
|
545 |
+
sigma_min = model_config['sigma_min']
|
546 |
+
sigma_max = model_config['sigma_max']
|
547 |
+
torch.set_float32_matmul_precision('high')
|
548 |
+
extra_args = {}
|
549 |
+
init_image, init_mask = None, None
|
550 |
+
if args.init_image is not None:
|
551 |
+
init_image, init_mask = get_init_image_and_mask(args, device)
|
552 |
+
init_image = init_image.to(device)
|
553 |
+
init_mask = init_mask.to(device)
|
554 |
+
@torch.no_grad()
|
555 |
+
@K.utils.eval_mode(model)
|
556 |
+
def run():
|
557 |
+
global init_image, init_mask
|
558 |
+
if accelerator.is_local_main_process:
|
559 |
+
tqdm.write('Sampling...')
|
560 |
+
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
|
561 |
+
|
562 |
+
def sample_fn(n, debug=True):
|
563 |
+
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
564 |
+
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|
565 |
+
|
566 |
+
if args.init_image is not None:
|
567 |
+
init_data, mask = get_init_image_and_mask(args, device)
|
568 |
+
init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation?
|
569 |
+
if cse is not None:
|
570 |
+
chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
|
571 |
+
else:
|
572 |
+
chord_cond = None
|
573 |
+
#print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
|
574 |
+
else:
|
575 |
+
init_data, mask, chord_cond = None, None, None
|
576 |
+
# chord_cond doesn't work anyway so f it:
|
577 |
+
chord_cond = None
|
578 |
+
|
579 |
+
print("chord_cond = ", chord_cond)
|
580 |
+
if chord_cond is not None:
|
581 |
+
extra_args['chord_cond'] = chord_cond
|
582 |
+
# these two work:
|
583 |
+
#x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
|
584 |
+
#x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
|
585 |
+
|
586 |
+
noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device)
|
587 |
+
|
588 |
+
sampler_type="my-dpmpp-2m-sde" # "k-lms"
|
589 |
+
#sampler_type="my-sample-euler"
|
590 |
+
#sampler_type="dpmpp-2m-sde"
|
591 |
+
#sampler_type = "dpmpp-3m-sde"
|
592 |
+
#sampler_type = "k-dpmpp-2s-ancestral"
|
593 |
+
print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]])
|
594 |
+
x_0 = sample_k(inner_model, noise, sampler_type=sampler_type,
|
595 |
+
init_data=init_data, mask=mask, steps=args.steps,
|
596 |
+
sigma_min=sigma_min, sigma_max=sigma_max, rho=7.,
|
597 |
+
device=device, model_config=model_config, repaint=args.repaint,
|
598 |
+
**extra_args)
|
599 |
+
#x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args)
|
600 |
+
print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
|
601 |
+
if x_0.isnan().any():
|
602 |
+
assert False, "x_0 has NaNs"
|
603 |
+
|
604 |
+
# do gpu garbage collection before proceeding
|
605 |
+
torch.cuda.empty_cache()
|
606 |
+
return x_0
|
607 |
+
|
608 |
+
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
|
609 |
+
if accelerator.is_main_process:
|
610 |
+
for i, out in enumerate(x_0):
|
611 |
+
filename = f'{args.prefix}_{i:05}.png'
|
612 |
+
K.utils.to_pil_image(out).save(filename)
|
613 |
+
|
614 |
+
try:
|
615 |
+
run()
|
616 |
+
except KeyboardInterrupt:
|
617 |
+
pass
|
618 |
+
|
619 |
+
|
620 |
+
|
621 |
+
|
622 |
+
|
623 |
+
|
624 |
def main():
|
625 |
global init_image, init_mask
|
626 |
p = argparse.ArgumentParser(description=__doc__,
|
|
|
672 |
sigma_min = model_config['sigma_min']
|
673 |
sigma_max = model_config['sigma_max']
|
674 |
|
|
|
675 |
torch.set_float32_matmul_precision('high')
|
|
|
|
|
|
|
|
|
676 |
extra_args = {}
|
677 |
init_image, init_mask = None, None
|
678 |
if args.init_image is not None:
|
|
|
688 |
tqdm.write('Sampling...')
|
689 |
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
|
690 |
|
|
|
|
|
|
|
|
|
|
|
691 |
def sample_fn(n, debug=True):
|
692 |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
|
693 |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
|