multimodalart HF staff commited on
Commit
7696de6
·
verified ·
1 Parent(s): ee58b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -70
app.py CHANGED
@@ -1,85 +1,106 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import AuraFlowPipeline
5
  import torch
6
- import spaces
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- #torch.set_float32_matmul_precision("high")
11
-
12
- #torch._inductor.config.conv_1x1_as_mm = True
13
- #torch._inductor.config.coordinate_descent_tuning = True
14
- #torch._inductor.config.epilogue_fusion = False
15
- #torch._inductor.config.coordinate_descent_check_all_directions = True
16
-
17
- pipe = AuraFlowPipeline.from_pretrained(
18
- "fal/AuraFlow",
19
- torch_dtype=torch.float16
20
- ).to("cuda")
21
 
22
- #pipe.transformer.to(memory_format=torch.channels_last)
23
- #pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
24
- #pipe.transformer.to(memory_format=torch.channels_last)
25
- #pipe.vae.to(memory_format=torch.channels_last)
26
 
27
- #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
28
- #pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1024
32
 
33
- @spaces.GPU
34
- def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
35
-
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
 
 
40
 
41
  image = pipe(
42
- prompt = prompt,
43
- negative_prompt = negative_prompt,
44
- width=width,
45
  height=height,
46
- guidance_scale = guidance_scale,
47
- num_inference_steps = num_inference_steps,
48
- generator = generator
49
- ).images[0]
 
 
50
 
51
  return image, seed
52
 
53
  examples = [
54
- "A photo of a lavender cat",
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css="""
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 520px;
64
  }
65
  """
66
-
67
- if torch.cuda.is_available():
68
- power_device = "GPU"
69
- else:
70
- power_device = "CPU"
71
-
72
  with gr.Blocks(css=css) as demo:
73
-
74
  with gr.Column(elem_id="col-container"):
75
  gr.Markdown(f"""
76
- # AuraFlow 0.1
77
- Demo of the [AuraFlow 0.1](https://huggingface.co/fal/AuraFlow) 6.8B parameters open source diffusion transformer model
78
- [[blog](https://blog.fal.ai/auraflow/)] [[model](https://huggingface.co/fal/AuraFlow)] [[fal](https://fal.ai/models/fal-ai/aura-flow)]
79
  """)
80
 
81
  with gr.Row():
82
-
83
  prompt = gr.Text(
84
  label="Prompt",
85
  show_label=False,
@@ -87,19 +108,18 @@ with gr.Blocks(css=css) as demo:
87
  placeholder="Enter your prompt",
88
  container=False,
89
  )
90
-
91
  run_button = gr.Button("Run", scale=0)
92
 
93
- result = gr.Image(label="Result", show_label=False)
 
 
94
 
95
  with gr.Accordion("Advanced Settings", open=False):
96
-
97
  negative_prompt = gr.Text(
98
  label="Negative prompt",
99
  max_lines=1,
100
  placeholder="Enter a negative prompt",
101
  )
102
-
103
  seed = gr.Slider(
104
  label="Seed",
105
  minimum=0,
@@ -107,11 +127,8 @@ with gr.Blocks(css=css) as demo:
107
  step=1,
108
  value=0,
109
  )
110
-
111
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
112
-
113
  with gr.Row():
114
-
115
  width = gr.Slider(
116
  label="Width",
117
  minimum=256,
@@ -119,7 +136,6 @@ with gr.Blocks(css=css) as demo:
119
  step=32,
120
  value=1024,
121
  )
122
-
123
  height = gr.Slider(
124
  label="Height",
125
  minimum=256,
@@ -127,9 +143,7 @@ with gr.Blocks(css=css) as demo:
127
  step=32,
128
  value=1024,
129
  )
130
-
131
  with gr.Row():
132
-
133
  guidance_scale = gr.Slider(
134
  label="Guidance scale",
135
  minimum=0.0,
@@ -137,28 +151,34 @@ with gr.Blocks(css=css) as demo:
137
  step=0.1,
138
  value=5.0,
139
  )
140
-
141
  num_inference_steps = gr.Slider(
142
  label="Number of inference steps",
143
  minimum=1,
144
- maximum=50,
145
  step=1,
146
- value=28,
147
  )
 
 
 
 
 
 
 
148
 
149
  gr.Examples(
150
- examples = examples,
151
- fn = infer,
152
- inputs = [prompt],
153
- outputs = [result, seed],
154
  cache_examples="lazy"
155
  )
156
 
157
  gr.on(
158
- triggers=[run_button.click, prompt.submit, negative_prompt.submit],
159
- fn = infer,
160
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
161
- outputs = [result, seed]
162
  )
163
 
164
  demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
+ from PIL import Image
6
+ import os
7
 
8
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
9
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
13
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler
 
 
 
 
 
 
 
14
 
15
+ from huggingface_hub import snapshot_download
 
 
 
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ ckpt_dir = f'{root_dir}/weights/Kolors'
20
+
21
+ snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
22
+ snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
23
+
24
+ # Load models
25
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
26
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
27
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
28
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
29
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
30
+
31
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
32
+ f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
33
+ ignore_mismatched_sizes=True
34
+ ).to(dtype=torch.float16, device=device)
35
+
36
+ ip_img_size = 336
37
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
38
+
39
+ pipe = StableDiffusionXLPipeline(
40
+ vae=vae,
41
+ text_encoder=text_encoder,
42
+ tokenizer=tokenizer,
43
+ unet=unet,
44
+ scheduler=scheduler,
45
+ image_encoder=image_encoder,
46
+ feature_extractor=clip_image_processor,
47
+ force_zeros_for_empty_prompt=False
48
+ )
49
+
50
+ pipe = pipe.to(device)
51
+ #pipe.enable_model_cpu_offload()
52
+
53
+ if hasattr(pipe.unet, 'encoder_hid_proj'):
54
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
55
+
56
+ pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
57
 
58
  MAX_SEED = np.iinfo(np.int32).max
59
  MAX_IMAGE_SIZE = 1024
60
 
61
+ def infer(prompt, ip_adapter_image, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, ip_adapter_scale=0.5, progress=gr.Progress(track_tqdm=True)):
 
 
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
+
65
+ generator = torch.Generator(device="cpu").manual_seed(seed)
66
+
67
+ pipe.set_ip_adapter_scale([ip_adapter_scale])
68
 
69
  image = pipe(
70
+ prompt=prompt,
71
+ ip_adapter_image=[ip_adapter_image],
72
+ negative_prompt=negative_prompt,
73
  height=height,
74
+ width=width,
75
+ num_inference_steps=num_inference_steps,
76
+ guidance_scale=guidance_scale,
77
+ num_images_per_prompt=1,
78
+ generator=generator,
79
+ ).images[0]
80
 
81
  return image, seed
82
 
83
  examples = [
84
+ ["A photo of a lavender cat", "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/640px-Cat_November_2010-1a.jpg"],
85
+ ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b5/Astronaut_EVA.jpg/640px-Astronaut_EVA.jpg"],
86
+ ["An astronaut riding a green horse", "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f7/Haflinger_in-motion.jpg/640px-Haflinger_in-motion.jpg"],
87
+ ["A delicious ceviche cheesecake slice", "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9c/Ceviche_mixto.jpg/640px-Ceviche_mixto.jpg"],
88
  ]
89
 
90
  css="""
91
  #col-container {
92
  margin: 0 auto;
93
+ max-width: 720px;
94
  }
95
  """
 
 
 
 
 
 
96
  with gr.Blocks(css=css) as demo:
 
97
  with gr.Column(elem_id="col-container"):
98
  gr.Markdown(f"""
99
+ # Kolors Demo
100
+ Demo of the Kolors model with IP-Adapter integration
 
101
  """)
102
 
103
  with gr.Row():
 
104
  prompt = gr.Text(
105
  label="Prompt",
106
  show_label=False,
 
108
  placeholder="Enter your prompt",
109
  container=False,
110
  )
 
111
  run_button = gr.Button("Run", scale=0)
112
 
113
+ with gr.Row():
114
+ ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
115
+ result = gr.Image(label="Result", show_label=False)
116
 
117
  with gr.Accordion("Advanced Settings", open=False):
 
118
  negative_prompt = gr.Text(
119
  label="Negative prompt",
120
  max_lines=1,
121
  placeholder="Enter a negative prompt",
122
  )
 
123
  seed = gr.Slider(
124
  label="Seed",
125
  minimum=0,
 
127
  step=1,
128
  value=0,
129
  )
 
130
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
131
  with gr.Row():
 
132
  width = gr.Slider(
133
  label="Width",
134
  minimum=256,
 
136
  step=32,
137
  value=1024,
138
  )
 
139
  height = gr.Slider(
140
  label="Height",
141
  minimum=256,
 
143
  step=32,
144
  value=1024,
145
  )
 
146
  with gr.Row():
 
147
  guidance_scale = gr.Slider(
148
  label="Guidance scale",
149
  minimum=0.0,
 
151
  step=0.1,
152
  value=5.0,
153
  )
 
154
  num_inference_steps = gr.Slider(
155
  label="Number of inference steps",
156
  minimum=1,
157
+ maximum=100,
158
  step=1,
159
+ value=50,
160
  )
161
+ ip_adapter_scale = gr.Slider(
162
+ label="IP-Adapter Scale",
163
+ minimum=0.0,
164
+ maximum=1.0,
165
+ step=0.01,
166
+ value=0.5,
167
+ )
168
 
169
  gr.Examples(
170
+ examples=examples,
171
+ fn=infer,
172
+ inputs=[prompt, ip_adapter_image],
173
+ outputs=[result, seed],
174
  cache_examples="lazy"
175
  )
176
 
177
  gr.on(
178
+ triggers=[run_button.click, prompt.submit],
179
+ fn=infer,
180
+ inputs=[prompt, ip_adapter_image, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ip_adapter_scale],
181
+ outputs=[result, seed]
182
  )
183
 
184
  demo.queue().launch()