v 2.0
Browse files- app.py +3 -1
- sgm/modules/diffusionmodules/guiders.py +3 -4
app.py
CHANGED
@@ -53,7 +53,8 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
|
|
53 |
cfgs.detailed = show_detail
|
54 |
seed_everything(seed)
|
55 |
|
56 |
-
sampler =
|
|
|
57 |
|
58 |
image = input_blk["image"]
|
59 |
mask = input_blk["mask"]
|
@@ -128,6 +129,7 @@ if __name__ == "__main__":
|
|
128 |
cfgs = OmegaConf.load("./configs/demo.yaml")
|
129 |
|
130 |
model = init_model(cfgs)
|
|
|
131 |
global_index = 0
|
132 |
|
133 |
block = gr.Blocks().queue()
|
|
|
53 |
cfgs.detailed = show_detail
|
54 |
seed_everything(seed)
|
55 |
|
56 |
+
sampler.num_steps = steps
|
57 |
+
sampler.guider.scale_value = scale
|
58 |
|
59 |
image = input_blk["image"]
|
60 |
mask = input_blk["mask"]
|
|
|
129 |
cfgs = OmegaConf.load("./configs/demo.yaml")
|
130 |
|
131 |
model = init_model(cfgs)
|
132 |
+
sampler = init_sampling(cfgs)
|
133 |
global_index = 0
|
134 |
|
135 |
block = gr.Blocks().queue()
|
sgm/modules/diffusionmodules/guiders.py
CHANGED
@@ -11,8 +11,8 @@ class VanillaCFG:
|
|
11 |
"""
|
12 |
|
13 |
def __init__(self, scale, dyn_thresh_config=None):
|
14 |
-
|
15 |
-
self.
|
16 |
self.dyn_thresh = instantiate_from_config(
|
17 |
default(
|
18 |
dyn_thresh_config,
|
@@ -24,8 +24,7 @@ class VanillaCFG:
|
|
24 |
|
25 |
def __call__(self, x, sigma):
|
26 |
x_u, x_c = x.chunk(2)
|
27 |
-
|
28 |
-
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
29 |
return x_pred
|
30 |
|
31 |
def prepare_inputs(self, x, s, c, uc):
|
|
|
11 |
"""
|
12 |
|
13 |
def __init__(self, scale, dyn_thresh_config=None):
|
14 |
+
|
15 |
+
self.scale_value = scale
|
16 |
self.dyn_thresh = instantiate_from_config(
|
17 |
default(
|
18 |
dyn_thresh_config,
|
|
|
24 |
|
25 |
def __call__(self, x, sigma):
|
26 |
x_u, x_c = x.chunk(2)
|
27 |
+
x_pred = self.dyn_thresh(x_u, x_c, self.scale_value)
|
|
|
28 |
return x_pred
|
29 |
|
30 |
def prepare_inputs(self, x, s, c, uc):
|