badayvedat commited on
Commit
85a99d8
·
1 Parent(s): 6f0dfc9

Initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +4 -0
  2. README.md +5 -6
  3. app.py +205 -0
  4. constants.py +209 -0
  5. gradio_examples.py +4 -0
  6. style.css +3 -0
  7. utils.py +14 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Realtime Stable Diffusion
3
- emoji:
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.3.0
8
  app_file: app.py
9
  pinned: false
10
- license: other
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Real Time Latent Consistency Models
3
+ emoji: 👀
4
+ colorFrom: pink
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import deque
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import gradio as gr
7
+ import websockets
8
+ from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64
9
+ from PIL import Image
10
+ from websockets.sync.client import connect
11
+
12
+ from constants import DESCRIPTION, WS_ADDRESS, LOGO
13
+ from utils import replace_background
14
+
15
+ from gradio_examples import EXAMPLES
16
+
17
+ MAX_QUEUE_SIZE = 4
18
+
19
+
20
+ @dataclass
21
+ class GenerationState:
22
+ prompts: deque
23
+ responses: deque
24
+
25
+
26
+ def get_initial_state() -> GenerationState:
27
+ return GenerationState(
28
+ prompts=deque(maxlen=MAX_QUEUE_SIZE),
29
+ responses=deque(maxlen=MAX_QUEUE_SIZE),
30
+ )
31
+
32
+
33
+ def load_initial_state(request: gr.Request) -> GenerationState:
34
+ print("Loading initial state for", request.client.host)
35
+
36
+ return get_initial_state()
37
+
38
+
39
+ async def put_to_queue(
40
+ image: Optional[Image.Image],
41
+ prompt: str,
42
+ seed: int,
43
+ strength: float,
44
+ state: GenerationState,
45
+ ):
46
+ prompts_queue = state.prompts
47
+
48
+ if prompt and image is not None:
49
+ prompts_queue.append((image, prompt, seed, strength))
50
+
51
+ return state
52
+
53
+
54
+ def send_inference_request(state: GenerationState) -> Image.Image:
55
+ prompts_queue = state.prompts
56
+ response_queue = state.responses
57
+
58
+ if len(prompts_queue) == 0:
59
+ return state
60
+
61
+ image, prompt, seed, strength = prompts_queue.popleft()
62
+ original_image_size = image.size
63
+ image = replace_background(image.resize((512, 512)))
64
+
65
+ arguments = {
66
+ "prompt": prompt,
67
+ "image_url": encode_pil_to_base64(image),
68
+ "strength": strength,
69
+ "negative_prompt": "cartoon, illustration, animation. face. male, female",
70
+ "seed": seed,
71
+ "guidance_scale": 1,
72
+ "num_inference_steps": 4,
73
+ "sync_mode": 1,
74
+ "num_images": 1,
75
+ }
76
+
77
+ connection = connect(WS_ADDRESS)
78
+ connection.send(json.dumps(arguments))
79
+
80
+ try:
81
+ response = json.loads(connection.recv())
82
+ except websockets.exceptions.ConnectionClosedOK:
83
+ print("Connection closed, reconnecting...")
84
+ # TODO: This is a hacky way to reconnect, but it works for now
85
+ # Ideally, we should be able to reconnect to the same connection
86
+ # and not have to create a new one
87
+ connection = connect(WS_ADDRESS)
88
+ try:
89
+ response = json.loads(connection.recv())
90
+ except websockets.exceptions.ConnectionClosedOK:
91
+ print("Connection closed again, aborting...")
92
+ return state
93
+
94
+ # TODO: If a new connection is created, the response do not contain the images.
95
+ if "images" in response:
96
+ response_queue.append((response, original_image_size))
97
+
98
+ return state
99
+
100
+
101
+ def update_output_image(state: GenerationState):
102
+ image_update = gr.update()
103
+ inference_time_update = gr.update()
104
+
105
+ response_queue = state.responses
106
+
107
+ if len(response_queue) > 0:
108
+ response, original_image_size = response_queue.popleft()
109
+ generated_image = decode_base64_to_image(response["images"][0]["url"])
110
+ inference_time = response["timings"]["inference"]
111
+
112
+ generated_image.resize(original_image_size).save("generated.png")
113
+
114
+ image_update = gr.update(value=generated_image.resize(original_image_size))
115
+ inference_time_update = gr.update(value=round(inference_time, 4))
116
+
117
+ return image_update, inference_time_update, state
118
+
119
+
120
+ with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo:
121
+ generation_state = gr.State(get_initial_state())
122
+
123
+ gr.HTML(f'<div style="width: 70px;">{LOGO}</div>')
124
+ gr.Markdown(DESCRIPTION)
125
+ with gr.Row(variant="default"):
126
+ input_image = gr.Image(
127
+ tool="color-sketch",
128
+ source="canvas",
129
+ label="Initial Image",
130
+ type="pil",
131
+ height=512,
132
+ width=512,
133
+ brush_radius=40.0,
134
+ )
135
+
136
+ output_image = gr.Image(
137
+ label="Generated Image",
138
+ type="pil",
139
+ interactive=False,
140
+ elem_id="output_image",
141
+ )
142
+ with gr.Row():
143
+ with gr.Column(scale=23):
144
+ prompt_box = gr.Textbox(label="Prompt")
145
+ with gr.Column(scale=1):
146
+ inference_time_box = gr.Number(
147
+ label="Inference Time (s)", interactive=False
148
+ )
149
+
150
+ with gr.Accordion(label="Advanced Options", open=False):
151
+ with gr.Row():
152
+ with gr.Column():
153
+ strength = gr.Slider(
154
+ label="Strength",
155
+ minimum=0.1,
156
+ maximum=1.0,
157
+ step=0.05,
158
+ value=0.8,
159
+ info="""
160
+ Strength of the initial image that will be applied during inference.
161
+ """,
162
+ )
163
+ with gr.Column():
164
+ seed = gr.Slider(
165
+ label="Seed",
166
+ minimum=0,
167
+ maximum=2**31 - 1,
168
+ step=1,
169
+ randomize=True,
170
+ info="""
171
+ Seed for the random number generator.
172
+ """,
173
+ )
174
+
175
+ demo.load(
176
+ load_initial_state,
177
+ outputs=[generation_state],
178
+ )
179
+ demo.load(
180
+ send_inference_request,
181
+ inputs=[generation_state],
182
+ outputs=[generation_state],
183
+ every=0.1,
184
+ )
185
+ demo.load(
186
+ update_output_image,
187
+ inputs=[generation_state],
188
+ outputs=[output_image, inference_time_box, generation_state],
189
+ every=0.1,
190
+ )
191
+
192
+ for event in [input_image.change, prompt_box.change, strength.change, seed.change]:
193
+ event(
194
+ put_to_queue,
195
+ [input_image, prompt_box, seed, strength, generation_state],
196
+ [generation_state],
197
+ show_progress=False,
198
+ queue=True,
199
+ )
200
+
201
+ gr.Markdown("## Example Prompts")
202
+ gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples")
203
+
204
+ if __name__ == "__main__":
205
+ demo.queue(concurrency_count=4, api_open=False).launch()
constants.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ DESCRIPTION = """
4
+ # Real Time Latent Consistency Model Hosted on [fal.ai](https://fal.ai)
5
+ """
6
+
7
+ WS_ADDRESS = os.environ["WS_ADDRESS"]
8
+
9
+
10
+ LOGO = """
11
+ <svg
12
+ width="100%"
13
+ height="100%"
14
+ viewBox="0 0 89 32"
15
+ fill="none"
16
+ xmlns="http://www.w3.org/2000/svg"
17
+ >
18
+ <path
19
+ d="M52.308 3.07812H57.8465V4.92428H56.0003V6.77043H54.1541V10.4627H57.8465V12.3089H54.1541V25.232H52.308V27.0781H46.7695V25.232H48.6157V12.3089H46.7695V10.4627H48.6157V6.77043H50.4618V4.92428H52.308V3.07812Z"
20
+ fill="currentColor"
21
+ ></path>
22
+ <path
23
+ d="M79.3849 23.3858H81.2311V25.232H83.0772V27.0781H88.6157V25.232H86.7695V23.3858H84.9234V4.92428H79.3849V23.3858Z"
24
+ fill="currentColor"
25
+ ></path>
26
+ <path
27
+ d="M57.8465 14.155H59.6926V12.3089H61.5388V10.4627H70.7695V12.3089H74.4618V23.3858H76.308V25.232H78.1541V27.0781H72.6157V25.232H70.7695V23.3858H68.9234V14.155H67.0772V12.3089H65.2311V14.155H63.3849V23.3858H65.2311V25.232H67.0772V27.0781H61.5388V25.232H59.6926V23.3858H57.8465V14.155Z"
28
+ fill="currentColor"
29
+ ></path>
30
+ <path
31
+ d="M67.0772 25.232V23.3858H68.9234V25.232H67.0772Z"
32
+ fill="currentColor"
33
+ ></path>
34
+ <rect
35
+ opacity="0.22"
36
+ x="7.38477"
37
+ y="29.5391"
38
+ width="2.46154"
39
+ height="2.46154"
40
+ fill="#5F4CD9"
41
+ ></rect>
42
+ <rect
43
+ opacity="0.85"
44
+ x="2.46094"
45
+ y="19.6914"
46
+ width="12.3077"
47
+ height="2.46154"
48
+ fill="#5F4CD9"
49
+ ></rect>
50
+ <rect
51
+ x="4.92383"
52
+ y="17.2305"
53
+ width="9.84615"
54
+ height="2.46154"
55
+ fill="#5F4CD9"
56
+ ></rect>
57
+ <rect
58
+ opacity="0.4"
59
+ x="7.38477"
60
+ y="27.0781"
61
+ width="4.92308"
62
+ height="2.46154"
63
+ fill="#5F4CD9"
64
+ ></rect>
65
+ <rect
66
+ opacity="0.7"
67
+ y="22.1562"
68
+ width="14.7692"
69
+ height="2.46154"
70
+ fill="#5F4CD9"
71
+ ></rect>
72
+ <rect
73
+ opacity="0.5"
74
+ x="7.38477"
75
+ y="24.6133"
76
+ width="7.38462"
77
+ height="2.46154"
78
+ fill="#5F4CD9"
79
+ ></rect>
80
+ <rect
81
+ opacity="0.22"
82
+ x="7.38477"
83
+ y="12.3086"
84
+ width="2.46154"
85
+ height="2.46154"
86
+ fill="#5F4CD9"
87
+ ></rect>
88
+ <rect
89
+ opacity="0.85"
90
+ x="2.46094"
91
+ y="2.46094"
92
+ width="12.3077"
93
+ height="2.46154"
94
+ fill="#5F4CD9"
95
+ ></rect>
96
+ <rect x="4.92383" width="9.84615" height="2.46154" fill="#5F4CD9"></rect>
97
+ <rect
98
+ opacity="0.4"
99
+ x="7.38477"
100
+ y="9.84375"
101
+ width="4.92308"
102
+ height="2.46154"
103
+ fill="#5F4CD9"
104
+ ></rect>
105
+ <rect
106
+ opacity="0.7"
107
+ y="4.92188"
108
+ width="14.7692"
109
+ height="2.46154"
110
+ fill="#5F4CD9"
111
+ ></rect>
112
+ <rect
113
+ opacity="0.5"
114
+ x="7.38477"
115
+ y="7.38281"
116
+ width="7.38462"
117
+ height="2.46154"
118
+ fill="#5F4CD9"
119
+ ></rect>
120
+ <rect
121
+ opacity="0.22"
122
+ x="24.6152"
123
+ y="29.5391"
124
+ width="2.46154"
125
+ height="2.46154"
126
+ fill="#5F4CD9"
127
+ ></rect>
128
+ <rect
129
+ opacity="0.85"
130
+ x="19.6914"
131
+ y="19.6914"
132
+ width="12.3077"
133
+ height="2.46154"
134
+ fill="#5F4CD9"
135
+ ></rect>
136
+ <rect
137
+ x="22.1543"
138
+ y="17.2305"
139
+ width="9.84615"
140
+ height="2.46154"
141
+ fill="#5F4CD9"
142
+ ></rect>
143
+ <rect
144
+ opacity="0.4"
145
+ x="24.6152"
146
+ y="27.0781"
147
+ width="4.92308"
148
+ height="2.46154"
149
+ fill="#5F4CD9"
150
+ ></rect>
151
+ <rect
152
+ opacity="0.7"
153
+ x="17.2305"
154
+ y="22.1562"
155
+ width="14.7692"
156
+ height="2.46154"
157
+ fill="#5F4CD9"
158
+ ></rect>
159
+ <rect
160
+ opacity="0.5"
161
+ x="24.6152"
162
+ y="24.6133"
163
+ width="7.38462"
164
+ height="2.46154"
165
+ fill="#5F4CD9"
166
+ ></rect>
167
+ <rect
168
+ opacity="0.22"
169
+ x="24.6152"
170
+ y="12.3086"
171
+ width="2.46154"
172
+ height="2.46154"
173
+ fill="#5F4CD9"
174
+ ></rect>
175
+ <rect
176
+ opacity="0.85"
177
+ x="19.6914"
178
+ y="2.46094"
179
+ width="12.3077"
180
+ height="2.46154"
181
+ fill="#5F4CD9"
182
+ ></rect>
183
+ <rect x="22.1543" width="9.84615" height="2.46154" fill="#5F4CD9"></rect>
184
+ <rect
185
+ opacity="0.4"
186
+ x="24.6152"
187
+ y="9.84375"
188
+ width="4.92308"
189
+ height="2.46154"
190
+ fill="#5F4CD9"
191
+ ></rect>
192
+ <rect
193
+ opacity="0.7"
194
+ x="17.2305"
195
+ y="4.92188"
196
+ width="14.7692"
197
+ height="2.46154"
198
+ fill="#5F4CD9"
199
+ ></rect>
200
+ <rect
201
+ opacity="0.5"
202
+ x="24.6152"
203
+ y="7.38281"
204
+ width="7.38462"
205
+ height="2.46154"
206
+ fill="#5F4CD9"
207
+ ></rect>
208
+ </svg>
209
+ """
gradio_examples.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ EXAMPLES = [
2
+ "a house on the water, oil painting",
3
+ "a sunset at a tropical beach with palm trees",
4
+ ]
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ def replace_background(image: Image.Image, new_background_color=(0, 255, 255)):
5
+ image_np = np.array(image)
6
+
7
+ white_threshold = 255 * 3
8
+ white_pixels = np.sum(image_np, axis=-1) == white_threshold
9
+
10
+ image_np[white_pixels] = new_background_color
11
+
12
+ result = Image.fromarray(image_np)
13
+
14
+ return result