Spaces:
Running
Running
Ron Au
commited on
Commit
·
4c519fd
1
Parent(s):
b49a47d
feat(eta): Improve duration UX
Browse files- Render result without waiting for last poll interval to complete
- Calculate ETA based on past completions
- Update ETA during generation based on place in queue
- app.py +32 -3
- modules/inference.py +4 -8
- static/index.js +29 -21
- templates/index.html +1 -1
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
import time
|
|
|
2 |
from flask import Flask, jsonify, render_template, request
|
3 |
|
4 |
from modules.details import load_lists, rand_details
|
@@ -6,6 +7,8 @@ from modules.inference import generate_image
|
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
|
|
|
|
9 |
|
10 |
@app.route('/')
|
11 |
def index():
|
@@ -19,10 +22,13 @@ tasks = {}
|
|
19 |
def create_task():
|
20 |
prompt = request.args.get('prompt') or "покемон"
|
21 |
|
22 |
-
|
|
|
|
|
23 |
|
24 |
tasks[task_id] = {
|
25 |
"task_id": task_id,
|
|
|
26 |
"prompt": prompt,
|
27 |
"status": "pending",
|
28 |
"poll_count": 0,
|
@@ -37,7 +43,9 @@ def queue_task():
|
|
37 |
|
38 |
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
|
39 |
|
40 |
-
tasks[task_id]["status"] = "
|
|
|
|
|
41 |
|
42 |
return jsonify(tasks[task_id])
|
43 |
|
@@ -46,6 +54,27 @@ def queue_task():
|
|
46 |
def poll_task():
|
47 |
task_id = request.args.get('task_id')
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
tasks[task_id]["poll_count"] += 1
|
50 |
|
51 |
return jsonify(tasks[task_id])
|
|
|
1 |
+
from time import time
|
2 |
+
from statistics import mean
|
3 |
from flask import Flask, jsonify, render_template, request
|
4 |
|
5 |
from modules.details import load_lists, rand_details
|
|
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
+
TEMPLATES_AUTO_RELOAD = True
|
11 |
+
|
12 |
|
13 |
@app.route('/')
|
14 |
def index():
|
|
|
22 |
def create_task():
|
23 |
prompt = request.args.get('prompt') or "покемон"
|
24 |
|
25 |
+
created_at = time()
|
26 |
+
|
27 |
+
task_id = f"{str(created_at)}_{prompt}"
|
28 |
|
29 |
tasks[task_id] = {
|
30 |
"task_id": task_id,
|
31 |
+
"created_at": created_at,
|
32 |
"prompt": prompt,
|
33 |
"status": "pending",
|
34 |
"poll_count": 0,
|
|
|
43 |
|
44 |
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
|
45 |
|
46 |
+
tasks[task_id]["status"] = "completed"
|
47 |
+
|
48 |
+
tasks[task_id]["completed_at"] = time()
|
49 |
|
50 |
return jsonify(tasks[task_id])
|
51 |
|
|
|
54 |
def poll_task():
|
55 |
task_id = request.args.get('task_id')
|
56 |
|
57 |
+
pending_tasks = []
|
58 |
+
completed_durations = []
|
59 |
+
|
60 |
+
for task in tasks.values():
|
61 |
+
if task["status"] == "pending":
|
62 |
+
pending_tasks.append(task["task_id"])
|
63 |
+
elif task["status"] == "completed":
|
64 |
+
completed_durations.append(task["completed_at"] - task["created_at"])
|
65 |
+
|
66 |
+
try:
|
67 |
+
place_in_queue = pending_tasks.index(task_id) + 1
|
68 |
+
except:
|
69 |
+
place_in_queue = 0
|
70 |
+
|
71 |
+
if (len(completed_durations)):
|
72 |
+
eta = sum(completed_durations) / len(completed_durations) * (place_in_queue or 1)
|
73 |
+
else:
|
74 |
+
eta = 40 * (place_in_queue or 1)
|
75 |
+
|
76 |
+
tasks[task_id]["place_in_queue"] = place_in_queue
|
77 |
+
tasks[task_id]["eta"] = round(eta, 1)
|
78 |
tasks[task_id]["poll_count"] += 1
|
79 |
|
80 |
return jsonify(tasks[task_id])
|
modules/inference.py
CHANGED
@@ -13,14 +13,11 @@ fp16 = torch.cuda.is_available()
|
|
13 |
|
14 |
file_dir = "./models"
|
15 |
file_name = "pytorch_model.bin"
|
16 |
-
config_file_url = hf_hub_url(
|
17 |
-
repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
|
18 |
cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)
|
19 |
|
20 |
-
model = get_rudalle_model('Malevich', pretrained=False,
|
21 |
-
|
22 |
-
model.load_state_dict(torch.load(
|
23 |
-
f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))
|
24 |
|
25 |
vae = get_vae().to(device)
|
26 |
tokenizer = get_tokenizer()
|
@@ -50,8 +47,7 @@ def generate_image(prompt):
|
|
50 |
if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
|
51 |
prompt = english_to_russian(prompt)
|
52 |
|
53 |
-
result, _ = generate_images(
|
54 |
-
prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
|
55 |
|
56 |
buffer = BytesIO()
|
57 |
result[0].save(buffer, format="PNG")
|
|
|
13 |
|
14 |
file_dir = "./models"
|
15 |
file_name = "pytorch_model.bin"
|
16 |
+
config_file_url = hf_hub_url(repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
|
|
|
17 |
cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)
|
18 |
|
19 |
+
model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)
|
20 |
+
model.load_state_dict(torch.load(f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))
|
|
|
|
|
21 |
|
22 |
vae = get_vae().to(device)
|
23 |
tokenizer = get_tokenizer()
|
|
|
47 |
if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
|
48 |
prompt = english_to_russian(prompt)
|
49 |
|
50 |
+
result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
|
|
|
51 |
|
52 |
buffer = BytesIO()
|
53 |
result[0].save(buffer, format="PNG")
|
static/index.js
CHANGED
@@ -129,8 +129,9 @@ const createTask = async (prompt) => {
|
|
129 |
return task;
|
130 |
};
|
131 |
|
132 |
-
const queueTask = (task_id) => {
|
133 |
-
fetch(`${getBasePath()}task/queue?task_id=${task_id}`);
|
|
|
134 |
};
|
135 |
|
136 |
const pollTask = async (task) => {
|
@@ -140,18 +141,16 @@ const pollTask = async (task) => {
|
|
140 |
};
|
141 |
|
142 |
const longPollTask = async (task, interval = 10_000, max) => {
|
143 |
-
|
144 |
-
return task;
|
145 |
-
}
|
146 |
|
147 |
-
|
148 |
|
149 |
-
task
|
150 |
-
|
151 |
-
if (task.status === 'complete' || task.poll_count > max) {
|
152 |
return task;
|
153 |
}
|
154 |
|
|
|
|
|
155 |
await new Promise((resolve) => setTimeout(resolve, interval));
|
156 |
|
157 |
return await longPollTask(task, interval, max);
|
@@ -162,26 +161,34 @@ const longPollTask = async (task, interval = 10_000, max) => {
|
|
162 |
const generateButton = document.querySelector('button.generate');
|
163 |
|
164 |
const durationTimer = () => {
|
|
|
165 |
let duration = 0.0;
|
166 |
|
167 |
-
return (
|
168 |
const startTime = performance.now();
|
169 |
|
170 |
const incrementSeconds = setInterval(() => {
|
171 |
duration += 0.1;
|
172 |
-
|
173 |
}, 100);
|
174 |
|
175 |
-
const updateDuration = () =>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
window.addEventListener('focus', updateDuration);
|
178 |
|
179 |
return {
|
180 |
-
cleanup: () => {
|
181 |
-
updateDuration();
|
182 |
clearInterval(incrementSeconds);
|
183 |
window.removeEventListener('focus', updateDuration);
|
184 |
-
|
185 |
},
|
186 |
};
|
187 |
};
|
@@ -238,7 +245,7 @@ generateButton.addEventListener('click', async () => {
|
|
238 |
}
|
239 |
|
240 |
const renderSection = document.querySelector('section.render');
|
241 |
-
const
|
242 |
const initialiseCardRotation = cardRotationInitiator(renderSection);
|
243 |
|
244 |
try {
|
@@ -246,14 +253,15 @@ generateButton.addEventListener('click', async () => {
|
|
246 |
|
247 |
const details = await generateDetails();
|
248 |
const task = await createTask(details.energy_type);
|
249 |
-
queueTask(task.task_id);
|
250 |
|
251 |
-
const timer = durationTimer();
|
252 |
-
const
|
|
|
|
|
|
|
253 |
|
254 |
-
const completedTask = await longPollTask(task);
|
255 |
generating = false;
|
256 |
-
|
257 |
|
258 |
renderSection.innerHTML = cardHTML(details);
|
259 |
const picture = document.querySelector('img.picture');
|
|
|
129 |
return task;
|
130 |
};
|
131 |
|
132 |
+
const queueTask = async (task_id) => {
|
133 |
+
const queueResponse = await fetch(`${getBasePath()}task/queue?task_id=${task_id}`);
|
134 |
+
return queueResponse.json();
|
135 |
};
|
136 |
|
137 |
const pollTask = async (task) => {
|
|
|
141 |
};
|
142 |
|
143 |
const longPollTask = async (task, interval = 10_000, max) => {
|
144 |
+
const etaDisplay = document.querySelector('.eta');
|
|
|
|
|
145 |
|
146 |
+
task = await pollTask(task);
|
147 |
|
148 |
+
if (task.status === 'completed' || (max && task.poll_count > max)) {
|
|
|
|
|
149 |
return task;
|
150 |
}
|
151 |
|
152 |
+
etaDisplay.textContent = Math.round(task.eta);
|
153 |
+
|
154 |
await new Promise((resolve) => setTimeout(resolve, interval));
|
155 |
|
156 |
return await longPollTask(task, interval, max);
|
|
|
161 |
const generateButton = document.querySelector('button.generate');
|
162 |
|
163 |
const durationTimer = () => {
|
164 |
+
const elapsedDisplay = document.querySelector('.elapsed');
|
165 |
let duration = 0.0;
|
166 |
|
167 |
+
return () => {
|
168 |
const startTime = performance.now();
|
169 |
|
170 |
const incrementSeconds = setInterval(() => {
|
171 |
duration += 0.1;
|
172 |
+
elapsedDisplay.textContent = duration.toFixed(1);
|
173 |
}, 100);
|
174 |
|
175 |
+
const updateDuration = (task) => {
|
176 |
+
if (task?.status == 'completed') {
|
177 |
+
duration = task.completed_at - task.created_at;
|
178 |
+
return;
|
179 |
+
}
|
180 |
+
|
181 |
+
duration = Number(((performance.now() - startTime) / 1_000).toFixed(1));
|
182 |
+
};
|
183 |
|
184 |
window.addEventListener('focus', updateDuration);
|
185 |
|
186 |
return {
|
187 |
+
cleanup: (completedTask) => {
|
188 |
+
updateDuration(completedTask);
|
189 |
clearInterval(incrementSeconds);
|
190 |
window.removeEventListener('focus', updateDuration);
|
191 |
+
elapsedDisplay.textContent = duration.toFixed(1);
|
192 |
},
|
193 |
};
|
194 |
};
|
|
|
245 |
}
|
246 |
|
247 |
const renderSection = document.querySelector('section.render');
|
248 |
+
const durationDisplay = document.querySelector('.duration');
|
249 |
const initialiseCardRotation = cardRotationInitiator(renderSection);
|
250 |
|
251 |
try {
|
|
|
253 |
|
254 |
const details = await generateDetails();
|
255 |
const task = await createTask(details.energy_type);
|
|
|
256 |
|
257 |
+
const timer = durationTimer(durationDisplay);
|
258 |
+
const timerCleanup = timer().cleanup;
|
259 |
+
|
260 |
+
const longPromises = [queueTask(task.task_id), longPollTask(task)];
|
261 |
+
const completedTask = await Promise.any(longPromises);
|
262 |
|
|
|
263 |
generating = false;
|
264 |
+
timerCleanup(completedTask);
|
265 |
|
266 |
renderSection.innerHTML = cardHTML(details);
|
267 |
const picture = document.querySelector('img.picture');
|
templates/index.html
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
</head>
|
16 |
<body>
|
17 |
<h1>This Pokémon Does Not Exist</h1>
|
18 |
-
<div class="duration"><span class="
|
19 |
<button class="generate">Generate Pokémon Card with AI</button>
|
20 |
<section class="render"></section>
|
21 |
</body>
|
|
|
15 |
</head>
|
16 |
<body>
|
17 |
<h1>This Pokémon Does Not Exist</h1>
|
18 |
+
<div class="duration"><span class="elapsed">0.0</span>s (ETA: <span class="eta">40</span>s)</div>
|
19 |
<button class="generate">Generate Pokémon Card with AI</button>
|
20 |
<section class="render"></section>
|
21 |
</body>
|