drscotthawley commited on
Commit
f084bcc
·
1 Parent(s): 8a80eb5

now calling zero_wrapper directly

Browse files
Files changed (2) hide show
  1. app.py +18 -8
  2. sample.py +2 -4
app.py CHANGED
@@ -29,6 +29,8 @@ import k_diffusion as K
29
  import natten
30
  import accelerate
31
 
 
 
32
  zero = torch.Tensor([0]).cuda()
33
  print("Zero Device = ",zero.device," <-- this probably says cpu") # <-- 'cpu' 🤔
34
 
@@ -96,6 +98,12 @@ def grab_dense_gen(init_img,
96
  print("Grabbing filename = ", dense_filename)
97
  return dense_filename
98
 
 
 
 
 
 
 
99
 
100
  @spaces.GPU
101
  def process_image(image, repaint, busyness):
@@ -126,14 +134,16 @@ def process_image(image, repaint, busyness):
126
  PREFIX = 'gradiodemo'
127
  # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
128
  print("Reading init image from ", masked_img_file,", repaint = ",repaint)
129
- cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
130
- print("Will run command: ", cmd)
131
- args = cmd.split(' ')
132
- #call(cmd, shell=True)
133
- print("Calling: ", args,"\n")
134
- return_value = call(args)
135
- print("Return value = ", return_value)
136
-
 
 
137
 
138
  # find gen'd image and convert to midi piano roll
139
  #gen_file = f'{PREFIX}_00000.png'
 
29
  import natten
30
  import accelerate
31
 
32
+ from sample import zero_wrapper
33
+
34
  zero = torch.Tensor([0]).cuda()
35
  print("Zero Device = ",zero.device," <-- this probably says cpu") # <-- 'cpu' 🤔
36
 
 
98
  print("Grabbing filename = ", dense_filename)
99
  return dense_filename
100
 
101
+ # dummy class to make an args-like object
102
+ class Args:
103
+ def __init__(self, **kwargs):
104
+ for key, value in kwargs.items():
105
+ setattr(self, key, value)
106
+
107
 
108
  @spaces.GPU
109
  def process_image(image, repaint, busyness):
 
134
  PREFIX = 'gradiodemo'
135
  # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
136
  print("Reading init image from ", masked_img_file,", repaint = ",repaint)
137
+ # cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
138
+ # print("Will run command: ", cmd)
139
+ # args = cmd.split(' ')
140
+ # #call(cmd, shell=True)
141
+ # print("Calling: ", args,"\n")
142
+ # return_value = call(args)
143
+ # print("Return value = ", return_value)
144
+ args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, repaint=repaint)
145
+ print(" Now calling zero_wrapper with args = ",args,"\n")
146
+ zero_wrapper(args, accelerator, device)
147
 
148
  # find gen'd image and convert to midi piano roll
149
  #gen_file = f'{PREFIX}_00000.png'
sample.py CHANGED
@@ -22,8 +22,6 @@ from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule
22
  #CHORD_BORDER = 8 # chord border size in pixels
23
  from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
24
 
25
- import spaces
26
-
27
  # ---- my mangled sampler that includes repaint
28
  import torchsde
29
 
@@ -516,8 +514,8 @@ def get_init_image_and_mask(args, device):
516
 
517
 
518
  # wrapper compatible with ZeroGPU, callable from outside
519
- @spaces.GPU
520
- def zero_wrapper(args, device):
521
  global init_image, init_mask
522
 
523
  config = K.config.load_config(args.config if args.config else args.checkpoint)
 
22
  #CHORD_BORDER = 8 # chord border size in pixels
23
  from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
24
 
 
 
25
  # ---- my mangled sampler that includes repaint
26
  import torchsde
27
 
 
514
 
515
 
516
  # wrapper compatible with ZeroGPU, callable from outside
517
+ #@spaces.GPU
518
+ def zero_wrapper(args, accelerator, device):
519
  global init_image, init_mask
520
 
521
  config = K.config.load_config(args.config if args.config else args.checkpoint)