KingNish commited on
Commit
99c90b0
·
verified ·
1 Parent(s): b3c7dc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -21,7 +21,7 @@ if not torch.cuda.is_available():
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
24
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000"))
25
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
27
 
@@ -105,7 +105,7 @@ if torch.cuda.is_available():
105
  print("Using DALL-E 3 Consistency Decoder")
106
  pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
107
 
108
- if ENABLE_CPU_OFFLOAD:
109
  pipe.enable_model_cpu_offload()
110
  else:
111
  pipe.to(device)
@@ -118,33 +118,35 @@ if torch.cuda.is_available():
118
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
119
  print("Model Compiled!")
120
 
 
121
  def save_image(img):
122
  unique_name = str(uuid.uuid4()) + ".png"
123
  img.save(unique_name)
124
  return unique_name
125
 
 
126
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
127
  if randomize_seed:
128
  seed = random.randint(0, MAX_SEED)
129
  return seed
130
 
 
131
  def generate(
132
  prompt: str,
133
  negative_prompt: str = "",
134
  style: str = DEFAULT_STYLE_NAME,
135
  use_negative_prompt: bool = False,
136
- num_imgs: int = 1,
137
  seed: int = 0,
138
  width: int = 1024,
139
  height: int = 1024,
140
- num_inference_steps: int = 4,
141
  randomize_seed: bool = False,
142
  use_resolution_binning: bool = True,
143
  progress=gr.Progress(track_tqdm=True),
144
  ):
145
  seed = int(randomize_seed_fn(seed, randomize_seed))
146
  generator = torch.Generator().manual_seed(seed)
147
-
148
  if not use_negative_prompt:
149
  negative_prompt = None # type: ignore
150
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
@@ -290,6 +292,7 @@ with gr.Blocks() as demo:
290
  negative_prompt,
291
  style_selection,
292
  use_negative_prompt,
 
293
  seed,
294
  width,
295
  height,
@@ -300,6 +303,6 @@ with gr.Blocks() as demo:
300
  api_name="run",
301
  )
302
 
303
-
304
  if __name__ == "__main__":
305
- demo.queue(max_size=20).launch()
 
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
24
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
25
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
27
 
 
105
  print("Using DALL-E 3 Consistency Decoder")
106
  pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
107
 
108
+ if ENABLE_CPU_OFFLOAD:
109
  pipe.enable_model_cpu_offload()
110
  else:
111
  pipe.to(device)
 
118
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
119
  print("Model Compiled!")
120
 
121
+
122
  def save_image(img):
123
  unique_name = str(uuid.uuid4()) + ".png"
124
  img.save(unique_name)
125
  return unique_name
126
 
127
+
128
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
129
  if randomize_seed:
130
  seed = random.randint(0, MAX_SEED)
131
  return seed
132
 
133
+
134
  def generate(
135
  prompt: str,
136
  negative_prompt: str = "",
137
  style: str = DEFAULT_STYLE_NAME,
138
  use_negative_prompt: bool = False,
 
139
  seed: int = 0,
140
  width: int = 1024,
141
  height: int = 1024,
142
+ inference_steps: int = 4,
143
  randomize_seed: bool = False,
144
  use_resolution_binning: bool = True,
145
  progress=gr.Progress(track_tqdm=True),
146
  ):
147
  seed = int(randomize_seed_fn(seed, randomize_seed))
148
  generator = torch.Generator().manual_seed(seed)
149
+
150
  if not use_negative_prompt:
151
  negative_prompt = None # type: ignore
152
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
 
292
  negative_prompt,
293
  style_selection,
294
  use_negative_prompt,
295
+ num_imgs,
296
  seed,
297
  width,
298
  height,
 
303
  api_name="run",
304
  )
305
 
 
306
  if __name__ == "__main__":
307
+ demo.queue(max_size=20).launch()
308
+ # demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=11900, debug=True)