import logging import time from asyncio import Queue as AioQueue from dataclasses import asdict from multiprocessing import shared_memory from queue import Queue from threading import Thread from typing import Dict, List, Tuple import numpy as np import orjson from redis import ConnectionPool, Redis from inference.core.entities.requests.inference import ( InferenceRequest, request_from_type, ) from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT from inference.core.managers.base import ModelManager from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache from inference.core.models.roboflow import RoboflowInferenceModel from inference.core.registries.roboflow import RoboflowModelRegistry from inference.enterprise.parallel.tasks import postprocess from inference.enterprise.parallel.utils import ( SharedMemoryMetadata, failure_handler, shm_manager, ) logging.basicConfig(level=logging.WARNING) logger = logging.getLogger() from inference.models.utils import ROBOFLOW_MODEL_TYPES BATCH_SIZE = MAX_BATCH_SIZE if BATCH_SIZE == float("inf"): BATCH_SIZE = 32 AGE_TRADEOFF_SECONDS_FACTOR = 30 class InferServer: def __init__(self, redis: Redis) -> None: self.redis = redis model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) model_manager = ModelManager(model_registry) self.model_manager = WithFixedSizeCache( model_manager, max_size=MAX_ACTIVE_MODELS ) self.running = True self.response_queue = Queue() self.write_thread = Thread(target=self.write_responses) self.write_thread.start() self.batch_queue = Queue(maxsize=1) self.infer_thread = Thread(target=self.infer) self.infer_thread.start() def write_responses(self): while True: try: response = self.response_queue.get() write_infer_arrays_and_launch_postprocess(*response) except Exception as error: logger.warning( f"Encountered error while writiing response:\n" + str(error) ) def infer_loop(self): while self.running: try: model_names = get_requested_model_names(self.redis) if not model_names: time.sleep(0.001) continue self.get_batch(model_names) except Exception as error: logger.warning("Encountered error in infer loop:\n" + str(error)) continue def infer(self): while True: model_id, images, batch, preproc_return_metadatas = self.batch_queue.get() outputs = self.model_manager.predict(model_id, images) for output, b, metadata in zip( zip(*outputs), batch, preproc_return_metadatas ): self.response_queue.put_nowait((output, b["request"], metadata)) def get_batch(self, model_names): start = time.perf_counter() batch, model_id = get_batch(self.redis, model_names) logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>") with failure_handler(self.redis, *[b["request"]["id"] for b in batch]): self.model_manager.add_model(model_id, batch[0]["request"]["api_key"]) model_type = self.model_manager.get_task_type(model_id) for b in batch: request = request_from_type(model_type, b["request"]) b["request"] = request b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"]) metadata_processed = time.perf_counter() logger.info( f"Took {(metadata_processed - start):3f} seconds to process metadata" ) with shm_manager( *[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True ) as shms: images, preproc_return_metadatas = load_batch(batch, shms) loaded = time.perf_counter() logger.info( f"Took {(loaded - metadata_processed):3f} seconds to load batch" ) self.batch_queue.put( (model_id, images, batch, preproc_return_metadatas) ) def get_requested_model_names(redis: Redis) -> List[str]: request_counts = redis.hgetall("requests") model_names = [ model_name for model_name, count in request_counts.items() if int(count) > 0 ] return model_names def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]: """ Run a heuristic to select the best batch to infer on redis[Redis]: redis client model_names[List[str]]: list of models with nonzero number of requests returns: Tuple[List[Dict], str] List[Dict] represents a batch of request dicts str is the model id """ batch_sizes = [ RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"] for m in model_names ] batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes] batches = [ redis.zrange(f"infer:{m}", 0, b - 1, withscores=True) for m, b in zip(model_names, batch_sizes) ] model_index = select_best_inference_batch(batches, batch_sizes) batch = batches[model_index] selected_model = model_names[model_index] redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch]) redis.hincrby(f"requests", selected_model, -len(batch)) batch = [orjson.loads(b[0]) for b in batch] return batch, selected_model def select_best_inference_batch(batches, batch_sizes): now = time.time() average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches] lengths = [ len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes) ] fitnesses = [ age / AGE_TRADEOFF_SECONDS_FACTOR + length for age, length in zip(average_ages, lengths) ] model_index = fitnesses.index(max(fitnesses)) return model_index def load_batch( batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory] ) -> Tuple[List[np.ndarray], List[Dict]]: images = [] preproc_return_metadatas = [] for b, shm in zip(batch, shms): shm_metadata: SharedMemoryMetadata = b["shm_metadata"] image = np.ndarray( shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf ).copy() images.append(image) preproc_return_metadatas.append(b["preprocess_metadata"]) return images, preproc_return_metadatas def write_infer_arrays_and_launch_postprocess( arrs: Tuple[np.ndarray, ...], request: InferenceRequest, preproc_return_metadata: Dict, ): """Write inference results to shared memory and launch the postprocessing task""" shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs] with shm_manager(*shms): shm_metadatas = [] for arr, shm in zip(arrs, shms): shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) shared[:] = arr[:] shm_metadata = SharedMemoryMetadata( shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name ) shm_metadatas.append(asdict(shm_metadata)) postprocess.s( tuple(shm_metadatas), request.dict(), preproc_return_metadata ).delay() if __name__ == "__main__": pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) redis = Redis(connection_pool=pool) InferServer(redis).infer_loop()