Spaces:
Running
Running
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from time import time | |
from statistics import mean | |
from modules.details import rand_details | |
from modules.inference import generate_image | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
tasks = {} | |
def get_place_in_queue(task_id): | |
pending_tasks = list(task for task in tasks.values() | |
if task["status"] == "pending") | |
try: | |
return pending_tasks.index(task_id) + 1 | |
except: | |
return 0 | |
def calculate_eta(task_id): | |
total_durations = list(task["completed_at"] - task["created_at"] | |
for task in tasks.values() if "completed_at" in task) | |
place = tasks[task_id]["initial_place_in_queue"] or 1 | |
if len(total_durations): | |
return mean(total_durations) * place | |
else: | |
return 40 * place | |
def index(): | |
return FileResponse(path="static/index.html", media_type="text/html") | |
def create_task(prompt: str = "покемон"): | |
created_at = time() | |
task_id = f"{str(created_at)}_{prompt}" | |
tasks[task_id] = { | |
"task_id": task_id, | |
"created_at": created_at, | |
"prompt": prompt, | |
"initial_place_in_queue": get_place_in_queue(task_id), | |
"status": "pending", | |
"poll_count": 0, | |
} | |
print("Place in queue: ", get_place_in_queue(task_id)) | |
print("ETA: ", calculate_eta(task_id)) | |
return tasks[task_id] | |
def queue_task(task_id: str): | |
try: | |
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"]) | |
except Exception as ex: | |
tasks[task_id]["status"] = "failed" | |
tasks[task_id]["error"] = repr(ex) | |
else: | |
tasks[task_id]["status"] = "completed" | |
finally: | |
tasks[task_id]["completed_at"] = time() | |
return tasks[task_id] | |
def poll_task(task_id: str): | |
pending_tasks = [] | |
completed_durations = [] | |
for task in tasks.values(): | |
if task["status"] == "pending": | |
pending_tasks.append(task["task_id"]) | |
elif task["status"] == "completed": | |
completed_durations.append( | |
task["completed_at"] - task["created_at"]) | |
try: | |
place_in_queue = pending_tasks.index(task_id) + 1 | |
except: | |
place_in_queue = 0 | |
if (len(completed_durations)): | |
eta = sum(completed_durations) / \ | |
len(completed_durations) * (place_in_queue or 1) | |
else: | |
eta = 40 * (place_in_queue or 1) | |
tasks[task_id]["place_in_queue"] = place_in_queue | |
tasks[task_id]["eta"] = round(eta, 1) | |
tasks[task_id]["poll_count"] += 1 | |
return tasks[task_id] | |
# @app.route('/details') | |
async def generate_details(): | |
return rand_details() | |
async def test(query: str = "quack"): | |
print(query) | |
return {"duck": query} | |
async def test(query: str = "test"): | |
print(query) | |
return {"query": query} | |