radames commited on
Commit
8b39cdc
·
1 Parent(s): 8453a45

sd2.1 turbo + controlnet

Browse files
Files changed (1) hide show
  1. pipelines/controlnelSD21Turbo.py +260 -0
pipelines/controlnelSD21Turbo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ LCMScheduler,
5
+ AutoencoderTiny,
6
+ )
7
+ from compel import Compel
8
+ import torch
9
+ from pipelines.utils.canny_gpu import SobelOperator
10
+
11
+ try:
12
+ import intel_extension_for_pytorch as ipex # type: ignore
13
+ except:
14
+ pass
15
+
16
+ import psutil
17
+ from config import Args
18
+ from pydantic import BaseModel, Field
19
+ from PIL import Image
20
+ import math
21
+ import time
22
+
23
+ #
24
+ taesd_model = "madebyollin/taesd"
25
+ controlnet_model = "thibaud/controlnet-sd21-canny-diffusers"
26
+ base_model = "stabilityai/sd-turbo"
27
+
28
+ default_prompt = "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
29
+ page_content = """
30
+ <h1 class="text-3xl font-bold">Real-Time SDv2.1 Turbo</h1>
31
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
32
+ <p class="text-sm">
33
+ This demo showcases
34
+ <a
35
+ href="https://huggingface.co/stabilityai/sdxl-turbo"
36
+ target="_blank"
37
+ class="text-blue-500 underline hover:no-underline">SDXL Turbo</a>
38
+ Image to Image pipeline using
39
+ <a
40
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl_turbo"
41
+ target="_blank"
42
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
43
+ > with a MJPEG stream server.
44
+ </p>
45
+ <p class="text-sm text-gray-500">
46
+ Change the prompt to generate different images, accepts <a
47
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
48
+ target="_blank"
49
+ class="text-blue-500 underline hover:no-underline">Compel</a
50
+ > syntax.
51
+ </p>
52
+ """
53
+
54
+
55
+ class Pipeline:
56
+ class Info(BaseModel):
57
+ name: str = "controlnet+sd15Turbo"
58
+ title: str = "SDv1.5 Turbo + Controlnet"
59
+ description: str = "Generates an image from a text prompt"
60
+ input_mode: str = "image"
61
+ page_content: str = page_content
62
+
63
+ class InputParams(BaseModel):
64
+ prompt: str = Field(
65
+ default_prompt,
66
+ title="Prompt",
67
+ field="textarea",
68
+ id="prompt",
69
+ )
70
+ seed: int = Field(
71
+ 4402026899276587, min=0, title="Seed", field="seed", hide=True, id="seed"
72
+ )
73
+ steps: int = Field(
74
+ 1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
75
+ )
76
+ width: int = Field(
77
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
78
+ )
79
+ height: int = Field(
80
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
81
+ )
82
+ guidance_scale: float = Field(
83
+ 1.21,
84
+ min=0,
85
+ max=10,
86
+ step=0.001,
87
+ title="Guidance Scale",
88
+ field="range",
89
+ hide=True,
90
+ id="guidance_scale",
91
+ )
92
+ strength: float = Field(
93
+ 0.8,
94
+ min=0.10,
95
+ max=1.0,
96
+ step=0.001,
97
+ title="Strength",
98
+ field="range",
99
+ hide=True,
100
+ id="strength",
101
+ )
102
+ controlnet_scale: float = Field(
103
+ 0.2,
104
+ min=0,
105
+ max=1.0,
106
+ step=0.001,
107
+ title="Controlnet Scale",
108
+ field="range",
109
+ hide=True,
110
+ id="controlnet_scale",
111
+ )
112
+ controlnet_start: float = Field(
113
+ 0.0,
114
+ min=0,
115
+ max=1.0,
116
+ step=0.001,
117
+ title="Controlnet Start",
118
+ field="range",
119
+ hide=True,
120
+ id="controlnet_start",
121
+ )
122
+ controlnet_end: float = Field(
123
+ 1.0,
124
+ min=0,
125
+ max=1.0,
126
+ step=0.001,
127
+ title="Controlnet End",
128
+ field="range",
129
+ hide=True,
130
+ id="controlnet_end",
131
+ )
132
+ canny_low_threshold: float = Field(
133
+ 0.31,
134
+ min=0,
135
+ max=1.0,
136
+ step=0.001,
137
+ title="Canny Low Threshold",
138
+ field="range",
139
+ hide=True,
140
+ id="canny_low_threshold",
141
+ )
142
+ canny_high_threshold: float = Field(
143
+ 0.125,
144
+ min=0,
145
+ max=1.0,
146
+ step=0.001,
147
+ title="Canny High Threshold",
148
+ field="range",
149
+ hide=True,
150
+ id="canny_high_threshold",
151
+ )
152
+ debug_canny: bool = Field(
153
+ False,
154
+ title="Debug Canny",
155
+ field="checkbox",
156
+ hide=True,
157
+ id="debug_canny",
158
+ )
159
+
160
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
161
+ controlnet_canny = ControlNetModel.from_pretrained(
162
+ controlnet_model, torch_dtype=torch_dtype
163
+ ).to(device)
164
+
165
+ self.pipes = {}
166
+
167
+ if args.safety_checker:
168
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
169
+ base_model,
170
+ controlnet=controlnet_canny,
171
+ )
172
+ else:
173
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
174
+ base_model,
175
+ controlnet=controlnet_canny,
176
+ safety_checker=None,
177
+ )
178
+
179
+ if args.use_taesd:
180
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
181
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
182
+ ).to(device)
183
+ self.canny_torch = SobelOperator(device=device)
184
+
185
+ self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
186
+ self.pipe.set_progress_bar_config(disable=True)
187
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
188
+ if device.type != "mps":
189
+ self.pipe.unet.to(memory_format=torch.channels_last)
190
+
191
+ if psutil.virtual_memory().total < 64 * 1024**3:
192
+ self.pipe.enable_attention_slicing()
193
+
194
+ self.pipe.compel_proc = Compel(
195
+ tokenizer=self.pipe.tokenizer,
196
+ text_encoder=self.pipe.text_encoder,
197
+ truncate_long_prompts=True,
198
+ )
199
+ if args.use_taesd:
200
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
201
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
202
+ ).to(device)
203
+
204
+ if args.torch_compile:
205
+ self.pipe.unet = torch.compile(
206
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
207
+ )
208
+ self.pipe.vae = torch.compile(
209
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
210
+ )
211
+ self.pipe(
212
+ prompt="warmup",
213
+ image=[Image.new("RGB", (768, 768))],
214
+ control_image=[Image.new("RGB", (768, 768))],
215
+ )
216
+
217
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
218
+ generator = torch.manual_seed(params.seed)
219
+ prompt_embeds = self.pipe.compel_proc(params.prompt)
220
+ control_image = self.canny_torch(
221
+ params.image, params.canny_low_threshold, params.canny_high_threshold
222
+ )
223
+ steps = params.steps
224
+ strength = params.strength
225
+ if int(steps * strength) < 1:
226
+ steps = math.ceil(1 / max(0.10, strength))
227
+ last_time = time.time()
228
+ results = self.pipe(
229
+ image=params.image,
230
+ control_image=control_image,
231
+ prompt_embeds=prompt_embeds,
232
+ generator=generator,
233
+ strength=strength,
234
+ num_inference_steps=steps,
235
+ guidance_scale=params.guidance_scale,
236
+ width=params.width,
237
+ height=params.height,
238
+ output_type="pil",
239
+ controlnet_conditioning_scale=params.controlnet_scale,
240
+ control_guidance_start=params.controlnet_start,
241
+ control_guidance_end=params.controlnet_end,
242
+ )
243
+ print(f"Time taken: {time.time() - last_time}")
244
+
245
+ nsfw_content_detected = (
246
+ results.nsfw_content_detected[0]
247
+ if "nsfw_content_detected" in results
248
+ else False
249
+ )
250
+ if nsfw_content_detected:
251
+ return None
252
+ result_image = results.images[0]
253
+ if params.debug_canny:
254
+ # paste control_image on top of result_image
255
+ w0, h0 = (200, 200)
256
+ control_image = control_image.resize((w0, h0))
257
+ w1, h1 = result_image.size
258
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
259
+
260
+ return result_image