Spaces:
Runtime error
Runtime error
import asyncio | |
from asyncio import BoundedSemaphore | |
from time import perf_counter, time | |
from typing import Any, Dict, List, Optional | |
import orjson | |
from redis.asyncio import Redis | |
from inference.core.entities.requests.inference import ( | |
InferenceRequest, | |
request_from_type, | |
) | |
from inference.core.entities.responses.inference import response_from_type | |
from inference.core.env import NUM_PARALLEL_TASKS | |
from inference.core.managers.base import ModelManager | |
from inference.core.registries.base import ModelRegistry | |
from inference.core.registries.roboflow import get_model_type | |
from inference.enterprise.parallel.tasks import preprocess | |
from inference.enterprise.parallel.utils import FAILURE_STATE, SUCCESS_STATE | |
class ResultsChecker: | |
""" | |
Class responsible for queuing asyncronous inference runs, | |
keeping track of running requests, and awaiting their results. | |
""" | |
def __init__(self, redis: Redis): | |
self.tasks: Dict[str, asyncio.Event] = {} | |
self.dones = dict() | |
self.errors = dict() | |
self.running = True | |
self.redis = redis | |
self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS) | |
async def add_task(self, task_id: str, request: InferenceRequest): | |
""" | |
Wait until there's available cylce to queue a task. | |
When there are cycles, add the task's id to a list to keep track of its results, | |
launch the preprocess celeryt task, set the task's status to in progress in redis. | |
""" | |
await self.semaphore.acquire() | |
self.tasks[task_id] = asyncio.Event() | |
preprocess.s(request.dict()).delay() | |
def get_result(self, task_id: str) -> Any: | |
""" | |
Check the done tasks and errored tasks for this task id. | |
""" | |
if task_id in self.dones: | |
return self.dones.pop(task_id) | |
elif task_id in self.errors: | |
message = self.errors.pop(task_id) | |
raise Exception(message) | |
else: | |
raise RuntimeError( | |
"Task result not found in either success or error dict. Unreachable" | |
) | |
async def loop(self): | |
""" | |
Main loop. Check all in progress tasks for their status, and if their status is final, | |
(either failure or success) then add their results to the appropriate results dictionary. | |
""" | |
async with self.redis.pubsub() as pubsub: | |
await pubsub.subscribe("results") | |
async for message in pubsub.listen(): | |
if message["type"] != "message": | |
continue | |
message = orjson.loads(message["data"]) | |
task_id = message.pop("task_id") | |
if task_id not in self.tasks: | |
continue | |
self.semaphore.release() | |
status = message.pop("status") | |
if status == FAILURE_STATE: | |
self.errors[task_id] = message["payload"] | |
elif status == SUCCESS_STATE: | |
self.dones[task_id] = message["payload"] | |
else: | |
raise RuntimeError( | |
"Task result not found in possible states. Unreachable" | |
) | |
self.tasks[task_id].set() | |
await asyncio.sleep(0) | |
async def wait_for_response(self, key: str): | |
event = self.tasks[key] | |
await event.wait() | |
del self.tasks[key] | |
return self.get_result(key) | |
class DispatchModelManager(ModelManager): | |
def __init__( | |
self, | |
model_registry: ModelRegistry, | |
checker: ResultsChecker, | |
models: Optional[dict] = None, | |
): | |
super().__init__(model_registry, models) | |
self.checker = checker | |
async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): | |
if request.visualize_predictions: | |
raise NotImplementedError("Visualisation of prediction is not supported") | |
request.start = time() | |
t = perf_counter() | |
task_type = self.get_task_type(model_id, request.api_key) | |
list_mode = False | |
if isinstance(request.image, list): | |
list_mode = True | |
request_dict = request.dict() | |
images = request_dict.pop("image") | |
del request_dict["id"] | |
requests = [ | |
request_from_type(task_type, dict(**request_dict, image=image)) | |
for image in images | |
] | |
else: | |
requests = [request] | |
start_task_awaitables = [] | |
results_awaitables = [] | |
for r in requests: | |
start_task_awaitables.append(self.checker.add_task(r.id, r)) | |
results_awaitables.append(self.checker.wait_for_response(r.id)) | |
await asyncio.gather(*start_task_awaitables) | |
response_jsons = await asyncio.gather(*results_awaitables) | |
responses = [] | |
for response_json in response_jsons: | |
response = response_from_type(task_type, response_json) | |
response.time = perf_counter() - t | |
responses.append(response) | |
if list_mode: | |
return responses | |
return responses[0] | |
def add_model( | |
self, model_id: str, api_key: str, model_id_alias: str = None | |
) -> None: | |
pass | |
def __contains__(self, model_id: str) -> bool: | |
return True | |
def get_task_type(self, model_id: str, api_key: str = None) -> str: | |
return get_model_type(model_id, api_key)[0] | |