ZYMPKU commited on
Commit
0b659a7
·
1 Parent(s): 84902eb
app.py CHANGED
@@ -53,7 +53,8 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
53
  cfgs.detailed = show_detail
54
  seed_everything(seed)
55
 
56
- sampler = init_sampling(cfgs)
 
57
 
58
  image = input_blk["image"]
59
  mask = input_blk["mask"]
@@ -128,6 +129,7 @@ if __name__ == "__main__":
128
  cfgs = OmegaConf.load("./configs/demo.yaml")
129
 
130
  model = init_model(cfgs)
 
131
  global_index = 0
132
 
133
  block = gr.Blocks().queue()
 
53
  cfgs.detailed = show_detail
54
  seed_everything(seed)
55
 
56
+ sampler.num_steps = steps
57
+ sampler.guider.scale_value = scale
58
 
59
  image = input_blk["image"]
60
  mask = input_blk["mask"]
 
129
  cfgs = OmegaConf.load("./configs/demo.yaml")
130
 
131
  model = init_model(cfgs)
132
+ sampler = init_sampling(cfgs)
133
  global_index = 0
134
 
135
  block = gr.Blocks().queue()
sgm/modules/diffusionmodules/guiders.py CHANGED
@@ -11,8 +11,8 @@ class VanillaCFG:
11
  """
12
 
13
  def __init__(self, scale, dyn_thresh_config=None):
14
- scale_schedule = lambda scale, sigma: scale # independent of step
15
- self.scale_schedule = partial(scale_schedule, scale)
16
  self.dyn_thresh = instantiate_from_config(
17
  default(
18
  dyn_thresh_config,
@@ -24,8 +24,7 @@ class VanillaCFG:
24
 
25
  def __call__(self, x, sigma):
26
  x_u, x_c = x.chunk(2)
27
- scale_value = self.scale_schedule(sigma)
28
- x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
  return x_pred
30
 
31
  def prepare_inputs(self, x, s, c, uc):
 
11
  """
12
 
13
  def __init__(self, scale, dyn_thresh_config=None):
14
+
15
+ self.scale_value = scale
16
  self.dyn_thresh = instantiate_from_config(
17
  default(
18
  dyn_thresh_config,
 
24
 
25
  def __call__(self, x, sigma):
26
  x_u, x_c = x.chunk(2)
27
+ x_pred = self.dyn_thresh(x_u, x_c, self.scale_value)
 
28
  return x_pred
29
 
30
  def prepare_inputs(self, x, s, c, uc):