radames commited on
Commit
c7af5cd
·
1 Parent(s): ec8114e
app-controlnetlora.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from pydantic import BaseModel
6
+
7
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import (
10
+ StreamingResponse,
11
+ JSONResponse,
12
+ HTMLResponse,
13
+ FileResponse,
14
+ )
15
+
16
+ from diffusers import (
17
+ StableDiffusionControlNetImg2ImgPipeline,
18
+ ControlNetModel,
19
+ LCMScheduler,
20
+ )
21
+ from compel import Compel
22
+ import torch
23
+
24
+ from canny_gpu import SobelOperator
25
+
26
+ # from controlnet_aux import OpenposeDetector
27
+ # import cv2
28
+
29
+ try:
30
+ import intel_extension_for_pytorch as ipex
31
+ except:
32
+ pass
33
+ from PIL import Image
34
+ import numpy as np
35
+ import gradio as gr
36
+ import io
37
+ import uuid
38
+ import os
39
+ import time
40
+ import psutil
41
+
42
+
43
+ MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
44
+ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
45
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
46
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
47
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
48
+
49
+ WIDTH = 512
50
+ HEIGHT = 512
51
+
52
+
53
+ # check if MPS is available OSX only M1/M2/M3 chips
54
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
55
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
56
+ device = torch.device(
57
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
58
+ )
59
+
60
+ # change to torch.float16 to save GPU memory
61
+ torch_dtype = torch.float16
62
+
63
+ print(f"TIMEOUT: {TIMEOUT}")
64
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
65
+ print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
66
+ print(f"device: {device}")
67
+
68
+ if mps_available:
69
+ device = torch.device("mps")
70
+ device = "cpu"
71
+ torch_dtype = torch.float32
72
+
73
+ controlnet_canny = ControlNetModel.from_pretrained(
74
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch_dtype
75
+ ).to(device)
76
+
77
+ canny_torch = SobelOperator(device=device)
78
+
79
+ model_id = "nitrosocke/mo-di-diffusion"
80
+ lcm_lora_id = "lcm-sd/lcm-sd1.5-lora"
81
+
82
+ if SAFETY_CHECKER == "True":
83
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
84
+ model_id,
85
+ controlnet=controlnet_canny,
86
+ )
87
+ else:
88
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
89
+ model_id,
90
+ safety_checker=None,
91
+ controlnet=controlnet_canny,
92
+ )
93
+
94
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
95
+ pipe.set_progress_bar_config(disable=True)
96
+ pipe.to(device=device, dtype=torch_dtype).to(device)
97
+ pipe.unet.to(memory_format=torch.channels_last)
98
+
99
+
100
+ if psutil.virtual_memory().total < 64 * 1024**3:
101
+ pipe.enable_attention_slicing()
102
+
103
+ # Load LCM LoRA
104
+ pipe.load_lora_weights(
105
+ lcm_lora_id,
106
+ weight_name="lcm_sd_lora.safetensors",
107
+ adapter_name="lcm",
108
+ use_auth_token=HF_TOKEN,
109
+ )
110
+
111
+ compel_proc = Compel(
112
+ tokenizer=pipe.tokenizer,
113
+ text_encoder=pipe.text_encoder,
114
+ truncate_long_prompts=False,
115
+ )
116
+ if TORCH_COMPILE:
117
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
118
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
119
+
120
+ pipe(
121
+ prompt="warmup",
122
+ image=[Image.new("RGB", (768, 768))],
123
+ control_image=[Image.new("RGB", (768, 768))],
124
+ )
125
+
126
+
127
+ user_queue_map = {}
128
+
129
+
130
+ class InputParams(BaseModel):
131
+ seed: int = 2159232
132
+ prompt: str
133
+ guidance_scale: float = 8.0
134
+ strength: float = 0.5
135
+ steps: int = 4
136
+ lcm_steps: int = 50
137
+ width: int = WIDTH
138
+ height: int = HEIGHT
139
+ controlnet_scale: float = 0.8
140
+ controlnet_start: float = 0.0
141
+ controlnet_end: float = 1.0
142
+ canny_low_threshold: float = 0.31
143
+ canny_high_threshold: float = 0.78
144
+ debug_canny: bool = False
145
+
146
+
147
+ def predict(
148
+ input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None
149
+ ):
150
+ generator = torch.manual_seed(params.seed)
151
+
152
+ control_image = canny_torch(
153
+ input_image, params.canny_low_threshold, params.canny_high_threshold
154
+ )
155
+ results = pipe(
156
+ control_image=control_image,
157
+ prompt_embeds=prompt_embeds,
158
+ generator=generator,
159
+ image=input_image,
160
+ strength=params.strength,
161
+ num_inference_steps=params.steps,
162
+ guidance_scale=params.guidance_scale,
163
+ width=params.width,
164
+ height=params.height,
165
+ output_type="pil",
166
+ controlnet_conditioning_scale=params.controlnet_scale,
167
+ control_guidance_start=params.controlnet_start,
168
+ control_guidance_end=params.controlnet_end,
169
+ )
170
+ nsfw_content_detected = (
171
+ results.nsfw_content_detected[0]
172
+ if "nsfw_content_detected" in results
173
+ else False
174
+ )
175
+ if nsfw_content_detected:
176
+ return None
177
+ result_image = results.images[0]
178
+ if params.debug_canny:
179
+ # paste control_image on top of result_image
180
+ w0, h0 = (200, 200)
181
+ control_image = control_image.resize((w0, h0))
182
+ w1, h1 = result_image.size
183
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
184
+
185
+ return result_image
186
+
187
+
188
+ app = FastAPI()
189
+ app.add_middleware(
190
+ CORSMiddleware,
191
+ allow_origins=["*"],
192
+ allow_credentials=True,
193
+ allow_methods=["*"],
194
+ allow_headers=["*"],
195
+ )
196
+
197
+
198
+ @app.websocket("/ws")
199
+ async def websocket_endpoint(websocket: WebSocket):
200
+ await websocket.accept()
201
+ if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
202
+ print("Server is full")
203
+ await websocket.send_json({"status": "error", "message": "Server is full"})
204
+ await websocket.close()
205
+ return
206
+
207
+ try:
208
+ uid = str(uuid.uuid4())
209
+ print(f"New user connected: {uid}")
210
+ await websocket.send_json(
211
+ {"status": "success", "message": "Connected", "userId": uid}
212
+ )
213
+ user_queue_map[uid] = {"queue": asyncio.Queue()}
214
+ await websocket.send_json(
215
+ {"status": "start", "message": "Start Streaming", "userId": uid}
216
+ )
217
+ await handle_websocket_data(websocket, uid)
218
+ except WebSocketDisconnect as e:
219
+ logging.error(f"WebSocket Error: {e}, {uid}")
220
+ traceback.print_exc()
221
+ finally:
222
+ print(f"User disconnected: {uid}")
223
+ queue_value = user_queue_map.pop(uid, None)
224
+ queue = queue_value.get("queue", None)
225
+ if queue:
226
+ while not queue.empty():
227
+ try:
228
+ queue.get_nowait()
229
+ except asyncio.QueueEmpty:
230
+ continue
231
+
232
+
233
+ @app.get("/queue_size")
234
+ async def get_queue_size():
235
+ queue_size = len(user_queue_map)
236
+ return JSONResponse({"queue_size": queue_size})
237
+
238
+
239
+ @app.get("/stream/{user_id}")
240
+ async def stream(user_id: uuid.UUID):
241
+ uid = str(user_id)
242
+ try:
243
+ user_queue = user_queue_map[uid]
244
+ queue = user_queue["queue"]
245
+
246
+ async def generate():
247
+ last_prompt: str = None
248
+ prompt_embeds: torch.Tensor = None
249
+ while True:
250
+ data = await queue.get()
251
+ input_image = data["image"]
252
+ params = data["params"]
253
+ if input_image is None:
254
+ continue
255
+ # avoid recalculate prompt embeds
256
+ if last_prompt != params.prompt:
257
+ print("new prompt")
258
+ prompt_embeds = compel_proc(params.prompt)
259
+ last_prompt = params.prompt
260
+
261
+ image = predict(
262
+ input_image,
263
+ params,
264
+ prompt_embeds,
265
+ )
266
+ if image is None:
267
+ continue
268
+ frame_data = io.BytesIO()
269
+ image.save(frame_data, format="JPEG")
270
+ frame_data = frame_data.getvalue()
271
+ if frame_data is not None and len(frame_data) > 0:
272
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
273
+
274
+ await asyncio.sleep(1.0 / 120.0)
275
+
276
+ return StreamingResponse(
277
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
278
+ )
279
+ except Exception as e:
280
+ logging.error(f"Streaming Error: {e}, {user_queue_map}")
281
+ traceback.print_exc()
282
+ return HTTPException(status_code=404, detail="User not found")
283
+
284
+
285
+ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
286
+ uid = str(user_id)
287
+ user_queue = user_queue_map[uid]
288
+ queue = user_queue["queue"]
289
+ if not queue:
290
+ return HTTPException(status_code=404, detail="User not found")
291
+ last_time = time.time()
292
+ try:
293
+ while True:
294
+ data = await websocket.receive_bytes()
295
+ params = await websocket.receive_json()
296
+ params = InputParams(**params)
297
+ pil_image = Image.open(io.BytesIO(data))
298
+
299
+ while not queue.empty():
300
+ try:
301
+ queue.get_nowait()
302
+ except asyncio.QueueEmpty:
303
+ continue
304
+ await queue.put({"image": pil_image, "params": params})
305
+ if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
306
+ await websocket.send_json(
307
+ {
308
+ "status": "timeout",
309
+ "message": "Your session has ended",
310
+ "userId": uid,
311
+ }
312
+ )
313
+ await websocket.close()
314
+ return
315
+
316
+ except Exception as e:
317
+ logging.error(f"Error: {e}")
318
+ traceback.print_exc()
319
+
320
+
321
+ @app.get("/", response_class=HTMLResponse)
322
+ async def root():
323
+ return FileResponse("./static/controlnetlora.html")
static/controlnetlora.html ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <title>Real-Time Latent Consistency Model ControlNet</title>
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <script
9
+ src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
+ <script src="https://cdn.tailwindcss.com"></script>
12
+ <style type="text/tailwindcss">
13
+ .button {
14
+ @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
15
+ }
16
+ </style>
17
+ <script type="module">
18
+ const getValue = (id) => {
19
+ const el = document.querySelector(`${id}`)
20
+ if (el.type === "checkbox")
21
+ return el.checked;
22
+ return el.value;
23
+ }
24
+ const startBtn = document.querySelector("#start");
25
+ const stopBtn = document.querySelector("#stop");
26
+ const videoEl = document.querySelector("#webcam");
27
+ const imageEl = document.querySelector("#player");
28
+ const queueSizeEl = document.querySelector("#queue_size");
29
+ const errorEl = document.querySelector("#error");
30
+ const snapBtn = document.querySelector("#snap");
31
+ const webcamsEl = document.querySelector("#webcams");
32
+
33
+ function LCMLive(webcamVideo, liveImage) {
34
+ let websocket;
35
+
36
+ async function start() {
37
+ return new Promise((resolve, reject) => {
38
+ const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
39
+ }:${window.location.host}/ws`;
40
+
41
+ const socket = new WebSocket(websocketURL);
42
+ socket.onopen = () => {
43
+ console.log("Connected to websocket");
44
+ };
45
+ socket.onclose = () => {
46
+ console.log("Disconnected from websocket");
47
+ stop();
48
+ resolve({ "status": "disconnected" });
49
+ };
50
+ socket.onerror = (err) => {
51
+ console.error(err);
52
+ reject(err);
53
+ };
54
+ socket.onmessage = (event) => {
55
+ const data = JSON.parse(event.data);
56
+ switch (data.status) {
57
+ case "success":
58
+ break;
59
+ case "start":
60
+ const userId = data.userId;
61
+ initVideoStream(userId);
62
+ break;
63
+ case "timeout":
64
+ stop();
65
+ resolve({ "status": "timeout" });
66
+ case "error":
67
+ stop();
68
+ reject(data.message);
69
+
70
+ }
71
+ };
72
+ websocket = socket;
73
+ })
74
+ }
75
+ function switchCamera() {
76
+ const constraints = {
77
+ audio: false,
78
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[webcamsEl.value].deviceId }
79
+ };
80
+ navigator.mediaDevices
81
+ .getUserMedia(constraints)
82
+ .then((mediaStream) => {
83
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
84
+ webcamVideo.srcObject = mediaStream;
85
+ webcamVideo.onloadedmetadata = () => {
86
+ webcamVideo.play();
87
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
88
+ };
89
+ })
90
+ .catch((err) => {
91
+ console.error(`${err.name}: ${err.message}`);
92
+ });
93
+ }
94
+
95
+ async function videoTimeUpdateHandler() {
96
+ const dimension = getValue("input[name=dimension]:checked");
97
+ const [WIDTH, HEIGHT] = JSON.parse(dimension);
98
+
99
+ const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
100
+ const videoW = webcamVideo.videoWidth;
101
+ const videoH = webcamVideo.videoHeight;
102
+ const aspectRatio = WIDTH / HEIGHT;
103
+
104
+ const ctx = canvas.getContext("2d");
105
+ ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
106
+ const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
107
+ websocket.send(blob);
108
+ websocket.send(JSON.stringify({
109
+ "seed": getValue("#seed"),
110
+ "prompt": getValue("#prompt"),
111
+ "guidance_scale": getValue("#guidance-scale"),
112
+ "strength": getValue("#strength"),
113
+ "steps": getValue("#steps"),
114
+ "width": WIDTH,
115
+ "height": HEIGHT,
116
+ "controlnet_scale": getValue("#controlnet_scale"),
117
+ "controlnet_start": getValue("#controlnet_start"),
118
+ "controlnet_end": getValue("#controlnet_end"),
119
+ "canny_low_threshold": getValue("#canny_low_threshold"),
120
+ "canny_high_threshold": getValue("#canny_high_threshold"),
121
+ "debug_canny": getValue("#debug_canny")
122
+ }));
123
+ }
124
+ let mediaDevices = [];
125
+ async function initVideoStream(userId) {
126
+ liveImage.src = `/stream/${userId}`;
127
+ await navigator.mediaDevices.enumerateDevices()
128
+ .then(devices => {
129
+ const cameras = devices.filter(device => device.kind === 'videoinput');
130
+ mediaDevices = cameras;
131
+ webcamsEl.innerHTML = "";
132
+ cameras.forEach((camera, index) => {
133
+ const option = document.createElement("option");
134
+ option.value = index;
135
+ option.innerText = camera.label;
136
+ webcamsEl.appendChild(option);
137
+ option.selected = index === 0;
138
+ });
139
+ webcamsEl.addEventListener("change", switchCamera);
140
+ })
141
+ .catch(err => {
142
+ console.error(err);
143
+ });
144
+ const constraints = {
145
+ audio: false,
146
+ video: { width: 1024, height: 1024, deviceId: mediaDevices[0].deviceId }
147
+ };
148
+ navigator.mediaDevices
149
+ .getUserMedia(constraints)
150
+ .then((mediaStream) => {
151
+ webcamVideo.srcObject = mediaStream;
152
+ webcamVideo.onloadedmetadata = () => {
153
+ webcamVideo.play();
154
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
155
+ };
156
+ })
157
+ .catch((err) => {
158
+ console.error(`${err.name}: ${err.message}`);
159
+ });
160
+ }
161
+
162
+
163
+ async function stop() {
164
+ websocket.close();
165
+ navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
166
+ mediaStream.getTracks().forEach((track) => track.stop());
167
+ });
168
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
169
+ webcamsEl.removeEventListener("change", switchCamera);
170
+ webcamVideo.srcObject = null;
171
+ }
172
+ return {
173
+ start,
174
+ stop
175
+ }
176
+ }
177
+ function toggleMessage(type) {
178
+ errorEl.hidden = false;
179
+ errorEl.scrollIntoView();
180
+ switch (type) {
181
+ case "error":
182
+ errorEl.innerText = "To many users are using the same GPU, please try again later.";
183
+ errorEl.classList.toggle("bg-red-300", "text-red-900");
184
+ break;
185
+ case "success":
186
+ errorEl.innerText = "Your session has ended, please start a new one.";
187
+ errorEl.classList.toggle("bg-green-300", "text-green-900");
188
+ break;
189
+ }
190
+ setTimeout(() => {
191
+ errorEl.hidden = true;
192
+ }, 2000);
193
+ }
194
+ function snapImage() {
195
+ try {
196
+ const zeroth = {};
197
+ const exif = {};
198
+ const gps = {};
199
+ zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image ControNet";
200
+ zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | strength: ${getValue("#strength")} | controlnet_start: ${getValue("#controlnet_start")} | controlnet_end: ${getValue("#controlnet_end")} | steps: ${getValue("#steps")}`;
201
+ zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
202
+ exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
203
+
204
+ const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
205
+ const exifBytes = piexif.dump(exifObj);
206
+
207
+ const canvas = document.createElement("canvas");
208
+ canvas.width = imageEl.naturalWidth;
209
+ canvas.height = imageEl.naturalHeight;
210
+ const ctx = canvas.getContext("2d");
211
+ ctx.drawImage(imageEl, 0, 0);
212
+ const dataURL = canvas.toDataURL("image/jpeg");
213
+ const withExif = piexif.insert(exifBytes, dataURL);
214
+
215
+ const a = document.createElement("a");
216
+ a.href = withExif;
217
+ a.download = `lcm_txt_2_img${Date.now()}.png`;
218
+ a.click();
219
+ } catch (err) {
220
+ console.log(err);
221
+ }
222
+ }
223
+
224
+
225
+ const lcmLive = LCMLive(videoEl, imageEl);
226
+ startBtn.addEventListener("click", async () => {
227
+ try {
228
+ startBtn.disabled = true;
229
+ snapBtn.disabled = false;
230
+ const res = await lcmLive.start();
231
+ startBtn.disabled = false;
232
+ if (res.status === "timeout")
233
+ toggleMessage("success")
234
+ } catch (err) {
235
+ console.log(err);
236
+ toggleMessage("error")
237
+ startBtn.disabled = false;
238
+ }
239
+ });
240
+ stopBtn.addEventListener("click", () => {
241
+ lcmLive.stop();
242
+ });
243
+ window.addEventListener("beforeunload", () => {
244
+ lcmLive.stop();
245
+ });
246
+ snapBtn.addEventListener("click", snapImage);
247
+ setInterval(() =>
248
+ fetch("/queue_size")
249
+ .then((res) => res.json())
250
+ .then((data) => {
251
+ queueSizeEl.innerText = data.queue_size;
252
+ })
253
+ .catch((err) => {
254
+ console.log(err);
255
+ })
256
+ , 5000);
257
+ </script>
258
+ </head>
259
+
260
+ <body class="text-black dark:bg-gray-900 dark:text-white">
261
+ <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
262
+ </div>
263
+ <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
264
+ <article class="text-center max-w-xl mx-auto">
265
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
266
+ <h2 class="text-2xl font-bold mb-4">ControlNet Lora</h2>
267
+ <p class="text-sm">
268
+ This demo showcases
269
+ <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
270
+ class="text-blue-500 underline hover:no-underline">LCM</a> Image to Image pipeline
271
+ using
272
+ <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
273
+ target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
274
+ stream server. Featuring <a href="https://huggingface.co/nitrosocke/mo-di-diffusion" target="_blank"
275
+ class="text-blue-500 underline hover:no-underline">Nitrosocke Mo-Di Diffusion</a>Model.
276
+ </p>
277
+ </article>
278
+ <div>
279
+ <h2 class="font-medium">Prompt</h2>
280
+ <p class="text-sm text-gray-500">
281
+ Change the prompt to generate different images, accepts <a
282
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
283
+ class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
284
+ </p>
285
+ <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
286
+ <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none dark:text-black"
287
+ title="Prompt, this is an example, feel free to modify"
288
+ placeholder="Add your prompt here...">a magical princess with golden hair, modern disney style</textarea>
289
+ </div>
290
+ </div>
291
+ <div class="">
292
+ <details>
293
+ <summary class="font-medium cursor-pointer">Advanced Options</summary>
294
+ <div class="grid grid-cols-3 sm:grid-cols-6 items-center gap-3 py-3">
295
+ <label for="webcams" class="text-sm font-medium">Camera Options: </label>
296
+ <select id="webcams" class="text-sm border-2 border-gray-500 rounded-md font-light dark:text-black">
297
+ </select>
298
+ <div></div>
299
+ <label class="text-sm font-medium " for="steps">Inference Steps
300
+ </label>
301
+ <input type="range" id="steps" name="steps" min="1" max="8" value="4"
302
+ oninput="this.nextElementSibling.value = Number(this.value)">
303
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
304
+ 4</output>
305
+ <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
306
+ </label>
307
+ <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="5" step="0.001"
308
+ value="0.3" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
309
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
310
+ 0.3</output>
311
+ <!-- -->
312
+ <label class="text-sm font-medium" for="strength">Strength</label>
313
+ <input type="range" id="strength" name="strength" min="0.1" max="1" step="0.001" value="0.50"
314
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
315
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
316
+ 0.5</output>
317
+ <!-- -->
318
+ <label class="text-sm font-medium" for="controlnet_scale">ControlNet Condition Scale</label>
319
+ <input type="range" id="controlnet_scale" name="controlnet_scale" min="0.0" max="1" step="0.001"
320
+ value="0.80" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
321
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
322
+ 0.8</output>
323
+ <!-- -->
324
+ <label class="text-sm font-medium" for="controlnet_start">ControlNet Guidance Start</label>
325
+ <input type="range" id="controlnet_start" name="controlnet_start" min="0.0" max="1.0" step="0.001"
326
+ value="0.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
327
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
328
+ 0.0</output>
329
+ <!-- -->
330
+ <label class="text-sm font-medium" for="controlnet_end">ControlNet Guidance End</label>
331
+ <input type="range" id="controlnet_end" name="controlnet_end" min="0.0" max="1.0" step="0.001"
332
+ value="0.8" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
333
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
334
+ 0.8</output>
335
+ <!-- -->
336
+ <label class="text-sm font-medium" for="canny_low_threshold">Canny Low Threshold</label>
337
+ <input type="range" id="canny_low_threshold" name="canny_low_threshold" min="0.0" max="1.0"
338
+ step="0.001" value="0.1"
339
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
340
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
341
+ 0.1</output>
342
+ <!-- -->
343
+ <label class="text-sm font-medium" for="canny_high_threshold">Canny High Threshold</label>
344
+ <input type="range" id="canny_high_threshold" name="canny_high_threshold" min="0.0" max="1.0"
345
+ step="0.001" value="0.2"
346
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
347
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
348
+ 0.2</output>
349
+ <!-- -->
350
+ <label class="text-sm font-medium" for="seed">Seed</label>
351
+ <input type="number" id="seed" name="seed" value="299792458"
352
+ class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
353
+ <button
354
+ onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
355
+ class="button">
356
+ Rand
357
+ </button>
358
+ <!-- -->
359
+ <!-- -->
360
+ <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
361
+ <div class="col-span-2 flex gap-2">
362
+ <div class="flex gap-1">
363
+ <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
364
+ class="cursor-pointer">
365
+ <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
366
+ </div>
367
+ <div class="flex gap-1">
368
+ <input type="radio" id="dimension768" name="dimension" value="[768,768]"
369
+ lass="cursor-pointer">
370
+ <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
371
+ </div>
372
+ </div>
373
+ <!-- -->
374
+ <!-- -->
375
+ <label class="text-sm font-medium" for="debug_canny">Debug Canny</label>
376
+ <div class="col-span-2 flex gap-2">
377
+ <input type="checkbox" id="debug_canny" name="debug_canny" class="cursor-pointer">
378
+ <label for="debug_canny" class="text-sm cursor-pointer"></label>
379
+ </div>
380
+ <div></div>
381
+ <!-- -->
382
+ </div>
383
+ </details>
384
+ </div>
385
+ <div class="flex gap-3">
386
+ <button id="start" class="button">
387
+ Start
388
+ </button>
389
+ <button id="stop" class="button">
390
+ Stop
391
+ </button>
392
+ <button id="snap" disabled class="button ml-auto">
393
+ Snapshot
394
+ </button>
395
+ </div>
396
+ <div class="relative rounded-lg border border-slate-300 overflow-hidden">
397
+ <img id="player" class="w-full aspect-square rounded-lg"
398
+ src="">
399
+ <div class="absolute top-0 left-0 w-1/4 aspect-square">
400
+ <video id="webcam" class="w-full aspect-square relative z-10 object-cover" playsinline autoplay muted
401
+ loop></video>
402
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 448" width="100"
403
+ class="w-full p-4 absolute top-0 opacity-20 z-0">
404
+ <path fill="currentColor"
405
+ d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z" />
406
+ </svg>
407
+ </div>
408
+ </div>
409
+ </main>
410
+ </body>
411
+
412
+ </html>
static/txt2imglora.html CHANGED
@@ -201,7 +201,7 @@
201
  <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
202
  <article class="text-center max-w-xl mx-auto">
203
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
204
- <h2 class="text-2xl font-bold mb-4">Text to Image</h2>
205
  <p class="text-sm">
206
  This demo showcases
207
  <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
 
201
  <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
202
  <article class="text-center max-w-xl mx-auto">
203
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
204
+ <h2 class="text-2xl font-bold mb-4">Text to Image Lora</h2>
205
  <p class="text-sm">
206
  This demo showcases
207
  <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"