radames commited on
Commit
0b5ceff
·
1 Parent(s): cd353d4

add txt2imglora pipeline

Browse files
app-txt2imglora.py DELETED
@@ -1,254 +0,0 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import traceback
5
- from pydantic import BaseModel
6
-
7
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import (
10
- StreamingResponse,
11
- JSONResponse,
12
- HTMLResponse,
13
- FileResponse,
14
- )
15
-
16
- from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
17
- from compel import Compel
18
- import torch
19
-
20
- try:
21
- import intel_extension_for_pytorch as ipex
22
- except:
23
- pass
24
- from PIL import Image
25
- import numpy as np
26
- import gradio as gr
27
- import io
28
- import uuid
29
- import os
30
- import time
31
- import psutil
32
-
33
-
34
- MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
35
- TIMEOUT = float(os.environ.get("TIMEOUT", 0))
36
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
37
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
38
-
39
- WIDTH = 512
40
- HEIGHT = 512
41
-
42
- # check if MPS is available OSX only M1/M2/M3 chips
43
- mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
44
- xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
45
- device = torch.device(
46
- "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
47
- )
48
- torch_device = device
49
- # change to torch.float16 to save GPU memory
50
- torch_dtype = torch.float
51
-
52
- print(f"TIMEOUT: {TIMEOUT}")
53
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
54
- print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
55
- print(f"device: {device}")
56
-
57
- if mps_available:
58
- device = torch.device("mps")
59
- torch_device = "cpu"
60
- torch_dtype = torch.float32
61
-
62
- model_id = "wavymulder/Analog-Diffusion"
63
- lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
64
-
65
- if SAFETY_CHECKER == "True":
66
- pipe = DiffusionPipeline.from_pretrained(model_id)
67
- else:
68
- pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
69
-
70
-
71
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
72
- pipe.set_progress_bar_config(disable=True)
73
- pipe.to(device=torch_device, dtype=torch_dtype).to(device)
74
- pipe.unet.to(memory_format=torch.channels_last)
75
-
76
- # check if computer has less than 64GB of RAM using sys or os
77
- if psutil.virtual_memory().total < 64 * 1024**3:
78
- pipe.enable_attention_slicing()
79
-
80
- if TORCH_COMPILE:
81
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
82
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
83
-
84
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
85
-
86
- # Load LCM LoRA
87
- pipe.load_lora_weights(
88
- lcm_lora_id,
89
- adapter_name="lcm"
90
- )
91
-
92
- compel_proc = Compel(
93
- tokenizer=pipe.tokenizer,
94
- text_encoder=pipe.text_encoder,
95
- truncate_long_prompts=False,
96
- )
97
- user_queue_map = {}
98
-
99
-
100
- class InputParams(BaseModel):
101
- seed: int = 2159232
102
- prompt: str
103
- guidance_scale: float = 0.5
104
- strength: float = 0.5
105
- steps: int = 4
106
- lcm_steps: int = 50
107
- width: int = WIDTH
108
- height: int = HEIGHT
109
-
110
-
111
- def predict(params: InputParams):
112
- generator = torch.manual_seed(params.seed)
113
- prompt_embeds = compel_proc(params.prompt)
114
- results = pipe(
115
- prompt_embeds=prompt_embeds,
116
- generator=generator,
117
- num_inference_steps=params.steps,
118
- guidance_scale=params.guidance_scale,
119
- width=params.width,
120
- height=params.height,
121
- output_type="pil",
122
- )
123
- nsfw_content_detected = (
124
- results.nsfw_content_detected[0]
125
- if "nsfw_content_detected" in results
126
- else False
127
- )
128
- if nsfw_content_detected:
129
- return None
130
- return results.images[0]
131
-
132
-
133
- app = FastAPI()
134
- app.add_middleware(
135
- CORSMiddleware,
136
- allow_origins=["*"],
137
- allow_credentials=True,
138
- allow_methods=["*"],
139
- allow_headers=["*"],
140
- )
141
-
142
-
143
- @app.websocket("/ws")
144
- async def websocket_endpoint(websocket: WebSocket):
145
- await websocket.accept()
146
- if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
147
- print("Server is full")
148
- await websocket.send_json({"status": "error", "message": "Server is full"})
149
- await websocket.close()
150
- return
151
-
152
- try:
153
- uid = str(uuid.uuid4())
154
- print(f"New user connected: {uid}")
155
- await websocket.send_json(
156
- {"status": "success", "message": "Connected", "userId": uid}
157
- )
158
- user_queue_map[uid] = {
159
- "queue": asyncio.Queue(),
160
- }
161
- await websocket.send_json(
162
- {"status": "start", "message": "Start Streaming", "userId": uid}
163
- )
164
- await handle_websocket_data(websocket, uid)
165
- except WebSocketDisconnect as e:
166
- logging.error(f"WebSocket Error: {e}, {uid}")
167
- traceback.print_exc()
168
- finally:
169
- print(f"User disconnected: {uid}")
170
- queue_value = user_queue_map.pop(uid, None)
171
- queue = queue_value.get("queue", None)
172
- if queue:
173
- while not queue.empty():
174
- try:
175
- queue.get_nowait()
176
- except asyncio.QueueEmpty:
177
- continue
178
-
179
-
180
- @app.get("/queue_size")
181
- async def get_queue_size():
182
- queue_size = len(user_queue_map)
183
- return JSONResponse({"queue_size": queue_size})
184
-
185
-
186
- @app.get("/stream/{user_id}")
187
- async def stream(user_id: uuid.UUID):
188
- uid = str(user_id)
189
- try:
190
- user_queue = user_queue_map[uid]
191
- queue = user_queue["queue"]
192
-
193
- async def generate():
194
- while True:
195
- params = await queue.get()
196
- if params is None:
197
- continue
198
-
199
- image = predict(params)
200
- if image is None:
201
- continue
202
- frame_data = io.BytesIO()
203
- image.save(frame_data, format="JPEG")
204
- frame_data = frame_data.getvalue()
205
- if frame_data is not None and len(frame_data) > 0:
206
- yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
207
-
208
- await asyncio.sleep(1.0 / 120.0)
209
-
210
- return StreamingResponse(
211
- generate(), media_type="multipart/x-mixed-replace;boundary=frame"
212
- )
213
- except Exception as e:
214
- logging.error(f"Streaming Error: {e}, {user_queue_map}")
215
- traceback.print_exc()
216
- return HTTPException(status_code=404, detail="User not found")
217
-
218
-
219
- async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
220
- uid = str(user_id)
221
- user_queue = user_queue_map[uid]
222
- queue = user_queue["queue"]
223
- if not queue:
224
- return HTTPException(status_code=404, detail="User not found")
225
- last_time = time.time()
226
- try:
227
- while True:
228
- params = await websocket.receive_json()
229
- params = InputParams(**params)
230
- while not queue.empty():
231
- try:
232
- queue.get_nowait()
233
- except asyncio.QueueEmpty:
234
- continue
235
- await queue.put(params)
236
- if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
237
- await websocket.send_json(
238
- {
239
- "status": "timeout",
240
- "message": "Your session has ended",
241
- "userId": uid,
242
- }
243
- )
244
- await websocket.close()
245
- return
246
-
247
- except Exception as e:
248
- logging.error(f"Error: {e}")
249
- traceback.print_exc()
250
-
251
-
252
- @app.get("/", response_class=HTMLResponse)
253
- async def root():
254
- return FileResponse("./static/txt2imglora.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/lib/components/Checkbox.svelte CHANGED
@@ -1,7 +1,11 @@
1
  <script lang="ts">
2
  import type { FieldProps } from '$lib/types';
 
3
  export let value = false;
4
  export let params: FieldProps;
 
 
 
5
  </script>
6
 
7
  <div class="grid max-w-md grid-cols-4 items-center justify-items-start gap-3">
 
1
  <script lang="ts">
2
  import type { FieldProps } from '$lib/types';
3
+ import { onMount } from 'svelte';
4
  export let value = false;
5
  export let params: FieldProps;
6
+ onMount(() => {
7
+ value = Boolean(params?.default) ?? 8.0;
8
+ });
9
  </script>
10
 
11
  <div class="grid max-w-md grid-cols-4 items-center justify-items-start gap-3">
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -20,7 +20,7 @@
20
  {#if params.field === FieldType.RANGE}
21
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
22
  {:else if params.field === FieldType.SEED}
23
- <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
24
  {:else if params.field === FieldType.TEXTAREA}
25
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
  {:else if params.field === FieldType.CHECKBOX}
@@ -30,7 +30,7 @@
30
  {/if}
31
  </div>
32
 
33
- <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
  <div
36
  class="grid grid-cols-1 items-center gap-3 {pipelineParams.length > 5 ? 'sm:grid-cols-2' : ''}"
@@ -40,7 +40,7 @@
40
  {#if params.field === FieldType.RANGE}
41
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
42
  {:else if params.field === FieldType.SEED}
43
- <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
44
  {:else if params.field === FieldType.TEXTAREA}
45
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
  {:else if params.field === FieldType.CHECKBOX}
 
20
  {#if params.field === FieldType.RANGE}
21
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
22
  {:else if params.field === FieldType.SEED}
23
+ <SeedInput {params} bind:value={$pipelineValues[params.id]}></SeedInput>
24
  {:else if params.field === FieldType.TEXTAREA}
25
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
  {:else if params.field === FieldType.CHECKBOX}
 
30
  {/if}
31
  </div>
32
 
33
+ <details>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
  <div
36
  class="grid grid-cols-1 items-center gap-3 {pipelineParams.length > 5 ? 'sm:grid-cols-2' : ''}"
 
40
  {#if params.field === FieldType.RANGE}
41
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
42
  {:else if params.field === FieldType.SEED}
43
+ <SeedInput {params} bind:value={$pipelineValues[params.id]}></SeedInput>
44
  {:else if params.field === FieldType.TEXTAREA}
45
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
  {:else if params.field === FieldType.CHECKBOX}
frontend/src/lib/components/SeedInput.svelte CHANGED
@@ -1,7 +1,13 @@
1
  <script lang="ts">
 
 
2
  import Button from './Button.svelte';
3
  export let value = 299792458;
 
4
 
 
 
 
5
  function randomize() {
6
  value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER);
7
  }
 
1
  <script lang="ts">
2
+ import type { FieldProps } from '$lib/types';
3
+ import { onMount } from 'svelte';
4
  import Button from './Button.svelte';
5
  export let value = 299792458;
6
+ export let params: FieldProps;
7
 
8
+ onMount(() => {
9
+ value = Number(params?.default ?? '');
10
+ });
11
  function randomize() {
12
  value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER);
13
  }
frontend/src/lib/types.ts CHANGED
@@ -22,6 +22,9 @@ export interface FieldProps {
22
  id: string;
23
  }
24
  export interface PipelineInfo {
 
 
 
25
  name: string;
26
  description: string;
27
  input_mode: {
 
22
  id: string;
23
  }
24
  export interface PipelineInfo {
25
+ title: {
26
+ default: string;
27
+ }
28
  name: string;
29
  description: string;
30
  input_mode: {
frontend/src/routes/+page.svelte CHANGED
@@ -78,6 +78,9 @@
78
  <main class="container mx-auto flex max-w-4xl flex-col gap-3 px-4 py-4">
79
  <article class="flex- mx-auto max-w-xl text-center">
80
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
 
 
 
81
  <p class="py-2 text-sm">
82
  This demo showcases
83
  <a
 
78
  <main class="container mx-auto flex max-w-4xl flex-col gap-3 px-4 py-4">
79
  <article class="flex- mx-auto max-w-xl text-center">
80
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
81
+ {#if pipelineInfo?.title?.default}
82
+ <h3 class="text-xl font-bold">{pipelineInfo?.title?.default}</h3>
83
+ {/if}
84
  <p class="py-2 text-sm">
85
  This demo showcases
86
  <a
pipelines/txt2img.py CHANGED
@@ -21,6 +21,7 @@ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intric
21
  class Pipeline:
22
  class Info(BaseModel):
23
  name: str = "txt2img"
 
24
  description: str = "Generates an image from a text prompt"
25
  input_mode: str = "text"
26
 
 
21
  class Pipeline:
22
  class Info(BaseModel):
23
  name: str = "txt2img"
24
+ title: str = "txt2img"
25
  description: str = "Generates an image from a text prompt"
26
  input_mode: str = "text"
27
 
pipelines/txt2imglora.py CHANGED
@@ -1,4 +1,4 @@
1
- from diffusers import DiffusionPipeline, AutoencoderTiny
2
  from compel import Compel
3
  import torch
4
 
@@ -9,85 +9,108 @@ except:
9
 
10
  import psutil
11
  from config import Args
12
- from pydantic import BaseModel
13
  from PIL import Image
14
- from typing import Callable
15
 
16
- base_model = "SimianLuo/LCM_Dreamshaper_v7"
17
- WIDTH = 512
18
- HEIGHT = 512
19
-
20
- model_id = "wavymulder/Analog-Diffusion"
21
  lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
 
 
 
22
 
23
 
24
  class Pipeline:
 
 
 
 
 
 
25
  class InputParams(BaseModel):
26
- seed: int = 2159232
27
- prompt: str
28
- guidance_scale: float = 8.0
29
- strength: float = 0.5
30
- steps: int = 4
31
- lcm_steps: int = 50
32
- width: int = WIDTH
33
- height: int = HEIGHT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- @staticmethod
36
- def create_pipeline(
37
- args: Args, device: torch.device, torch_dtype: torch.dtype
38
- ) -> Callable[["Pipeline.InputParams"], Image.Image]:
39
  if args.safety_checker:
40
- pipe = DiffusionPipeline.from_pretrained(base_model)
41
  else:
42
- pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
 
 
43
  if args.use_taesd:
44
- pipe.vae = AutoencoderTiny.from_pretrained(
45
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
46
  )
 
 
 
 
47
 
48
- pipe.set_progress_bar_config(disable=True)
49
- pipe.to(device=device, dtype=torch_dtype)
50
- pipe.unet.to(memory_format=torch.channels_last)
51
-
52
- # Load LCM LoRA
53
- pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
54
  # check if computer has less than 64GB of RAM using sys or os
55
  if psutil.virtual_memory().total < 64 * 1024**3:
56
- pipe.enable_attention_slicing()
57
 
58
  if args.torch_compile:
59
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
60
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
 
 
 
 
61
 
62
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
63
 
64
- compel_proc = Compel(
65
- tokenizer=pipe.tokenizer,
66
- text_encoder=pipe.text_encoder,
 
 
67
  truncate_long_prompts=False,
68
  )
69
 
70
- def predict(params: "Pipeline.InputParams") -> Image.Image:
71
- generator = torch.manual_seed(params.seed)
72
- prompt_embeds = compel_proc(params.prompt)
73
- # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
74
- results = pipe(
75
- prompt_embeds=prompt_embeds,
76
- generator=generator,
77
- num_inference_steps=params.steps,
78
- guidance_scale=params.guidance_scale,
79
- width=params.width,
80
- height=params.height,
81
- original_inference_steps=params.lcm_steps,
82
- output_type="pil",
83
- )
84
- nsfw_content_detected = (
85
- results.nsfw_content_detected[0]
86
- if "nsfw_content_detected" in results
87
- else False
88
- )
89
- if nsfw_content_detected:
90
- return None
91
- return results.images[0]
92
-
93
- return predict
 
1
+ from diffusers import DiffusionPipeline, AutoencoderTiny, LCMScheduler
2
  from compel import Compel
3
  import torch
4
 
 
9
 
10
  import psutil
11
  from config import Args
12
+ from pydantic import BaseModel, Field
13
  from PIL import Image
 
14
 
15
+ base_model = "wavymulder/Analog-Diffusion"
 
 
 
 
16
  lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
17
+ taesd_model = "madebyollin/taesd"
18
+
19
+ default_prompt = "Analog style photograph of young Harrison Ford as Han Solo, star wars behind the scenes"
20
 
21
 
22
  class Pipeline:
23
+ class Info(BaseModel):
24
+ name: str = "txt2imglora"
25
+ title: str = "txt2imglora"
26
+ description: str = "Generates an image from a text prompt"
27
+ input_mode: str = "text"
28
+
29
  class InputParams(BaseModel):
30
+ prompt: str = Field(
31
+ default_prompt,
32
+ title="Prompt",
33
+ field="textarea",
34
+ id="prompt",
35
+ )
36
+ seed: int = Field(
37
+ 8638236174640251, min=0, title="Seed", field="seed", hide=True, id="seed"
38
+ )
39
+ steps: int = Field(
40
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
41
+ )
42
+ width: int = Field(
43
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
44
+ )
45
+ height: int = Field(
46
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
47
+ )
48
+ guidance_scale: float = Field(
49
+ 0.2,
50
+ min=0,
51
+ max=4,
52
+ step=0.001,
53
+ title="Guidance Scale",
54
+ field="range",
55
+ hide=True,
56
+ id="guidance_scale",
57
+ )
58
 
59
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
 
 
 
60
  if args.safety_checker:
61
+ self.pipe = DiffusionPipeline.from_pretrained(base_model)
62
  else:
63
+ self.pipe = DiffusionPipeline.from_pretrained(
64
+ base_model, safety_checker=None
65
+ )
66
  if args.use_taesd:
67
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
68
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
69
  )
70
+ self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
71
+ self.pipe.set_progress_bar_config(disable=True)
72
+ self.pipe.to(device=device, dtype=torch_dtype)
73
+ self.pipe.unet.to(memory_format=torch.channels_last)
74
 
 
 
 
 
 
 
75
  # check if computer has less than 64GB of RAM using sys or os
76
  if psutil.virtual_memory().total < 64 * 1024**3:
77
+ self.pipe.enable_attention_slicing()
78
 
79
  if args.torch_compile:
80
+ self.pipe.unet = torch.compile(
81
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
82
+ )
83
+ self.pipe.vae = torch.compile(
84
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
85
+ )
86
 
87
+ self.pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
88
 
89
+ self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
90
+
91
+ self.compel_proc = Compel(
92
+ tokenizer=self.pipe.tokenizer,
93
+ text_encoder=self.pipe.text_encoder,
94
  truncate_long_prompts=False,
95
  )
96
 
97
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
98
+ generator = torch.manual_seed(params.seed)
99
+ prompt_embeds = self.compel_proc(params.prompt)
100
+ results = self.pipe(
101
+ prompt_embeds=prompt_embeds,
102
+ generator=generator,
103
+ num_inference_steps=params.steps,
104
+ guidance_scale=params.guidance_scale,
105
+ width=params.width,
106
+ height=params.height,
107
+ output_type="pil",
108
+ )
109
+ nsfw_content_detected = (
110
+ results.nsfw_content_detected[0]
111
+ if "nsfw_content_detected" in results
112
+ else False
113
+ )
114
+ if nsfw_content_detected:
115
+ return None
116
+ return results.images[0]
 
 
 
 
static/txt2imglora.html DELETED
@@ -1,279 +0,0 @@
1
- <!doctype html>
2
- <html>
3
-
4
- <head>
5
- <meta charset="UTF-8">
6
- <title>Real-Time Latent Consistency Model</title>
7
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
- <script
9
- src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
- <script src="https://cdn.tailwindcss.com"></script>
12
- <style type="text/tailwindcss">
13
- .button {
14
- @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
15
- }
16
- </style>
17
- <script type="module">
18
- const getValue = (id) => {
19
- const el = document.querySelector(`${id}`)
20
- if (el.type === "checkbox")
21
- return el.checked;
22
- return el.value;
23
- }
24
- const startBtn = document.querySelector("#start");
25
- const stopBtn = document.querySelector("#stop");
26
- const videoEl = document.querySelector("#webcam");
27
- const imageEl = document.querySelector("#player");
28
- const queueSizeEl = document.querySelector("#queue_size");
29
- const errorEl = document.querySelector("#error");
30
- const snapBtn = document.querySelector("#snap");
31
- const paramsEl = document.querySelector("#params");
32
- const promptEl = document.querySelector("#prompt");
33
- paramsEl.addEventListener("submit", (e) => e.preventDefault());
34
- function LCMLive(promptEl, paramsEl, liveImage) {
35
- let websocket;
36
-
37
- async function start() {
38
- return new Promise((resolve, reject) => {
39
- const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
40
- }:${window.location.host}/ws`;
41
-
42
- const socket = new WebSocket(websocketURL);
43
- socket.onopen = () => {
44
- console.log("Connected to websocket");
45
- };
46
- socket.onclose = () => {
47
- console.log("Disconnected from websocket");
48
- stop();
49
- resolve({ "status": "disconnected" });
50
- };
51
- socket.onerror = (err) => {
52
- console.error(err);
53
- reject(err);
54
- };
55
- socket.onmessage = (event) => {
56
- const data = JSON.parse(event.data);
57
- switch (data.status) {
58
- case "success":
59
- break;
60
- case "start":
61
- const userId = data.userId;
62
- initPromptStream(userId);
63
- break;
64
- case "timeout":
65
- stop();
66
- resolve({ "status": "timeout" });
67
- case "error":
68
- stop();
69
- reject(data.message);
70
- }
71
- };
72
- websocket = socket;
73
- })
74
- }
75
-
76
- async function promptUpdateStream(e) {
77
- const [WIDTH, HEIGHT] = [512, 512];
78
- websocket.send(JSON.stringify({
79
- "seed": getValue("#seed"),
80
- "prompt": getValue("#prompt"),
81
- "guidance_scale": getValue("#guidance-scale"),
82
- "steps": getValue("#steps"),
83
- "width": WIDTH,
84
- "height": HEIGHT,
85
- }));
86
- }
87
- function debouceInput(fn, delay) {
88
- let timer;
89
- return function (...args) {
90
- clearTimeout(timer);
91
- timer = setTimeout(() => {
92
- fn(...args);
93
- }, delay);
94
- }
95
- }
96
- const debouncedInput = debouceInput(promptUpdateStream, 200);
97
- function initPromptStream(userId) {
98
- liveImage.src = `/stream/${userId}`;
99
- paramsEl.addEventListener("change", debouncedInput);
100
- promptEl.addEventListener("input", debouncedInput);
101
- }
102
-
103
- async function stop() {
104
- websocket.close();
105
- paramsEl.removeEventListener("change", debouncedInput);
106
- promptEl.removeEventListener("input", debouncedInput);
107
- }
108
- return {
109
- start,
110
- stop
111
- }
112
- }
113
- function toggleMessage(type) {
114
- errorEl.hidden = false;
115
- errorEl.scrollIntoView();
116
- switch (type) {
117
- case "error":
118
- errorEl.innerText = "To many users are using the same GPU, please try again later.";
119
- errorEl.classList.toggle("bg-red-300", "text-red-900");
120
- break;
121
- case "success":
122
- errorEl.innerText = "Your session has ended, please start a new one.";
123
- errorEl.classList.toggle("bg-green-300", "text-green-900");
124
- break;
125
- }
126
- setTimeout(() => {
127
- errorEl.hidden = true;
128
- }, 2000);
129
- }
130
- function snapImage() {
131
- try {
132
- const zeroth = {};
133
- const exif = {};
134
- const gps = {};
135
- zeroth[piexif.ImageIFD.Make] = "LCM Text-to-Image";
136
- zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | steps: ${getValue("#steps")}`;
137
- zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
138
-
139
- exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
140
-
141
- const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
142
- const exifBytes = piexif.dump(exifObj);
143
-
144
- const canvas = document.createElement("canvas");
145
- canvas.width = imageEl.naturalWidth;
146
- canvas.height = imageEl.naturalHeight;
147
- const ctx = canvas.getContext("2d");
148
- ctx.drawImage(imageEl, 0, 0);
149
- const dataURL = canvas.toDataURL("image/jpeg");
150
- const withExif = piexif.insert(exifBytes, dataURL);
151
-
152
- const a = document.createElement("a");
153
- a.href = withExif;
154
- a.download = `lcm_txt_2_img${Date.now()}.png`;
155
- a.click();
156
- } catch (err) {
157
- console.log(err);
158
- }
159
- }
160
-
161
-
162
- const lcmLive = LCMLive(promptEl, paramsEl, imageEl);
163
- startBtn.addEventListener("click", async () => {
164
- try {
165
- startBtn.disabled = true;
166
- snapBtn.disabled = false;
167
- const res = await lcmLive.start();
168
- startBtn.disabled = false;
169
- if (res.status === "timeout")
170
- toggleMessage("success")
171
- } catch (err) {
172
- console.log(err);
173
- toggleMessage("error")
174
- startBtn.disabled = false;
175
- }
176
- });
177
- stopBtn.addEventListener("click", () => {
178
- lcmLive.stop();
179
- });
180
- window.addEventListener("beforeunload", () => {
181
- lcmLive.stop();
182
- });
183
- snapBtn.addEventListener("click", snapImage);
184
- setInterval(() =>
185
- fetch("/queue_size")
186
- .then((res) => res.json())
187
- .then((data) => {
188
- queueSizeEl.innerText = data.queue_size;
189
- })
190
- .catch((err) => {
191
- console.log(err);
192
- })
193
- , 5000);
194
- </script>
195
- </head>
196
-
197
- <body class="text-black dark:bg-gray-900 dark:text-white">
198
- <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
199
- </div>  
200
- <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
201
- <article class="text-center max-w-xl mx-auto">
202
- <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
203
- <h2 class="text-2xl font-bold mb-4">Text to Image Lora</h2>
204
- <p class="text-sm">
205
- This demo showcases
206
- <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
207
- class="text-blue-500 underline hover:no-underline">LCM</a> Text to Image model
208
- using
209
- <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
210
- target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
211
- stream server. Featuring <a href="https://huggingface.co/wavymulder/Analog-Diffusion" target="_blank"
212
- class="text-blue-500 underline hover:no-underline">Analog Diffusion</a> Model.
213
- </p>
214
- <p class="text-sm">
215
- There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
216
- real-time performance.
217
- </p>
218
- </article>
219
- <div>
220
- <h2 class="font-medium">Prompt</h2>
221
- <p class="text-sm text-gray-500 dark:text-gray-400">
222
- Start your session and type your prompt here, accepts
223
- <a href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
224
- class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
225
- </p>
226
- <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
227
- <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none dark:text-black"
228
- title=" Start your session and type your prompt here, you can see the result in real-time."
229
- placeholder="Add your prompt here...">Analog style photograph of young Harrison Ford as Han Solo, star wars behind the scenes</textarea>
230
- </div>
231
-
232
- </div>
233
- <div class="">
234
- <details>
235
- <summary class="font-medium cursor-pointer">Advanced Options</summary>
236
- <form class="grid grid-cols-3 items-center gap-3 py-3" id="params" action="">
237
- <label class="text-sm font-medium " for="steps">Inference Steps
238
- </label>
239
- <input type="range" id="steps" name="steps" min="2" max="10" value="4"
240
- oninput="this.nextElementSibling.value = Number(this.value)">
241
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
242
- 4</output>
243
- <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
244
- </label>
245
- <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="5" step="0.0001"
246
- value="0.8" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
247
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
248
- 0.8</output>
249
- <!-- -->
250
- <label class="text-sm font-medium" for="seed">Seed</label>
251
- <input type="number" id="seed" name="seed" value="299792458"
252
- class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
253
- <button class="button"
254
- onclick="document.querySelector('#seed').value = Math.floor(Math.random() * 1000000000); document.querySelector('#params').dispatchEvent(new Event('change'))">
255
- Rand
256
- </button>
257
- <!-- -->
258
- </form>
259
- </details>
260
- </div>
261
- <div class="flex gap-3">
262
- <button id="start" class="button">
263
- Start
264
- </button>
265
- <button id="stop" class="button">
266
- Stop
267
- </button>
268
- <button id="snap" disabled class="button ml-auto">
269
- Snapshot
270
- </button>
271
- </div>
272
- <div class="relative rounded-lg border border-slate-300 overflow-hidden">
273
- <img id="player" class="w-full aspect-square rounded-lg"
274
- src="">
275
- </div>
276
- </main>
277
- </body>
278
-
279
- </html>