radames commited on
Commit
4b58964
·
1 Parent(s): 592470d

using sfast go brrrr

Browse files
app_init.py CHANGED
@@ -110,11 +110,11 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
110
  params = await user_data.get_latest_data(user_id)
111
  if not vars(params) or params.__dict__ == last_params.__dict__:
112
  await websocket.send_json({"status": "send_frame"})
113
- await asyncio.sleep(0.1)
114
  continue
115
 
116
  last_params = params
117
  image = pipeline.predict(params)
 
118
  if image is None:
119
  await websocket.send_json({"status": "send_frame"})
120
  continue
 
110
  params = await user_data.get_latest_data(user_id)
111
  if not vars(params) or params.__dict__ == last_params.__dict__:
112
  await websocket.send_json({"status": "send_frame"})
 
113
  continue
114
 
115
  last_params = params
116
  image = pipeline.predict(params)
117
+
118
  if image is None:
119
  await websocket.send_json({"status": "send_frame"})
120
  continue
config.py CHANGED
@@ -16,6 +16,7 @@ class Args(NamedTuple):
16
  pipeline: str
17
  ssl_certfile: str
18
  ssl_keyfile: str
 
19
  compel: bool = False
20
  debug: bool = False
21
 
@@ -102,6 +103,12 @@ parser.add_argument(
102
  default=False,
103
  help="Compel",
104
  )
 
 
 
 
 
 
105
  parser.set_defaults(taesd=USE_TAESD)
106
 
107
  args = Args(**vars(parser.parse_args()))
 
16
  pipeline: str
17
  ssl_certfile: str
18
  ssl_keyfile: str
19
+ sfast: bool
20
  compel: bool = False
21
  debug: bool = False
22
 
 
103
  default=False,
104
  help="Compel",
105
  )
106
+ parser.add_argument(
107
+ "--sfast",
108
+ action="store_true",
109
+ default=False,
110
+ help="Enable Stable Fast",
111
+ )
112
  parser.set_defaults(taesd=USE_TAESD)
113
 
114
  args = Args(**vars(parser.parse_args()))
pipelines/controlnelSD21Turbo.py CHANGED
@@ -180,6 +180,19 @@ class Pipeline:
180
  self.pipe.vae = AutoencoderTiny.from_pretrained(
181
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
182
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  self.canny_torch = SobelOperator(device=device)
184
 
185
  self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
@@ -188,14 +201,15 @@ class Pipeline:
188
  if device.type != "mps":
189
  self.pipe.unet.to(memory_format=torch.channels_last)
190
 
191
- if psutil.virtual_memory().total < 64 * 1024**3:
192
- self.pipe.enable_attention_slicing()
 
 
 
 
 
 
193
 
194
- self.pipe.compel_proc = Compel(
195
- tokenizer=self.pipe.tokenizer,
196
- text_encoder=self.pipe.text_encoder,
197
- truncate_long_prompts=True,
198
- )
199
  if args.taesd:
200
  self.pipe.vae = AutoencoderTiny.from_pretrained(
201
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
@@ -216,7 +230,13 @@ class Pipeline:
216
 
217
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
218
  generator = torch.manual_seed(params.seed)
219
- prompt_embeds = self.pipe.compel_proc(params.prompt)
 
 
 
 
 
 
220
  control_image = self.canny_torch(
221
  params.image, params.canny_low_threshold, params.canny_high_threshold
222
  )
@@ -224,10 +244,10 @@ class Pipeline:
224
  strength = params.strength
225
  if int(steps * strength) < 1:
226
  steps = math.ceil(1 / max(0.10, strength))
227
- last_time = time.time()
228
  results = self.pipe(
229
  image=params.image,
230
  control_image=control_image,
 
231
  prompt_embeds=prompt_embeds,
232
  generator=generator,
233
  strength=strength,
@@ -240,8 +260,6 @@ class Pipeline:
240
  control_guidance_start=params.controlnet_start,
241
  control_guidance_end=params.controlnet_end,
242
  )
243
- print(f"Time taken: {time.time() - last_time}")
244
-
245
  nsfw_content_detected = (
246
  results.nsfw_content_detected[0]
247
  if "nsfw_content_detected" in results
 
180
  self.pipe.vae = AutoencoderTiny.from_pretrained(
181
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
182
  ).to(device)
183
+
184
+ if args.sfast:
185
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
186
+ compile,
187
+ CompilationConfig,
188
+ )
189
+
190
+ config = CompilationConfig.Default()
191
+ config.enable_xformers = True
192
+ config.enable_triton = True
193
+ config.enable_cuda_graph = True
194
+ self.pipe = compile(self.pipe, config=config)
195
+
196
  self.canny_torch = SobelOperator(device=device)
197
 
198
  self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
 
201
  if device.type != "mps":
202
  self.pipe.unet.to(memory_format=torch.channels_last)
203
 
204
+ if args.compel:
205
+ from compel import Compel
206
+
207
+ self.pipe.compel_proc = Compel(
208
+ tokenizer=self.pipe.tokenizer,
209
+ text_encoder=self.pipe.text_encoder,
210
+ truncate_long_prompts=True,
211
+ )
212
 
 
 
 
 
 
213
  if args.taesd:
214
  self.pipe.vae = AutoencoderTiny.from_pretrained(
215
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
 
230
 
231
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
232
  generator = torch.manual_seed(params.seed)
233
+ prompt = params.prompt
234
+ prompt_embeds = None
235
+ if hasattr(self.pipe, "compel_proc"):
236
+ prompt_embeds = self.pipe.compel_proc(
237
+ [params.prompt, params.negative_prompt]
238
+ )
239
+ prompt = None
240
  control_image = self.canny_torch(
241
  params.image, params.canny_low_threshold, params.canny_high_threshold
242
  )
 
244
  strength = params.strength
245
  if int(steps * strength) < 1:
246
  steps = math.ceil(1 / max(0.10, strength))
 
247
  results = self.pipe(
248
  image=params.image,
249
  control_image=control_image,
250
+ prompt=prompt,
251
  prompt_embeds=prompt_embeds,
252
  generator=generator,
253
  strength=strength,
 
260
  control_guidance_start=params.controlnet_start,
261
  control_guidance_end=params.controlnet_end,
262
  )
 
 
263
  nsfw_content_detected = (
264
  results.nsfw_content_detected[0]
265
  if "nsfw_content_detected" in results
pipelines/controlnetSDXLTurbo.py CHANGED
@@ -185,20 +185,31 @@ class Pipeline:
185
  )
186
  self.canny_torch = SobelOperator(device=device)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  self.pipe.set_progress_bar_config(disable=True)
189
  self.pipe.to(device=device, dtype=torch_dtype).to(device)
190
  if device.type != "mps":
191
  self.pipe.unet.to(memory_format=torch.channels_last)
192
 
193
- if psutil.virtual_memory().total < 64 * 1024**3:
194
- self.pipe.enable_attention_slicing()
 
 
 
 
 
195
 
196
- self.pipe.compel_proc = Compel(
197
- tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
198
- text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
199
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
200
- requires_pooled=[False, True],
201
- )
202
  if args.taesd:
203
  self.pipe.vae = AutoencoderTiny.from_pretrained(
204
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
@@ -220,9 +231,23 @@ class Pipeline:
220
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
221
  generator = torch.manual_seed(params.seed)
222
 
223
- prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
224
- [params.prompt, params.negative_prompt]
225
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  control_image = self.canny_torch(
227
  params.image, params.canny_low_threshold, params.canny_high_threshold
228
  )
@@ -234,10 +259,12 @@ class Pipeline:
234
  results = self.pipe(
235
  image=params.image,
236
  control_image=control_image,
237
- prompt_embeds=prompt_embeds[0:1],
238
- pooled_prompt_embeds=pooled_prompt_embeds[0:1],
239
- negative_prompt_embeds=prompt_embeds[1:2],
240
- negative_pooled_prompt_embeds=pooled_prompt_embeds[1:2],
 
 
241
  generator=generator,
242
  strength=strength,
243
  num_inference_steps=steps,
 
185
  )
186
  self.canny_torch = SobelOperator(device=device)
187
 
188
+ if args.sfast:
189
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
190
+ compile,
191
+ CompilationConfig,
192
+ )
193
+
194
+ config = CompilationConfig.Default()
195
+ config.enable_xformers = True
196
+ config.enable_triton = True
197
+ config.enable_cuda_graph = True
198
+ self.pipe = compile(self.pipe, config=config)
199
+
200
  self.pipe.set_progress_bar_config(disable=True)
201
  self.pipe.to(device=device, dtype=torch_dtype).to(device)
202
  if device.type != "mps":
203
  self.pipe.unet.to(memory_format=torch.channels_last)
204
 
205
+ if args.compel:
206
+ self.pipe.compel_proc = Compel(
207
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
208
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
209
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
210
+ requires_pooled=[False, True],
211
+ )
212
 
 
 
 
 
 
 
213
  if args.taesd:
214
  self.pipe.vae = AutoencoderTiny.from_pretrained(
215
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
 
231
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
232
  generator = torch.manual_seed(params.seed)
233
 
234
+ prompt = params.prompt
235
+ negative_prompt = params.negative_prompt
236
+ prompt_embeds = None
237
+ pooled_prompt_embeds = None
238
+ negative_prompt_embeds = None
239
+ negative_pooled_prompt_embeds = None
240
+ if hasattr(self.pipe, "compel_proc"):
241
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
242
+ [params.prompt, params.negative_prompt]
243
+ )
244
+ prompt = None
245
+ negative_prompt = None
246
+ prompt_embeds = _prompt_embeds[0:1]
247
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
248
+ negative_prompt_embeds = _prompt_embeds[1:2]
249
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
250
+
251
  control_image = self.canny_torch(
252
  params.image, params.canny_low_threshold, params.canny_high_threshold
253
  )
 
259
  results = self.pipe(
260
  image=params.image,
261
  control_image=control_image,
262
+ prompt=prompt,
263
+ negative_prompt=negative_prompt,
264
+ prompt_embeds=prompt_embeds,
265
+ pooled_prompt_embeds=pooled_prompt_embeds,
266
+ negative_prompt_embeds=negative_prompt_embeds,
267
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
268
  generator=generator,
269
  strength=strength,
270
  num_inference_steps=steps,
pipelines/img2imgSD21Turbo.py CHANGED
@@ -14,6 +14,10 @@ from config import Args
14
  from pydantic import BaseModel, Field
15
  from PIL import Image
16
  import math
 
 
 
 
17
 
18
  base_model = "stabilityai/sd-turbo"
19
  taesd_model = "madebyollin/taesd"
@@ -104,15 +108,23 @@ class Pipeline:
104
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
105
  ).to(device)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.pipe.set_progress_bar_config(disable=True)
108
  self.pipe.to(device=device, dtype=torch_dtype)
109
  if device.type != "mps":
110
  self.pipe.unet.to(memory_format=torch.channels_last)
111
 
112
- # check if computer has less than 64GB of RAM using sys or os
113
- if psutil.virtual_memory().total < 64 * 1024**3:
114
- self.pipe.enable_attention_slicing()
115
-
116
  if args.torch_compile:
117
  print("Running torch compile")
118
  self.pipe.unet = torch.compile(
 
14
  from pydantic import BaseModel, Field
15
  from PIL import Image
16
  import math
17
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
18
+ compile,
19
+ CompilationConfig,
20
+ )
21
 
22
  base_model = "stabilityai/sd-turbo"
23
  taesd_model = "madebyollin/taesd"
 
108
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
109
  ).to(device)
110
 
111
+ if args.sfast:
112
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
113
+ compile,
114
+ CompilationConfig,
115
+ )
116
+
117
+ config = CompilationConfig.Default()
118
+ config.enable_xformers = True
119
+ config.enable_triton = True
120
+ config.enable_cuda_graph = True
121
+ self.pipe = compile(self.pipe, config=config)
122
+
123
  self.pipe.set_progress_bar_config(disable=True)
124
  self.pipe.to(device=device, dtype=torch_dtype)
125
  if device.type != "mps":
126
  self.pipe.unet.to(memory_format=torch.channels_last)
127
 
 
 
 
 
128
  if args.torch_compile:
129
  print("Running torch compile")
130
  self.pipe.unet = torch.compile(
requirements.txt CHANGED
@@ -10,4 +10,5 @@ compel==2.0.2
10
  controlnet-aux==0.0.7
11
  peft==0.6.0
12
  xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
13
- markdown2
 
 
10
  controlnet-aux==0.0.7
11
  peft==0.6.0
12
  xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
13
+ markdown2
14
+ stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.15.post1/stable_fast-0.0.15.post1+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl