Ron Au commited on
Commit
9fbb486
·
1 Parent(s): 67f60f6

refactor(queue): Improve reliability of handling many requests

Browse files
Files changed (4) hide show
  1. app.py +51 -73
  2. modules/inference.py +6 -3
  3. static/js/index.js +2 -6
  4. static/js/network.js +1 -6
app.py CHANGED
@@ -1,11 +1,10 @@
1
- from fastapi import FastAPI
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
-
5
-
6
  from time import time
7
  from statistics import mean
8
 
 
 
 
 
9
  from modules.details import rand_details
10
  from modules.inference import generate_image
11
 
@@ -13,31 +12,59 @@ app = FastAPI()
13
 
14
  app.mount("/static", StaticFiles(directory="static"), name="static")
15
 
16
-
17
  tasks = {}
18
 
19
 
20
  def get_place_in_queue(task_id):
 
 
 
 
21
 
22
- pending_tasks = list(task for task in tasks.values()
23
- if task["status"] == "pending")
24
 
25
  try:
26
- return pending_tasks.index(task_id) + 1
27
  except:
28
  return 0
29
 
30
 
31
  def calculate_eta(task_id):
32
- total_durations = list(task["completed_at"] - task["created_at"]
33
  for task in tasks.values() if "completed_at" in task)
34
 
35
- place = tasks[task_id]["initial_place_in_queue"] or 1
36
 
37
  if len(total_durations):
38
- return mean(total_durations) * place
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  else:
40
- return 40 * place
 
 
 
 
 
 
 
 
41
 
42
 
43
  @app.get('/')
@@ -45,8 +72,13 @@ def index():
45
  return FileResponse(path="static/index.html", media_type="text/html")
46
 
47
 
 
 
 
 
 
48
  @app.get('/task/create')
49
- def create_task(prompt: str = "покемон"):
50
  created_at = time()
51
 
52
  task_id = f"{str(created_at)}_{prompt}"
@@ -55,75 +87,21 @@ def create_task(prompt: str = "покемон"):
55
  "task_id": task_id,
56
  "created_at": created_at,
57
  "prompt": prompt,
58
- "initial_place_in_queue": get_place_in_queue(task_id),
59
- "status": "pending",
60
  "poll_count": 0,
61
  }
62
 
63
- print("Place in queue: ", get_place_in_queue(task_id))
64
- print("ETA: ", calculate_eta(task_id))
65
-
66
- return tasks[task_id]
67
-
68
 
69
- @app.get('/task/queue')
70
- def queue_task(task_id: str):
71
- try:
72
- tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
73
- except Exception as ex:
74
- tasks[task_id]["status"] = "failed"
75
- tasks[task_id]["error"] = repr(ex)
76
- else:
77
- tasks[task_id]["status"] = "completed"
78
- finally:
79
- tasks[task_id]["completed_at"] = time()
80
 
81
  return tasks[task_id]
82
 
83
 
84
  @app.get('/task/poll')
85
  def poll_task(task_id: str):
86
- pending_tasks = []
87
- completed_durations = []
88
-
89
- for task in tasks.values():
90
- if task["status"] == "pending":
91
- pending_tasks.append(task["task_id"])
92
- elif task["status"] == "completed":
93
- completed_durations.append(
94
- task["completed_at"] - task["created_at"])
95
-
96
- try:
97
- place_in_queue = pending_tasks.index(task_id) + 1
98
- except:
99
- place_in_queue = 0
100
-
101
- if (len(completed_durations)):
102
- eta = sum(completed_durations) / \
103
- len(completed_durations) * (place_in_queue or 1)
104
- else:
105
- eta = 40 * (place_in_queue or 1)
106
-
107
- tasks[task_id]["place_in_queue"] = place_in_queue
108
- tasks[task_id]["eta"] = round(eta, 1)
109
  tasks[task_id]["poll_count"] += 1
110
 
111
  return tasks[task_id]
112
-
113
-
114
- # @app.route('/details')
115
- @app.get('/details')
116
- async def generate_details():
117
- return rand_details()
118
-
119
-
120
- @app.get('/duck/quack')
121
- async def test(query: str = "quack"):
122
- print(query)
123
- return {"duck": query}
124
-
125
-
126
- @app.get('/test')
127
- async def test(query: str = "test"):
128
- print(query)
129
- return {"query": query}
 
 
 
 
 
 
1
  from time import time
2
  from statistics import mean
3
 
4
+ from fastapi import BackgroundTasks, FastAPI
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.responses import FileResponse
7
+
8
  from modules.details import rand_details
9
  from modules.inference import generate_image
10
 
 
12
 
13
  app.mount("/static", StaticFiles(directory="static"), name="static")
14
 
 
15
  tasks = {}
16
 
17
 
18
  def get_place_in_queue(task_id):
19
+ queued_tasks = list(task for task in tasks.values()
20
+ if task["status"] == "queued" or task["status"] == "processing")
21
+
22
+ queued_tasks.sort(key=lambda task: task["created_at"])
23
 
24
+ queued_task_ids = list(task["task_id"] for task in queued_tasks)
 
25
 
26
  try:
27
+ return queued_task_ids.index(task_id) + 1
28
  except:
29
  return 0
30
 
31
 
32
  def calculate_eta(task_id):
33
+ total_durations = list(task["completed_at"] - task["started_at"]
34
  for task in tasks.values() if "completed_at" in task)
35
 
36
+ initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
37
 
38
  if len(total_durations):
39
+ eta = initial_place_in_queue * mean(total_durations)
40
+ else:
41
+ eta = initial_place_in_queue * 40
42
+
43
+ return round(eta, 1)
44
+
45
+
46
+ def process_task(task_id):
47
+ if 'processing' in list(task['status'] for task in tasks.values()):
48
+ return
49
+
50
+ tasks[task_id]["status"] = "processing"
51
+ tasks[task_id]["started_at"] = time()
52
+
53
+ try:
54
+ tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
55
+ except Exception as ex:
56
+ tasks[task_id]["status"] = "failed"
57
+ tasks[task_id]["error"] = repr(ex)
58
  else:
59
+ tasks[task_id]["status"] = "completed"
60
+ finally:
61
+ tasks[task_id]["completed_at"] = time()
62
+
63
+ queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
64
+
65
+ if queued_tasks:
66
+ print(f"Tasks remaining: {len(queued_tasks)}")
67
+ process_task(queued_tasks[0]["task_id"])
68
 
69
 
70
  @app.get('/')
 
72
  return FileResponse(path="static/index.html", media_type="text/html")
73
 
74
 
75
+ @app.get('/details')
76
+ def generate_details():
77
+ return rand_details()
78
+
79
+
80
  @app.get('/task/create')
81
+ def create_task(background_tasks: BackgroundTasks, prompt: str = "покемон"):
82
  created_at = time()
83
 
84
  task_id = f"{str(created_at)}_{prompt}"
 
87
  "task_id": task_id,
88
  "created_at": created_at,
89
  "prompt": prompt,
90
+ "status": "queued",
 
91
  "poll_count": 0,
92
  }
93
 
94
+ tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
 
 
 
 
95
 
96
+ background_tasks.add_task(process_task, task_id)
 
 
 
 
 
 
 
 
 
 
97
 
98
  return tasks[task_id]
99
 
100
 
101
  @app.get('/task/poll')
102
  def poll_task(task_id: str):
103
+ tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
104
+ tasks[task_id]["eta"] = calculate_eta(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  tasks[task_id]["poll_count"] += 1
106
 
107
  return tasks[task_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/inference.py CHANGED
@@ -1,4 +1,6 @@
1
- print("Preparing for inference...") # noqa
 
 
2
 
3
  from rudalle.pipelines import generate_images
4
  from rudalle import get_rudalle_model, get_tokenizer, get_vae
@@ -26,7 +28,7 @@ model.load_state_dict(torch.load(f"{file_dir}/{file_name}", map_location=f"{'cud
26
  vae = get_vae().to(device)
27
  tokenizer = get_tokenizer()
28
 
29
- print("Ready for inference")
30
 
31
 
32
  def english_to_russian(english_string):
@@ -48,7 +50,8 @@ def english_to_russian(english_string):
48
 
49
 
50
  def generate_image(prompt):
51
- if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
 
52
  prompt = english_to_russian(prompt)
53
 
54
  result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
 
1
+ from time import gmtime, strftime
2
+
3
+ print(f'{strftime("%Y-%m-%d %H:%M:%S", gmtime())} Preparing for inference...') # noqa
4
 
5
  from rudalle.pipelines import generate_images
6
  from rudalle import get_rudalle_model, get_tokenizer, get_vae
 
28
  vae = get_vae().to(device)
29
  tokenizer = get_tokenizer()
30
 
31
+ print(f'{strftime("%Y-%m-%d %H:%M:%S", gmtime())} Ready for inference')
32
 
33
 
34
  def english_to_russian(english_string):
 
50
 
51
 
52
  def generate_image(prompt):
53
+ if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness',
54
+ 'metal', 'dragon', 'fairy']:
55
  prompt = english_to_russian(prompt)
56
 
57
  result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
static/js/index.js CHANGED
@@ -1,4 +1,4 @@
1
- import { generateDetails, createTask, queueTask, longPollTask } from './network.js';
2
  import { cardHTML } from './card-html.js';
3
  import {
4
  durationTimer,
@@ -30,11 +30,8 @@ const generate = async () => {
30
  scene.removeEventListener('mousemove', mousemoveHandlerForPreviousCard, true);
31
  cardSlot.innerHTML = '';
32
  generating = true;
33
-
34
  setOutput('booster', 'generating');
35
 
36
- await new Promise((resolve) => setTimeout(resolve, 9999));
37
-
38
  const details = await generateDetails();
39
  pokeName = details.name;
40
  const task = await createTask(details.energy_type);
@@ -44,8 +41,7 @@ const generate = async () => {
44
  const timer = durationTimer(durationDisplay);
45
  const timerCleanup = timer().cleanup;
46
 
47
- const longPromises = [queueTask(task.task_id), longPollTask(task)];
48
- const completedTask = await Promise.any(longPromises);
49
 
50
  generating = false;
51
  timerCleanup(completedTask);
 
1
+ import { generateDetails, createTask, longPollTask } from './network.js';
2
  import { cardHTML } from './card-html.js';
3
  import {
4
  durationTimer,
 
30
  scene.removeEventListener('mousemove', mousemoveHandlerForPreviousCard, true);
31
  cardSlot.innerHTML = '';
32
  generating = true;
 
33
  setOutput('booster', 'generating');
34
 
 
 
35
  const details = await generateDetails();
36
  pokeName = details.name;
37
  const task = await createTask(details.energy_type);
 
41
  const timer = durationTimer(durationDisplay);
42
  const timerCleanup = timer().cleanup;
43
 
44
+ const completedTask = await longPollTask(task);
 
45
 
46
  generating = false;
47
  timerCleanup(completedTask);
static/js/network.js CHANGED
@@ -19,11 +19,6 @@ const createTask = async (prompt) => {
19
  return task;
20
  };
21
 
22
- const queueTask = async (task_id) => {
23
- const queueResponse = await fetch(pathFor(`task/queue?task_id=${task_id}`));
24
- return queueResponse.json();
25
- };
26
-
27
  const pollTask = async (task) => {
28
  const taskResponse = await fetch(pathFor(`task/poll?task_id=${task.task_id}`));
29
 
@@ -50,4 +45,4 @@ const longPollTask = async (task, interval = 10_000, max) => {
50
  return await longPollTask(task, interval, max);
51
  };
52
 
53
- export { generateDetails, createTask, queueTask, longPollTask };
 
19
  return task;
20
  };
21
 
 
 
 
 
 
22
  const pollTask = async (task) => {
23
  const taskResponse = await fetch(pathFor(`task/poll?task_id=${task.task_id}`));
24
 
 
45
  return await longPollTask(task, interval, max);
46
  };
47
 
48
+ export { generateDetails, createTask, longPollTask };