Spaces:
Runtime error
Runtime error
import logging | |
import time | |
from asyncio import Queue as AioQueue | |
from dataclasses import asdict | |
from multiprocessing import shared_memory | |
from queue import Queue | |
from threading import Thread | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import orjson | |
from redis import ConnectionPool, Redis | |
from inference.core.entities.requests.inference import ( | |
InferenceRequest, | |
request_from_type, | |
) | |
from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT | |
from inference.core.managers.base import ModelManager | |
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache | |
from inference.core.models.roboflow import RoboflowInferenceModel | |
from inference.core.registries.roboflow import RoboflowModelRegistry | |
from inference.enterprise.parallel.tasks import postprocess | |
from inference.enterprise.parallel.utils import ( | |
SharedMemoryMetadata, | |
failure_handler, | |
shm_manager, | |
) | |
logging.basicConfig(level=logging.WARNING) | |
logger = logging.getLogger() | |
from inference.models.utils import ROBOFLOW_MODEL_TYPES | |
BATCH_SIZE = MAX_BATCH_SIZE | |
if BATCH_SIZE == float("inf"): | |
BATCH_SIZE = 32 | |
AGE_TRADEOFF_SECONDS_FACTOR = 30 | |
class InferServer: | |
def __init__(self, redis: Redis) -> None: | |
self.redis = redis | |
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) | |
model_manager = ModelManager(model_registry) | |
self.model_manager = WithFixedSizeCache( | |
model_manager, max_size=MAX_ACTIVE_MODELS | |
) | |
self.running = True | |
self.response_queue = Queue() | |
self.write_thread = Thread(target=self.write_responses) | |
self.write_thread.start() | |
self.batch_queue = Queue(maxsize=1) | |
self.infer_thread = Thread(target=self.infer) | |
self.infer_thread.start() | |
def write_responses(self): | |
while True: | |
try: | |
response = self.response_queue.get() | |
write_infer_arrays_and_launch_postprocess(*response) | |
except Exception as error: | |
logger.warning( | |
f"Encountered error while writiing response:\n" + str(error) | |
) | |
def infer_loop(self): | |
while self.running: | |
try: | |
model_names = get_requested_model_names(self.redis) | |
if not model_names: | |
time.sleep(0.001) | |
continue | |
self.get_batch(model_names) | |
except Exception as error: | |
logger.warning("Encountered error in infer loop:\n" + str(error)) | |
continue | |
def infer(self): | |
while True: | |
model_id, images, batch, preproc_return_metadatas = self.batch_queue.get() | |
outputs = self.model_manager.predict(model_id, images) | |
for output, b, metadata in zip( | |
zip(*outputs), batch, preproc_return_metadatas | |
): | |
self.response_queue.put_nowait((output, b["request"], metadata)) | |
def get_batch(self, model_names): | |
start = time.perf_counter() | |
batch, model_id = get_batch(self.redis, model_names) | |
logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>") | |
with failure_handler(self.redis, *[b["request"]["id"] for b in batch]): | |
self.model_manager.add_model(model_id, batch[0]["request"]["api_key"]) | |
model_type = self.model_manager.get_task_type(model_id) | |
for b in batch: | |
request = request_from_type(model_type, b["request"]) | |
b["request"] = request | |
b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"]) | |
metadata_processed = time.perf_counter() | |
logger.info( | |
f"Took {(metadata_processed - start):3f} seconds to process metadata" | |
) | |
with shm_manager( | |
*[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True | |
) as shms: | |
images, preproc_return_metadatas = load_batch(batch, shms) | |
loaded = time.perf_counter() | |
logger.info( | |
f"Took {(loaded - metadata_processed):3f} seconds to load batch" | |
) | |
self.batch_queue.put( | |
(model_id, images, batch, preproc_return_metadatas) | |
) | |
def get_requested_model_names(redis: Redis) -> List[str]: | |
request_counts = redis.hgetall("requests") | |
model_names = [ | |
model_name for model_name, count in request_counts.items() if int(count) > 0 | |
] | |
return model_names | |
def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]: | |
""" | |
Run a heuristic to select the best batch to infer on | |
redis[Redis]: redis client | |
model_names[List[str]]: list of models with nonzero number of requests | |
returns: | |
Tuple[List[Dict], str] | |
List[Dict] represents a batch of request dicts | |
str is the model id | |
""" | |
batch_sizes = [ | |
RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"] | |
for m in model_names | |
] | |
batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes] | |
batches = [ | |
redis.zrange(f"infer:{m}", 0, b - 1, withscores=True) | |
for m, b in zip(model_names, batch_sizes) | |
] | |
model_index = select_best_inference_batch(batches, batch_sizes) | |
batch = batches[model_index] | |
selected_model = model_names[model_index] | |
redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch]) | |
redis.hincrby(f"requests", selected_model, -len(batch)) | |
batch = [orjson.loads(b[0]) for b in batch] | |
return batch, selected_model | |
def select_best_inference_batch(batches, batch_sizes): | |
now = time.time() | |
average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches] | |
lengths = [ | |
len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes) | |
] | |
fitnesses = [ | |
age / AGE_TRADEOFF_SECONDS_FACTOR + length | |
for age, length in zip(average_ages, lengths) | |
] | |
model_index = fitnesses.index(max(fitnesses)) | |
return model_index | |
def load_batch( | |
batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory] | |
) -> Tuple[List[np.ndarray], List[Dict]]: | |
images = [] | |
preproc_return_metadatas = [] | |
for b, shm in zip(batch, shms): | |
shm_metadata: SharedMemoryMetadata = b["shm_metadata"] | |
image = np.ndarray( | |
shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf | |
).copy() | |
images.append(image) | |
preproc_return_metadatas.append(b["preprocess_metadata"]) | |
return images, preproc_return_metadatas | |
def write_infer_arrays_and_launch_postprocess( | |
arrs: Tuple[np.ndarray, ...], | |
request: InferenceRequest, | |
preproc_return_metadata: Dict, | |
): | |
"""Write inference results to shared memory and launch the postprocessing task""" | |
shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs] | |
with shm_manager(*shms): | |
shm_metadatas = [] | |
for arr, shm in zip(arrs, shms): | |
shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) | |
shared[:] = arr[:] | |
shm_metadata = SharedMemoryMetadata( | |
shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name | |
) | |
shm_metadatas.append(asdict(shm_metadata)) | |
postprocess.s( | |
tuple(shm_metadatas), request.dict(), preproc_return_metadata | |
).delay() | |
if __name__ == "__main__": | |
pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) | |
redis = Redis(connection_pool=pool) | |
InferServer(redis).infer_loop() | |