radames commited on
Commit
d02351b
·
verified ·
1 Parent(s): 593fdd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -113
app.py CHANGED
@@ -1,149 +1,351 @@
1
- import random
2
  import spaces
3
-
4
  import gradio as gr
5
- import numpy as np
6
  import torch
7
- from diffusers import LCMScheduler, PixArtAlphaPipeline, Transformer2DModel
8
- from peft import PeftModel
 
 
 
 
 
 
 
 
9
  import os
 
 
10
 
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
 
13
 
14
- transformer = Transformer2DModel.from_pretrained(
15
- "PixArt-alpha/PixArt-XL-2-1024-MS",
16
- subfolder="transformer",
17
- torch_dtype=torch.float16,
18
- )
19
- transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-pixart")
20
 
 
21
 
22
- if torch.cuda.is_available():
23
- torch.cuda.max_memory_allocated(device=device)
24
- pipe = PixArtAlphaPipeline.from_pretrained(
25
- "PixArt-alpha/PixArt-XL-2-1024-MS",
26
- transformer=transformer,
27
- torch_dtype=torch.float16,
28
- )
29
- if not IS_SPACES_ZERO:
30
- pipe.enable_xformers_memory_efficient_attention()
31
- pipe = pipe.to(device)
32
- else:
33
- pipe = PixArtAlphaPipeline.from_pretrained(
34
- "PixArt-alpha/PixArt-XL-2-1024-MS",
35
- transformer=transformer,
36
- torch_dtype=torch.float16,
37
- )
38
- pipe = pipe.to(device)
39
 
40
- pipe.text_encoder.to_bettertransformer()
41
 
42
- pipe.scheduler = LCMScheduler.from_pretrained(
43
- "PixArt-alpha/PixArt-XL-2-1024-MS",
44
- subfolder="scheduler",
45
- timestep_spacing="trailing",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
- MAX_SEED = np.iinfo(np.int32).max
49
- MAX_IMAGE_SIZE = 1024
50
- NUM_INFERENCE_STEPS = 4
 
 
 
 
51
 
 
 
 
52
 
53
- @spaces.GPU
54
- def infer(prompt, seed, randomize_seed):
55
- if randomize_seed:
56
- seed = random.randint(0, MAX_SEED)
57
 
58
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- image = pipe(
61
- prompt=prompt,
62
- guidance_scale=0,
63
- num_inference_steps=NUM_INFERENCE_STEPS,
64
- generator=generator,
65
- ).images[0]
66
 
67
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- examples = [
71
- "The image showcases a freshly baked bread, possibly focaccia, with rosemary sprigs and red pepper flakes sprinkled on top. It's sliced and placed on a wire cooling rack, with a bowl of mixed peppercorns beside it.",
72
- "A raccoon reading a book in a lush forest.",
73
- "A small cactus with a happy face in the Sahara desert.",
74
- "A super-realistic close-up of a snake eye",
75
- "A cute cheetah looking amazed and surprised",
76
- "Pirate ship sailing on a sea with the milky way galaxy in the sky and purple glow lights",
77
- "a cute fluffy rabbit pilot walking on a military aircraft carrier, 8k, cinematic",
78
- "A close up of an old elderly man with green eyes looking straight at the camera",
79
- "A beautiful sunflower in rainy day",
80
- ]
81
 
82
  css = """
83
- #col-container {
84
- margin: 0 auto;
85
- max-width: 512px;
 
86
  }
87
  """
88
 
89
- if torch.cuda.is_available():
90
- power_device = "GPU"
91
- else:
92
- power_device = "CPU"
93
-
94
  with gr.Blocks(css=css) as demo:
95
- with gr.Column(elem_id="col-container"):
96
- gr.Markdown(
97
- f"""
98
- # ⚡ Flash Diffusion: FlashPixart ⚡
99
- This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/), a diffusion distillation method proposed in [Flash Diffusion: Accelerating Any Conditional
100
- Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by Clément Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin.*
101
- [This model](https://huggingface.co/jasperai/flash-pixart) is a **66.5M** LoRA distilled version of [Pixart-α](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) model that is able to generate 1024x1024 images in **4 steps**.
102
- Currently running on {power_device}.
103
  """
104
- )
105
- gr.Markdown(
106
- "💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
107
- )
108
-
109
- with gr.Row():
110
- prompt = gr.Text(
 
 
 
 
 
 
111
  label="Prompt",
112
- show_label=False,
113
- max_lines=1,
114
- placeholder="Enter your prompt",
115
- container=False,
 
116
  )
117
-
118
- run_button = gr.Button("Run", scale=0)
119
-
120
- result = gr.Image(label="Result", show_label=False)
121
-
122
- with gr.Accordion("Advanced Settings", open=False):
123
  seed = gr.Slider(
124
- label="Seed",
125
  minimum=0,
126
- maximum=MAX_SEED,
 
127
  step=1,
128
- value=0,
 
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
132
-
133
- examples = gr.Examples(examples=examples, inputs=[prompt])
134
-
135
- gr.Markdown("**Disclaimer:**")
136
- gr.Markdown(
137
- "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
138
- )
139
- gr.on(
140
- [run_button.click, seed.change, prompt.change, randomize_seed.change],
141
- fn=infer,
142
- inputs=[prompt, seed, randomize_seed],
143
- outputs=[result],
144
- show_progress="minimal",
145
- show_api=False,
146
- trigger_mode="always_last",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
149
- demo.queue().launch(show_api=False)
 
 
 
 
1
  import spaces
 
2
  import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
  import torch
5
+
6
+ torch.jit.script = lambda f: f
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ StableDiffusionXLControlNetImg2ImgPipeline,
10
+ DDIMScheduler,
11
+ )
12
+ from controlnet_aux import AnylineDetector
13
+ from compel import Compel, ReturnedEmbeddingsType
14
+ from PIL import Image
15
  import os
16
+ import time
17
+ import numpy as np
18
 
 
19
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
20
+ IS_SPACE = os.environ.get("SPACE_ID", None) is not None
21
 
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ dtype = torch.float16
 
 
 
 
24
 
25
+ LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
26
 
27
+ print(f"device: {device}")
28
+ print(f"dtype: {dtype}")
29
+ print(f"low memory: {LOW_MEMORY}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
31
 
32
+ model = "stabilityai/stable-diffusion-xl-base-1.0"
33
+ # model = "stabilityai/sdxl-turbo"
34
+ # vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
35
+ scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
36
+ # controlnet = ControlNetModel.from_pretrained(
37
+ # "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
38
+ # )
39
+ controlnet = ControlNetModel.from_pretrained(
40
+ "TheMistoAI/MistoLine",
41
+ torch_dtype=torch.float16,
42
+ revision="refs/pr/3",
43
+ variant="fp16",
44
+ )
45
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
46
+ model,
47
+ controlnet=controlnet,
48
+ torch_dtype=dtype,
49
+ variant="fp16",
50
+ use_safetensors=True,
51
+ scheduler=scheduler,
52
  )
53
 
54
+ compel = Compel(
55
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
56
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
57
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
58
+ requires_pooled=[False, True],
59
+ )
60
+ pipe = pipe.to(device)
61
 
62
+ anyline = AnylineDetector.from_pretrained(
63
+ "TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline"
64
+ ).to(device)
65
 
 
 
 
 
66
 
67
+ def pad_image(image):
68
+ w, h = image.size
69
+ if w == h:
70
+ return image
71
+ elif w > h:
72
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
73
+ pad_w = 0
74
+ pad_h = (w - h) // 2
75
+ new_image.paste(image, (0, pad_h))
76
+ return new_image
77
+ else:
78
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
79
+ pad_w = (h - w) // 2
80
+ pad_h = 0
81
+ new_image.paste(image, (pad_w, 0))
82
+ return new_image
83
 
 
 
 
 
 
 
84
 
85
+ @spaces.GPU
86
+ def predict(
87
+ input_image,
88
+ prompt,
89
+ negative_prompt,
90
+ seed,
91
+ guidance_scale=8.5,
92
+ controlnet_conditioning_scale=0.5,
93
+ strength=1.0,
94
+ controlnet_start=0.0,
95
+ controlnet_end=1.0,
96
+ guassian_sigma=2.0,
97
+ intensity_threshold=3,
98
+ progress=gr.Progress(track_tqdm=True),
99
+ ):
100
+ if input_image is None:
101
+ raise gr.Error("Please upload an image.")
102
+ padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB")
103
+ conditioning, pooled = compel([prompt, negative_prompt])
104
+ generator = torch.manual_seed(seed)
105
+ last_time = time.time()
106
+ anyline_image = anyline(
107
+ padded_image,
108
+ detect_resolution=1280,
109
+ guassian_sigma=max(0.01, guassian_sigma),
110
+ intensity_threshold=intensity_threshold,
111
+ )
112
 
113
+ images = pipe(
114
+ image=padded_image,
115
+ control_image=anyline_image,
116
+ strength=strength,
117
+ prompt_embeds=conditioning[0:1],
118
+ pooled_prompt_embeds=pooled[0:1],
119
+ negative_prompt_embeds=conditioning[1:2],
120
+ negative_pooled_prompt_embeds=pooled[1:2],
121
+ width=1024,
122
+ height=1024,
123
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
124
+ controlnet_start=float(controlnet_start),
125
+ controlnet_end=float(controlnet_end),
126
+ generator=generator,
127
+ num_inference_steps=30,
128
+ guidance_scale=guidance_scale,
129
+ eta=1.0,
130
+ )
131
+ print(f"Time taken: {time.time() - last_time}")
132
+ return (padded_image, images.images[0]), padded_image, anyline_image
133
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  css = """
136
+ #intro{
137
+ # max-width: 32rem;
138
+ # text-align: center;
139
+ # margin: 0 auto;
140
  }
141
  """
142
 
 
 
 
 
 
143
  with gr.Blocks(css=css) as demo:
144
+ gr.Markdown(
 
 
 
 
 
 
 
145
  """
146
+ # MistoLine ControlNet demo
147
+
148
+ You can upload an initial image and prompt to generate an enhanced version.
149
+ SDXL Controlnet [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine)
150
+ [Anyline with Controlnet Aux ](https://github.com/huggingface/controlnet_aux)
151
+ For upscaling see [Enhance This Demo](https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL)
152
+ """,
153
+ elem_id="intro",
154
+ )
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ image_input = gr.Image(type="pil", label="Input Image")
158
+ prompt = gr.Textbox(
159
  label="Prompt",
160
+ info="The prompt is very important to get the desired results. Please try to describe the image as best as you can. Accepts Compel Syntax",
161
+ )
162
+ negative_prompt = gr.Textbox(
163
+ label="Negative Prompt",
164
+ value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
165
  )
 
 
 
 
 
 
166
  seed = gr.Slider(
 
167
  minimum=0,
168
+ maximum=2**64 - 1,
169
+ value=1415926535897932,
170
  step=1,
171
+ label="Seed",
172
+ randomize=True,
173
  )
174
+ with gr.Accordion(label="Advanced", open=False):
175
+ guidance_scale = gr.Slider(
176
+ minimum=0,
177
+ maximum=50,
178
+ value=8.5,
179
+ step=0.001,
180
+ label="Guidance Scale",
181
+ )
182
+ controlnet_conditioning_scale = gr.Slider(
183
+ minimum=0,
184
+ maximum=1,
185
+ step=0.001,
186
+ value=0.5,
187
+ label="ControlNet Conditioning Scale",
188
+ )
189
+ strength = gr.Slider(
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.001,
193
+ value=1,
194
+ label="Strength",
195
+ )
196
+ controlnet_start = gr.Slider(
197
+ minimum=0,
198
+ maximum=1,
199
+ step=0.001,
200
+ value=0.0,
201
+ label="ControlNet Start",
202
+ )
203
+ controlnet_end = gr.Slider(
204
+ minimum=0.0,
205
+ maximum=1.0,
206
+ step=0.001,
207
+ value=1.0,
208
+ label="ControlNet End",
209
+ )
210
+ guassian_sigma = gr.Slider(
211
+ minimum=0.01,
212
+ maximum=10.0,
213
+ step=0.1,
214
+ value=2.0,
215
+ label="(Anyline) Guassian Sigma",
216
+ )
217
+ intensity_threshold = gr.Slider(
218
+ minimum=0,
219
+ maximum=255,
220
+ step=1,
221
+ value=3,
222
+ label="(Anyline) Intensity Threshold",
223
+ )
224
 
225
+ btn = gr.Button()
226
+ with gr.Column(scale=2):
227
+ with gr.Group():
228
+ image_slider = ImageSlider(position=0.5)
229
+ with gr.Row():
230
+ padded_image = gr.Image(type="pil", label="Padded Image")
231
+ anyline_image = gr.Image(type="pil", label="Anyline Image")
232
+ inputs = [
233
+ image_input,
234
+ prompt,
235
+ negative_prompt,
236
+ seed,
237
+ guidance_scale,
238
+ controlnet_conditioning_scale,
239
+ strength,
240
+ controlnet_start,
241
+ controlnet_end,
242
+ guassian_sigma,
243
+ intensity_threshold,
244
+ ]
245
+ outputs = [image_slider, padded_image, anyline_image]
246
+ btn.click(lambda x: None, inputs=None, outputs=image_slider).then(
247
+ fn=predict, inputs=inputs, outputs=outputs
248
+ )
249
+ gr.Examples(
250
+ fn=predict,
251
+ inputs=inputs,
252
+ outputs=outputs,
253
+ examples=[
254
+ [
255
+ "./examples/city.png",
256
+ "hyperrealistic surreal cityscape scene at sunset, buildings",
257
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
258
+ 13113544138610326000,
259
+ 8.5,
260
+ 0.481,
261
+ 1.0,
262
+ 0.0,
263
+ 0.9,
264
+ 2,
265
+ 3,
266
+ ],
267
+ [
268
+ "./examples/lara.jpeg",
269
+ "photography of lara croft 8k high definition award winning",
270
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
271
+ 5436236241,
272
+ 8.5,
273
+ 0.8,
274
+ 1.0,
275
+ 0.0,
276
+ 0.9,
277
+ 2,
278
+ 3,
279
+ ],
280
+ [
281
+ "./examples/cybetruck.jpeg",
282
+ "photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future",
283
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
284
+ 383472451451,
285
+ 8.5,
286
+ 0.8,
287
+ 0.8,
288
+ 0.0,
289
+ 0.9,
290
+ 2,
291
+ 3,
292
+ ],
293
+ [
294
+ "./examples/jesus.png",
295
+ "a photorealistic painting of Jesus Christ, 4k high definition",
296
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
297
+ 13317204146129588000,
298
+ 8.5,
299
+ 0.8,
300
+ 0.8,
301
+ 0.0,
302
+ 0.9,
303
+ 2,
304
+ 3,
305
+ ],
306
+ [
307
+ "./examples/anna-sullivan-DioLM8ViiO8-unsplash.jpg",
308
+ "A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow",
309
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
310
+ 5623124123512,
311
+ 8.5,
312
+ 0.8,
313
+ 0.8,
314
+ 0.0,
315
+ 0.9,
316
+ 2,
317
+ 3,
318
+ ],
319
+ [
320
+ "./examples/img_aef651cb-2919-499d-aa49-6d4e2e21a56e_1024.jpg",
321
+ "a large red flower on a black background 4k high definition",
322
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
323
+ 23123412341234,
324
+ 8.5,
325
+ 0.8,
326
+ 0.8,
327
+ 0.0,
328
+ 0.9,
329
+ 2,
330
+ 3,
331
+ ],
332
+ [
333
+ "./examples/huggingface.jpg",
334
+ "photo realistic huggingface human emoji costume, round, yellow, (human skin)+++ (human texture)+++",
335
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic, emoji cartoon, drawing, pixelated",
336
+ 12312353423,
337
+ 15.206,
338
+ 0.364,
339
+ 0.8,
340
+ 0.0,
341
+ 0.9,
342
+ 2,
343
+ 3,
344
+ ],
345
+ ],
346
+ cache_examples="lazy",
347
  )
348
 
349
+
350
+ demo.queue(api_open=True)
351
+ demo.launch(show_api=True)