text length
Browse files- app.py +5 -2
- configs/demo.yaml +1 -0
- configs/test/textdesign_sd_2.yaml +1 -25
app.py
CHANGED
@@ -40,6 +40,9 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
|
|
40 |
|
41 |
global cfgs, global_index
|
42 |
|
|
|
|
|
|
|
43 |
global_index += 1
|
44 |
|
45 |
if num_samples > 1: cfgs.noise_iters = 0
|
@@ -51,7 +54,7 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
|
|
51 |
seed_everything(seed)
|
52 |
|
53 |
sampler = init_sampling(cfgs)
|
54 |
-
|
55 |
image = input_blk["image"]
|
56 |
mask = input_blk["mask"]
|
57 |
image = cv2.resize(image, (cfgs.W, cfgs.H))
|
@@ -156,7 +159,7 @@ if __name__ == "__main__":
|
|
156 |
with gr.Column():
|
157 |
|
158 |
input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
|
159 |
-
text = gr.Textbox(label="Text to render:", info="the text you want to render at the masked region")
|
160 |
run_button = gr.Button(variant="primary")
|
161 |
|
162 |
with gr.Accordion("Advanced options", open=False):
|
|
|
40 |
|
41 |
global cfgs, global_index
|
42 |
|
43 |
+
if len(text) < cfgs.txt_len[0] or len(text) > cfgs.txt_len[1]:
|
44 |
+
raise gr.Error("Illegal text length!")
|
45 |
+
|
46 |
global_index += 1
|
47 |
|
48 |
if num_samples > 1: cfgs.noise_iters = 0
|
|
|
54 |
seed_everything(seed)
|
55 |
|
56 |
sampler = init_sampling(cfgs)
|
57 |
+
|
58 |
image = input_blk["image"]
|
59 |
mask = input_blk["mask"]
|
60 |
image = cv2.resize(image, (cfgs.W, cfgs.H))
|
|
|
159 |
with gr.Column():
|
160 |
|
161 |
input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
|
162 |
+
text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
|
163 |
run_button = gr.Button(variant="primary")
|
164 |
|
165 |
with gr.Accordion("Advanced options", open=False):
|
configs/demo.yaml
CHANGED
@@ -7,6 +7,7 @@ model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
|
|
7 |
# param
|
8 |
H: 512
|
9 |
W: 512
|
|
|
10 |
seq_len: 12
|
11 |
batch_size: 1
|
12 |
|
|
|
7 |
# param
|
8 |
H: 512
|
9 |
W: 512
|
10 |
+
txt_len: [1, 12]
|
11 |
seq_len: 12
|
12 |
batch_size: 1
|
13 |
|
configs/test/textdesign_sd_2.yaml
CHANGED
@@ -104,28 +104,4 @@ model:
|
|
104 |
attn_resolutions: []
|
105 |
dropout: 0.0
|
106 |
lossconfig:
|
107 |
-
target: torch.nn.Identity
|
108 |
-
|
109 |
-
loss_fn_config:
|
110 |
-
target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
|
111 |
-
params:
|
112 |
-
seq_len: 12
|
113 |
-
kernel_size: 3
|
114 |
-
gaussian_sigma: 1.0
|
115 |
-
min_attn_size: 16
|
116 |
-
lambda_local_loss: 0.01
|
117 |
-
lambda_ocr_loss: 0.001
|
118 |
-
ocr_enabled: False
|
119 |
-
|
120 |
-
predictor_config:
|
121 |
-
target: sgm.modules.predictors.model.ParseqPredictor
|
122 |
-
params:
|
123 |
-
ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
|
124 |
-
|
125 |
-
sigma_sampler_config:
|
126 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
127 |
-
params:
|
128 |
-
num_idx: 1000
|
129 |
-
|
130 |
-
discretization_config:
|
131 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
|
|
104 |
attn_resolutions: []
|
105 |
dropout: 0.0
|
106 |
lossconfig:
|
107 |
+
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|