ZYMPKU commited on
Commit
295fba8
·
1 Parent(s): 251f521

text length

Browse files
Files changed (3) hide show
  1. app.py +5 -2
  2. configs/demo.yaml +1 -0
  3. configs/test/textdesign_sd_2.yaml +1 -25
app.py CHANGED
@@ -40,6 +40,9 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
40
 
41
  global cfgs, global_index
42
 
 
 
 
43
  global_index += 1
44
 
45
  if num_samples > 1: cfgs.noise_iters = 0
@@ -51,7 +54,7 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
51
  seed_everything(seed)
52
 
53
  sampler = init_sampling(cfgs)
54
-
55
  image = input_blk["image"]
56
  mask = input_blk["mask"]
57
  image = cv2.resize(image, (cfgs.W, cfgs.H))
@@ -156,7 +159,7 @@ if __name__ == "__main__":
156
  with gr.Column():
157
 
158
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
159
- text = gr.Textbox(label="Text to render:", info="the text you want to render at the masked region")
160
  run_button = gr.Button(variant="primary")
161
 
162
  with gr.Accordion("Advanced options", open=False):
 
40
 
41
  global cfgs, global_index
42
 
43
+ if len(text) < cfgs.txt_len[0] or len(text) > cfgs.txt_len[1]:
44
+ raise gr.Error("Illegal text length!")
45
+
46
  global_index += 1
47
 
48
  if num_samples > 1: cfgs.noise_iters = 0
 
54
  seed_everything(seed)
55
 
56
  sampler = init_sampling(cfgs)
57
+
58
  image = input_blk["image"]
59
  mask = input_blk["mask"]
60
  image = cv2.resize(image, (cfgs.W, cfgs.H))
 
159
  with gr.Column():
160
 
161
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
162
+ text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
163
  run_button = gr.Button(variant="primary")
164
 
165
  with gr.Accordion("Advanced options", open=False):
configs/demo.yaml CHANGED
@@ -7,6 +7,7 @@ model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
7
  # param
8
  H: 512
9
  W: 512
 
10
  seq_len: 12
11
  batch_size: 1
12
 
 
7
  # param
8
  H: 512
9
  W: 512
10
+ txt_len: [1, 12]
11
  seq_len: 12
12
  batch_size: 1
13
 
configs/test/textdesign_sd_2.yaml CHANGED
@@ -104,28 +104,4 @@ model:
104
  attn_resolutions: []
105
  dropout: 0.0
106
  lossconfig:
107
- target: torch.nn.Identity
108
-
109
- loss_fn_config:
110
- target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
111
- params:
112
- seq_len: 12
113
- kernel_size: 3
114
- gaussian_sigma: 1.0
115
- min_attn_size: 16
116
- lambda_local_loss: 0.01
117
- lambda_ocr_loss: 0.001
118
- ocr_enabled: False
119
-
120
- predictor_config:
121
- target: sgm.modules.predictors.model.ParseqPredictor
122
- params:
123
- ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
124
-
125
- sigma_sampler_config:
126
- target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
127
- params:
128
- num_idx: 1000
129
-
130
- discretization_config:
131
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
 
104
  attn_resolutions: []
105
  dropout: 0.0
106
  lossconfig:
107
+ target: torch.nn.Identity