File size: 2,782 Bytes
4c519fd
 
5c239ba
 
8f246ac
5c239ba
 
 
 
 
 
 
 
 
 
 
 
 
3750ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c239ba
 
 
 
4c519fd
 
 
5c239ba
 
 
4c519fd
5c239ba
3750ff9
5c239ba
 
 
 
3750ff9
 
 
5c239ba
 
 
 
 
 
 
 
 
4c519fd
 
 
5c239ba
 
 
 
 
 
 
 
4c519fd
 
 
 
 
 
 
8f246ac
 
4c519fd
 
 
 
 
 
 
8f246ac
 
4c519fd
 
 
 
 
5c239ba
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from time import time
from statistics import mean
from flask import Flask, jsonify, render_template, request

from modules.details import rand_details
from modules.inference import generate_image

app = Flask(__name__)


@app.route('/')
def index():
    return render_template('index.html', **rand_details())


tasks = {}


def place_in_queue(task_id):

    pending_tasks = list(task for task in tasks.values()
                         if task["status"] == "pending")

    try:
        return pending_tasks.index(task_id) + 1
    except:
        return 0


def calculate_eta(task_id):
    total_durations = list(task["completed_at"] - task["created_at"]
                           for task in tasks.values() if "completed_at" in task)

    place = tasks[task_id]["initial_place_in_queue"] or 1

    if len(total_durations):
        return sum(total_durations) / len(total_durations) * place
    else:
        return 40 * place


@app.route('/task/create')
def create_task():
    prompt = request.args.get('prompt') or "покемон"

    created_at = time()

    task_id = f"{str(created_at)}_{prompt}"

    tasks[task_id] = {
        "task_id": task_id,
        "created_at": created_at,
        "prompt": prompt,
        "initial_place_in_queue": place_in_queue(task_id),
        "status": "pending",
        "poll_count": 0,
    }

    print("Place in queue: ", place_in_queue(task_id))
    print("ETA: ", calculate_eta(task_id))

    return jsonify(tasks[task_id])


@app.route('/task/queue')
def queue_task():
    task_id = request.args.get('task_id')

    tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])

    tasks[task_id]["status"] = "completed"

    tasks[task_id]["completed_at"] = time()

    return jsonify(tasks[task_id])


@app.route('/task/poll')
def poll_task():
    task_id = request.args.get('task_id')

    pending_tasks = []
    completed_durations = []

    for task in tasks.values():
        if task["status"] == "pending":
            pending_tasks.append(task["task_id"])
        elif task["status"] == "completed":
            completed_durations.append(
                task["completed_at"] - task["created_at"])

    try:
        place_in_queue = pending_tasks.index(task_id) + 1
    except:
        place_in_queue = 0

    if (len(completed_durations)):
        eta = sum(completed_durations) / \
            len(completed_durations) * (place_in_queue or 1)
    else:
        eta = 40 * (place_in_queue or 1)

    tasks[task_id]["place_in_queue"] = place_in_queue
    tasks[task_id]["eta"] = round(eta, 1)
    tasks[task_id]["poll_count"] += 1

    return jsonify(tasks[task_id])


@app.route('/details')
def generate_details():
    return jsonify(rand_details())


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)