radames commited on
Commit
8a96a46
·
1 Parent(s): d8c251e

add img2imgSD21Turbo pipeline

Browse files
Files changed (5) hide show
  1. app.py +3 -0
  2. app_init.py +4 -0
  3. config.py +14 -0
  4. pipelines/img2imgSD21Turbo.py +186 -0
  5. requirements.txt +1 -1
app.py CHANGED
@@ -12,6 +12,9 @@ print("TORCH_DTYPE:", torch_dtype)
12
  print("PIPELINE:", args.pipeline)
13
  print("SAFETY_CHECKER:", args.safety_checker)
14
  print("TORCH_COMPILE:", args.torch_compile)
 
 
 
15
 
16
 
17
  app = FastAPI()
 
12
  print("PIPELINE:", args.pipeline)
13
  print("SAFETY_CHECKER:", args.safety_checker)
14
  print("TORCH_COMPILE:", args.torch_compile)
15
+ print("USE_TAESD:", args.use_taesd)
16
+ print("COMPEL:", args.compel)
17
+ print("DEBUG:", args.debug)
18
 
19
 
20
  app = FastAPI()
app_init.py CHANGED
@@ -15,6 +15,7 @@ from types import SimpleNamespace
15
  from util import pil_to_frame, bytes_to_pil, is_firefox
16
  import asyncio
17
  import os
 
18
 
19
 
20
  def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
@@ -105,6 +106,7 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
105
  websocket = user_data.get_websocket(user_id)
106
  last_params = SimpleNamespace()
107
  while True:
 
108
  params = await user_data.get_latest_data(user_id)
109
  if not vars(params) or params.__dict__ == last_params.__dict__:
110
  await websocket.send_json({"status": "send_frame"})
@@ -122,6 +124,8 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
122
  if not is_firefox(request.headers["user-agent"]):
123
  yield frame
124
  await websocket.send_json({"status": "send_frame"})
 
 
125
 
126
  return StreamingResponse(
127
  generate(),
 
15
  from util import pil_to_frame, bytes_to_pil, is_firefox
16
  import asyncio
17
  import os
18
+ import time
19
 
20
 
21
  def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
 
106
  websocket = user_data.get_websocket(user_id)
107
  last_params = SimpleNamespace()
108
  while True:
109
+ last_time = time.time()
110
  params = await user_data.get_latest_data(user_id)
111
  if not vars(params) or params.__dict__ == last_params.__dict__:
112
  await websocket.send_json({"status": "send_frame"})
 
124
  if not is_firefox(request.headers["user-agent"]):
125
  yield frame
126
  await websocket.send_json({"status": "send_frame"})
127
+ if args.debug:
128
+ print(f"Time taken: {time.time() - last_time}")
129
 
130
  return StreamingResponse(
131
  generate(),
config.py CHANGED
@@ -16,6 +16,8 @@ class Args(NamedTuple):
16
  pipeline: str
17
  ssl_certfile: str
18
  ssl_keyfile: str
 
 
19
 
20
 
21
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
@@ -83,5 +85,17 @@ parser.add_argument(
83
  default=None,
84
  help="SSL keyfile",
85
  )
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  args = Args(**vars(parser.parse_args()))
 
16
  pipeline: str
17
  ssl_certfile: str
18
  ssl_keyfile: str
19
+ compel: bool = False
20
+ debug: bool = False
21
 
22
 
23
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
 
85
  default=None,
86
  help="SSL keyfile",
87
  )
88
+ parser.add_argument(
89
+ "--debug",
90
+ type=bool,
91
+ default=False,
92
+ help="Debug",
93
+ )
94
+ parser.add_argument(
95
+ "--compel",
96
+ type=bool,
97
+ default=False,
98
+ help="Compel",
99
+ )
100
 
101
  args = Args(**vars(parser.parse_args()))
pipelines/img2imgSD21Turbo.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ AutoPipelineForImage2Image,
3
+ AutoencoderTiny,
4
+ )
5
+ import torch
6
+
7
+ try:
8
+ import intel_extension_for_pytorch as ipex # type: ignore
9
+ except:
10
+ pass
11
+
12
+ import psutil
13
+ from config import Args
14
+ from pydantic import BaseModel, Field
15
+ from PIL import Image
16
+ import math
17
+
18
+ base_model = "stabilityai/sd-turbo"
19
+ taesd_model = "madebyollin/taesd"
20
+
21
+ default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
22
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
23
+ page_content = """
24
+ <h1 class="text-3xl font-bold">Real-Time SDXL Turbo</h1>
25
+ <h3 class="text-xl font-bold">Image-to-Image</h3>
26
+ <p class="text-sm">
27
+ This demo showcases
28
+ <a
29
+ href="https://huggingface.co/stabilityai/sdxl-turbo"
30
+ target="_blank"
31
+ class="text-blue-500 underline hover:no-underline">SDXL Turbo</a>
32
+ Image to Image pipeline using
33
+ <a
34
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl_turbo"
35
+ target="_blank"
36
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
37
+ > with a MJPEG stream server.
38
+ </p>
39
+ <p class="text-sm text-gray-500">
40
+ Change the prompt to generate different images, accepts <a
41
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
42
+ target="_blank"
43
+ class="text-blue-500 underline hover:no-underline">Compel</a
44
+ > syntax.
45
+ </p>
46
+ """
47
+
48
+
49
+ class Pipeline:
50
+ class Info(BaseModel):
51
+ name: str = "img2img"
52
+ title: str = "Image-to-Image SDXL"
53
+ description: str = "Generates an image from a text prompt"
54
+ input_mode: str = "image"
55
+ page_content: str = page_content
56
+
57
+ class InputParams(BaseModel):
58
+ prompt: str = Field(
59
+ default_prompt,
60
+ title="Prompt",
61
+ field="textarea",
62
+ id="prompt",
63
+ )
64
+ negative_prompt: str = Field(
65
+ default_negative_prompt,
66
+ title="Negative Prompt",
67
+ field="textarea",
68
+ id="negative_prompt",
69
+ hide=True,
70
+ )
71
+ seed: int = Field(
72
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
73
+ )
74
+ steps: int = Field(
75
+ 4, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
76
+ )
77
+ width: int = Field(
78
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
79
+ )
80
+ height: int = Field(
81
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
82
+ )
83
+ guidance_scale: float = Field(
84
+ 0.0,
85
+ min=0,
86
+ max=1,
87
+ step=0.001,
88
+ title="Guidance Scale",
89
+ field="range",
90
+ hide=True,
91
+ id="guidance_scale",
92
+ )
93
+ strength: float = Field(
94
+ 0.5,
95
+ min=0.25,
96
+ max=1.0,
97
+ step=0.001,
98
+ title="Strength",
99
+ field="range",
100
+ hide=True,
101
+ id="strength",
102
+ )
103
+
104
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
105
+ if args.safety_checker:
106
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(base_model)
107
+ else:
108
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
109
+ base_model,
110
+ safety_checker=None,
111
+ )
112
+ if args.use_taesd:
113
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
114
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
115
+ ).to(device)
116
+
117
+ self.pipe.set_progress_bar_config(disable=True)
118
+ self.pipe.to(device=device, dtype=torch_dtype)
119
+ if device.type != "mps":
120
+ self.pipe.unet.to(memory_format=torch.channels_last)
121
+
122
+ # check if computer has less than 64GB of RAM using sys or os
123
+ if psutil.virtual_memory().total < 64 * 1024**3:
124
+ self.pipe.enable_attention_slicing()
125
+
126
+ if args.torch_compile:
127
+ print("Running torch compile")
128
+ self.pipe.unet = torch.compile(
129
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
130
+ )
131
+ self.pipe.vae = torch.compile(
132
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
133
+ )
134
+
135
+ self.pipe(
136
+ prompt="warmup",
137
+ image=[Image.new("RGB", (768, 768))],
138
+ )
139
+ if args.compel:
140
+ from compel import Compel
141
+
142
+ self.pipe.compel_proc = Compel(
143
+ tokenizer=self.pipe.tokenizer,
144
+ text_encoder=self.pipe.text_encoder,
145
+ truncate_long_prompts=True,
146
+ )
147
+
148
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
149
+ generator = torch.manual_seed(params.seed)
150
+ steps = params.steps
151
+ strength = params.strength
152
+ if int(steps * strength) < 1:
153
+ steps = math.ceil(1 / max(0.10, strength))
154
+
155
+ prompt = params.prompt
156
+ prompt_embeds = None
157
+ if hasattr(self.pipe, "compel_proc"):
158
+ prompt_embeds = self.pipe.compel_proc(
159
+ [params.prompt, params.negative_prompt]
160
+ )
161
+ prompt = None
162
+
163
+ results = self.pipe(
164
+ image=params.image,
165
+ prompt_embeds=prompt_embeds,
166
+ prompt=prompt,
167
+ negative_prompt=params.negative_prompt,
168
+ generator=generator,
169
+ strength=strength,
170
+ num_inference_steps=steps,
171
+ guidance_scale=params.guidance_scale,
172
+ width=params.width,
173
+ height=params.height,
174
+ output_type="pil",
175
+ )
176
+
177
+ nsfw_content_detected = (
178
+ results.nsfw_content_detected[0]
179
+ if "nsfw_content_detected" in results
180
+ else False
181
+ )
182
+ if nsfw_content_detected:
183
+ return None
184
+ result_image = results.images[0]
185
+
186
+ return result_image
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- git+https://github.com/huggingface/diffusers@dadd55fb36acc862254cf935826d54349b0fcd8c
2
  transformers==4.35.2
3
  --extra-index-url https://download.pytorch.org/whl/cu121;
4
  torch==2.1.0
 
1
+ git+https://github.com/huggingface/diffusers@29dfe22a8e6f1ea1e1f6cd4fbb8381f08064091e
2
  transformers==4.35.2
3
  --extra-index-url https://download.pytorch.org/whl/cu121;
4
  torch==2.1.0