adamelliotfields commited on
Commit
ab11d6f
1 Parent(s): 75fa840
Files changed (4) hide show
  1. app.css +0 -11
  2. app.py +40 -96
  3. lib/__init__.py +1 -10
  4. lib/utils.py +2 -38
app.css CHANGED
@@ -16,10 +16,6 @@
16
  margin-right: 4px;
17
  }
18
 
19
- #embeddings .token {
20
- border-radius: 4px;
21
- }
22
-
23
  .gallery {
24
  background-color: var(--bg);
25
  }
@@ -34,10 +30,6 @@
34
  max-height: none;
35
  }
36
 
37
- .gap-0, .gap-0 * {
38
- gap: 0px;
39
- }
40
-
41
  .icon-button {
42
  max-width: 42px;
43
  }
@@ -92,9 +84,6 @@
92
  .popover#clear:hover::after {
93
  content: 'Clear';
94
  }
95
- .popover#clear-control:hover::after {
96
- content: 'Clear';
97
- }
98
  .popover#refresh:hover::after {
99
  content: var(--seed, "-1");
100
  }
 
16
  margin-right: 4px;
17
  }
18
 
 
 
 
 
19
  .gallery {
20
  background-color: var(--bg);
21
  }
 
30
  max-height: none;
31
  }
32
 
 
 
 
 
33
  .icon-button {
34
  max-width: 42px;
35
  }
 
84
  .popover#clear:hover::after {
85
  content: 'Clear';
86
  }
 
 
 
87
  .popover#refresh:hover::after {
88
  content: var(--seed, "-1");
89
  }
app.py CHANGED
@@ -1,12 +1,19 @@
1
  import argparse
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
 
4
 
5
  from lib import (
6
  Config,
7
- async_call,
8
- disable_progress_bars,
9
- download_repo_files,
10
  generate,
11
  read_file,
12
  read_json,
@@ -60,46 +67,6 @@ random_prompt_js = f"""
60
  }}
61
  """
62
 
63
-
64
- # Transform the raw inputs before generation
65
- async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
66
- if len(args) > 0:
67
- prompt = args[0]
68
- else:
69
- prompt = None
70
- if prompt is None or prompt.strip() == "":
71
- raise gr.Error("You must enter a prompt")
72
-
73
- # These are always the last arguments
74
- DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-3:]
75
- gen_args = list(args[:-3])
76
-
77
- # First two arguments are the prompt and negative prompt
78
- if DISABLE_IMAGE_PROMPT:
79
- gen_args[2] = None
80
- if DISABLE_CONTROL_IMAGE_PROMPT:
81
- gen_args[3] = None
82
- if DISABLE_IP_IMAGE_PROMPT:
83
- gen_args[4] = None
84
-
85
- try:
86
- if Config.ZERO_GPU:
87
- progress((0, 100), desc="ZeroGPU init")
88
-
89
- # Remaining arguments are the alert handlers and progress bar
90
- images = await async_call(
91
- generate,
92
- *gen_args,
93
- Error=gr.Error,
94
- Info=gr.Info,
95
- progress=progress,
96
- )
97
- except RuntimeError:
98
- raise gr.Error("Error: Please try again")
99
-
100
- return images
101
-
102
-
103
  with gr.Blocks(
104
  head=read_file("./partials/head.html"),
105
  css="./app.css",
@@ -114,8 +81,8 @@ with gr.Blocks(
114
  radius_size=gr.themes.sizes.radius_sm,
115
  spacing_size=gr.themes.sizes.spacing_md,
116
  # fonts
117
- font=[gr.themes.GoogleFont("Inter"), *Config.SANS_FONTS],
118
- font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *Config.MONO_FONTS],
119
  ).set(
120
  layout_gap="8px",
121
  block_shadow="0 0 #0000",
@@ -124,11 +91,6 @@ with gr.Blocks(
124
  block_background_fill_dark=gr.themes.colors.gray.c900,
125
  ),
126
  ) as demo:
127
- # Disable image inputs without clearing them
128
- DISABLE_IMAGE_PROMPT = gr.State(False)
129
- DISABLE_IP_IMAGE_PROMPT = gr.State(False)
130
- DISABLE_CONTROL_IMAGE_PROMPT = gr.State(False)
131
-
132
  gr.HTML(read_file("./partials/intro.html"))
133
 
134
  with gr.Tabs():
@@ -144,7 +106,7 @@ with gr.Blocks(
144
  format="png",
145
  columns=2,
146
  )
147
- prompt = gr.Textbox(
148
  placeholder="What do you want to see?",
149
  autoscroll=False,
150
  show_label=False,
@@ -271,7 +233,7 @@ with gr.Blocks(
271
  label="Scale",
272
  )
273
  seed = gr.Number(
274
- value=Config.SEED,
275
  label="Seed",
276
  minimum=-1,
277
  maximum=(2**64) - 1,
@@ -286,7 +248,7 @@ with gr.Blocks(
286
  # Image-to-Image settings
287
  gr.HTML("<h3>Image-to-Image</h3>")
288
  with gr.Row():
289
- image_prompt = gr.Image(
290
  show_share_button=False,
291
  label="Initial Image",
292
  min_width=640,
@@ -294,14 +256,14 @@ with gr.Blocks(
294
  type="pil",
295
  )
296
  with gr.Row():
297
- control_image_prompt = gr.Image(
298
  show_share_button=False,
299
  label="Control Image",
300
  min_width=320,
301
  format="png",
302
  type="pil",
303
  )
304
- ip_image_prompt = gr.Image(
305
  show_share_button=False,
306
  label="IP-Adapter Image",
307
  min_width=320,
@@ -316,7 +278,7 @@ with gr.Blocks(
316
  maximum=1.0,
317
  step=0.1,
318
  )
319
- control_annotator = gr.Dropdown(
320
  label="ControlNet Annotator",
321
  # TODO: annotators should be in config with names
322
  choices=[("Canny", "canny")],
@@ -324,22 +286,7 @@ with gr.Blocks(
324
  filterable=False,
325
  )
326
  with gr.Row():
327
- disable_image = gr.Checkbox(
328
- label="Disable initial image",
329
- elem_classes=["checkbox"],
330
- value=False,
331
- )
332
- disable_control_image = gr.Checkbox(
333
- label="Disable ControlNet",
334
- elem_classes=["checkbox"],
335
- value=False,
336
- )
337
- disable_ip_image = gr.Checkbox(
338
- label="Disable IP-Adapter",
339
- elem_classes=["checkbox"],
340
- value=False,
341
- )
342
- use_ip_face = gr.Checkbox(
343
  label="Use IP-Adapter Face",
344
  elem_classes=["checkbox"],
345
  value=False,
@@ -349,7 +296,9 @@ with gr.Blocks(
349
  gr.Markdown(read_file("DOCS.md"))
350
 
351
  # Random prompt on click
352
- random_btn.click(None, inputs=[prompt], outputs=[prompt], js=random_prompt_js)
 
 
353
 
354
  # Update seed on click
355
  refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
@@ -374,31 +323,22 @@ with gr.Blocks(
374
  js=custom_aspect_ratio_js,
375
  )
376
 
377
- # Toggle image prompts by updating session state
378
- gr.on(
379
- triggers=[disable_image.input, disable_control_image.input, disable_ip_image.input],
380
- fn=lambda image, control_image, ip_image: (image, control_image, ip_image),
381
- inputs=[disable_image, disable_control_image, disable_ip_image],
382
- outputs=[DISABLE_IMAGE_PROMPT, DISABLE_CONTROL_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
383
- show_api=False,
384
- )
385
-
386
  # Generate images
387
  gr.on(
388
- triggers=[generate_btn.click, prompt.submit],
389
- fn=generate_fn,
390
  api_name="generate",
391
  outputs=[output_images],
392
  inputs=[
393
- prompt,
394
  negative_prompt,
395
- image_prompt,
396
- control_image_prompt,
397
- ip_image_prompt,
398
  seed,
399
  model,
400
  scheduler,
401
- control_annotator,
402
  width,
403
  height,
404
  guidance_scale,
@@ -408,10 +348,7 @@ with gr.Blocks(
408
  scale,
409
  num_images,
410
  use_karras,
411
- use_ip_face,
412
- DISABLE_IMAGE_PROMPT,
413
- DISABLE_CONTROL_IMAGE_PROMPT,
414
- DISABLE_IP_IMAGE_PROMPT,
415
  ],
416
  )
417
 
@@ -421,9 +358,16 @@ if __name__ == "__main__":
421
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
422
  args = parser.parse_args()
423
 
424
- disable_progress_bars()
425
- for repo_id, allow_patterns in Config.HF_MODELS.items():
426
- download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
 
 
 
 
 
 
 
427
 
428
  # https://www.gradio.app/docs/gradio/interface#interface-queue
429
  demo.queue(default_concurrency_limit=1).launch(
 
1
  import argparse
2
+ import os
3
+ from importlib.util import find_spec
4
+
5
+ # Improved GPU handling and progress bars
6
+ os.environ["ZEROGPU_V2"] = "1"
7
+
8
+ # Use Rust-based downloader
9
+ if find_spec("hf_transfer"):
10
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
11
 
12
  import gradio as gr
13
+ from huggingface_hub._snapshot_download import snapshot_download
14
 
15
  from lib import (
16
  Config,
 
 
 
17
  generate,
18
  read_file,
19
  read_json,
 
67
  }}
68
  """
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with gr.Blocks(
71
  head=read_file("./partials/head.html"),
72
  css="./app.css",
 
81
  radius_size=gr.themes.sizes.radius_sm,
82
  spacing_size=gr.themes.sizes.spacing_md,
83
  # fonts
84
+ font=[gr.themes.GoogleFont("Inter"), "sans-serif"],
85
+ font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), "monospace"],
86
  ).set(
87
  layout_gap="8px",
88
  block_shadow="0 0 #0000",
 
91
  block_background_fill_dark=gr.themes.colors.gray.c900,
92
  ),
93
  ) as demo:
 
 
 
 
 
94
  gr.HTML(read_file("./partials/intro.html"))
95
 
96
  with gr.Tabs():
 
106
  format="png",
107
  columns=2,
108
  )
109
+ positive_prompt = gr.Textbox(
110
  placeholder="What do you want to see?",
111
  autoscroll=False,
112
  show_label=False,
 
233
  label="Scale",
234
  )
235
  seed = gr.Number(
236
+ value=-1,
237
  label="Seed",
238
  minimum=-1,
239
  maximum=(2**64) - 1,
 
248
  # Image-to-Image settings
249
  gr.HTML("<h3>Image-to-Image</h3>")
250
  with gr.Row():
251
+ image_input = gr.Image(
252
  show_share_button=False,
253
  label="Initial Image",
254
  min_width=640,
 
256
  type="pil",
257
  )
258
  with gr.Row():
259
+ controlnet_input = gr.Image(
260
  show_share_button=False,
261
  label="Control Image",
262
  min_width=320,
263
  format="png",
264
  type="pil",
265
  )
266
+ ip_adapter_input = gr.Image(
267
  show_share_button=False,
268
  label="IP-Adapter Image",
269
  min_width=320,
 
278
  maximum=1.0,
279
  step=0.1,
280
  )
281
+ controlnet_annotator = gr.Dropdown(
282
  label="ControlNet Annotator",
283
  # TODO: annotators should be in config with names
284
  choices=[("Canny", "canny")],
 
286
  filterable=False,
287
  )
288
  with gr.Row():
289
+ use_ip_adapter_face = gr.Checkbox(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  label="Use IP-Adapter Face",
291
  elem_classes=["checkbox"],
292
  value=False,
 
296
  gr.Markdown(read_file("DOCS.md"))
297
 
298
  # Random prompt on click
299
+ random_btn.click(
300
+ None, inputs=[positive_prompt], outputs=[positive_prompt], js=random_prompt_js
301
+ )
302
 
303
  # Update seed on click
304
  refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
 
323
  js=custom_aspect_ratio_js,
324
  )
325
 
 
 
 
 
 
 
 
 
 
326
  # Generate images
327
  gr.on(
328
+ triggers=[generate_btn.click, positive_prompt.submit],
329
+ fn=generate,
330
  api_name="generate",
331
  outputs=[output_images],
332
  inputs=[
333
+ positive_prompt,
334
  negative_prompt,
335
+ image_input,
336
+ controlnet_input,
337
+ ip_adapter_input,
338
  seed,
339
  model,
340
  scheduler,
341
+ controlnet_annotator,
342
  width,
343
  height,
344
  guidance_scale,
 
348
  scale,
349
  num_images,
350
  use_karras,
351
+ use_ip_adapter_face,
 
 
 
352
  ],
353
  )
354
 
 
358
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
359
  args = parser.parse_args()
360
 
361
+ token = os.environ.get("HF_TOKEN", None)
362
+ for repo_id, allow_patterns in Config.HF_REPOS.items():
363
+ snapshot_download(
364
+ repo_id,
365
+ repo_type="model",
366
+ revision="main",
367
+ token=token,
368
+ allow_patterns=allow_patterns,
369
+ ignore_patterns=None,
370
+ )
371
 
372
  # https://www.gradio.app/docs/gradio/interface#interface-queue
373
  demo.queue(default_concurrency_limit=1).launch(
lib/__init__.py CHANGED
@@ -1,18 +1,9 @@
1
  from .config import Config
2
  from .inference import generate
3
- from .utils import (
4
- async_call,
5
- disable_progress_bars,
6
- download_repo_files,
7
- read_file,
8
- read_json,
9
- )
10
 
11
  __all__ = [
12
  "Config",
13
- "async_call",
14
- "disable_progress_bars",
15
- "download_repo_files",
16
  "generate",
17
  "read_file",
18
  "read_json",
 
1
  from .config import Config
2
  from .inference import generate
3
+ from .utils import read_file, read_json
 
 
 
 
 
 
4
 
5
  __all__ = [
6
  "Config",
 
 
 
7
  "generate",
8
  "read_file",
9
  "read_json",
lib/utils.py CHANGED
@@ -1,18 +1,14 @@
1
  import functools
2
- import inspect
3
  import json
4
  import os
5
  import time
6
  from contextlib import contextmanager
7
- from typing import Callable, Tuple, TypeVar
8
 
9
- import anyio
10
  import numpy as np
11
  import torch
12
  from anyio import Semaphore
13
  from diffusers.utils import logging as diffusers_logging
14
- from huggingface_hub._snapshot_download import snapshot_download
15
- from huggingface_hub.utils import are_progress_bars_disabled
16
  from PIL import Image
17
  from transformers import logging as transformers_logging
18
  from typing_extensions import ParamSpec
@@ -61,12 +57,7 @@ def enable_progress_bars():
61
  diffusers_logging.enable_progress_bar()
62
 
63
 
64
- def safe_progress(progress, current=0, total=0, desc=""):
65
- if progress is not None:
66
- progress((current, total), desc=desc)
67
-
68
-
69
- def clear_cuda_cache():
70
  if torch.cuda.is_available():
71
  torch.cuda.empty_cache()
72
  torch.cuda.ipc_collect()
@@ -74,22 +65,6 @@ def clear_cuda_cache():
74
  torch.cuda.synchronize()
75
 
76
 
77
- def download_repo_files(repo_id, allow_patterns, token=None):
78
- was_disabled = are_progress_bars_disabled()
79
- enable_progress_bars()
80
- snapshot_path = snapshot_download(
81
- repo_id=repo_id,
82
- repo_type="model",
83
- revision="main",
84
- token=token,
85
- allow_patterns=allow_patterns,
86
- ignore_patterns=None,
87
- )
88
- if was_disabled:
89
- disable_progress_bars()
90
- return snapshot_path
91
-
92
-
93
  def image_to_pil(image: Image.Image):
94
  """Converts various image inputs to RGB PIL Image."""
95
  if isinstance(image, str) and os.path.isfile(image):
@@ -159,14 +134,3 @@ def annotate_image(image: Image.Image, annotator="canny"):
159
  canny = CannyAnnotator()
160
  return canny(image, size)
161
  raise ValueError(f"Invalid annotator: {annotator}")
162
-
163
-
164
- # Like the original but supports args and kwargs instead of a dict
165
- # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
166
- async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
167
- async with MAX_THREADS_GUARD:
168
- sig = inspect.signature(fn)
169
- bound_args = sig.bind(*args, **kwargs)
170
- bound_args.apply_defaults()
171
- partial_fn = functools.partial(fn, **bound_args.arguments)
172
- return await anyio.to_thread.run_sync(partial_fn)
 
1
  import functools
 
2
  import json
3
  import os
4
  import time
5
  from contextlib import contextmanager
6
+ from typing import Tuple, TypeVar
7
 
 
8
  import numpy as np
9
  import torch
10
  from anyio import Semaphore
11
  from diffusers.utils import logging as diffusers_logging
 
 
12
  from PIL import Image
13
  from transformers import logging as transformers_logging
14
  from typing_extensions import ParamSpec
 
57
  diffusers_logging.enable_progress_bar()
58
 
59
 
60
+ def cuda_collect():
 
 
 
 
 
61
  if torch.cuda.is_available():
62
  torch.cuda.empty_cache()
63
  torch.cuda.ipc_collect()
 
65
  torch.cuda.synchronize()
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def image_to_pil(image: Image.Image):
69
  """Converts various image inputs to RGB PIL Image."""
70
  if isinstance(image, str) and os.path.isfile(image):
 
134
  canny = CannyAnnotator()
135
  return canny(image, size)
136
  raise ValueError(f"Invalid annotator: {annotator}")