import os import signal import socket import sys from functools import partial from multiprocessing import Process, Queue from socketserver import BaseRequestHandler, BaseServer from types import FrameType from typing import Any, Dict, Optional, Tuple from uuid import uuid4 from inference.core import logger from inference.enterprise.stream_management.manager.communication import ( receive_socket_data, send_data_trough_socket, ) from inference.enterprise.stream_management.manager.entities import ( PIPELINE_ID_KEY, STATUS_KEY, TYPE_KEY, CommandType, ErrorType, OperationStatus, ) from inference.enterprise.stream_management.manager.errors import MalformedPayloadError from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( InferencePipelineManager, ) from inference.enterprise.stream_management.manager.serialisation import ( describe_error, prepare_error_response, prepare_response, ) from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {} HEADER_SIZE = 4 SOCKET_BUFFER_SIZE = 16384 HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1") PORT = int(os.getenv("STREAM_MANAGER_PORT", "7070")) SOCKET_TIMEOUT = float(os.getenv("STREAM_MANAGER_SOCKET_TIMEOUT", "5.0")) class InferencePipelinesManagerHandler(BaseRequestHandler): def __init__( self, request: socket.socket, client_address: Any, server: BaseServer, processes_table: Dict[str, Tuple[Process, Queue, Queue]], ): self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle() super().__init__(request, client_address, server) def handle(self) -> None: pipeline_id: Optional[str] = None request_id = str(uuid4()) try: data = receive_socket_data( source=self.request, header_size=HEADER_SIZE, buffer_size=SOCKET_BUFFER_SIZE, ) data[TYPE_KEY] = CommandType(data[TYPE_KEY]) if data[TYPE_KEY] is CommandType.LIST_PIPELINES: return self._list_pipelines(request_id=request_id) if data[TYPE_KEY] is CommandType.INIT: return self._initialise_pipeline(request_id=request_id, command=data) pipeline_id = data[PIPELINE_ID_KEY] if data[TYPE_KEY] is CommandType.TERMINATE: self._terminate_pipeline( request_id=request_id, pipeline_id=pipeline_id, command=data ) else: response = handle_command( processes_table=self._processes_table, request_id=request_id, pipeline_id=pipeline_id, command=data, ) serialised_response = prepare_response( request_id=request_id, response=response, pipeline_id=pipeline_id ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=serialised_response, request_id=request_id, pipeline_id=pipeline_id, ) except (KeyError, ValueError, MalformedPayloadError) as error: logger.error( f"Invalid payload in processes manager. error={error} request_id={request_id}..." ) payload = prepare_error_response( request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD, pipeline_id=pipeline_id, ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=payload, request_id=request_id, pipeline_id=pipeline_id, ) except Exception as error: logger.error( f"Internal error in processes manager. error={error} request_id={request_id}..." ) payload = prepare_error_response( request_id=request_id, error=error, error_type=ErrorType.INTERNAL_ERROR, pipeline_id=pipeline_id, ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=payload, request_id=request_id, pipeline_id=pipeline_id, ) def _list_pipelines(self, request_id: str) -> None: serialised_response = prepare_response( request_id=request_id, response={ "pipelines": list(self._processes_table.keys()), STATUS_KEY: OperationStatus.SUCCESS, }, pipeline_id=None, ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=serialised_response, request_id=request_id, ) def _initialise_pipeline(self, request_id: str, command: dict) -> None: pipeline_id = str(uuid4()) command_queue = Queue() responses_queue = Queue() inference_pipeline_manager = InferencePipelineManager.init( command_queue=command_queue, responses_queue=responses_queue, ) inference_pipeline_manager.start() self._processes_table[pipeline_id] = ( inference_pipeline_manager, command_queue, responses_queue, ) command_queue.put((request_id, command)) response = get_response_ignoring_thrash( responses_queue=responses_queue, matching_request_id=request_id ) serialised_response = prepare_response( request_id=request_id, response=response, pipeline_id=pipeline_id ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=serialised_response, request_id=request_id, pipeline_id=pipeline_id, ) def _terminate_pipeline( self, request_id: str, pipeline_id: str, command: dict ) -> None: response = handle_command( processes_table=self._processes_table, request_id=request_id, pipeline_id=pipeline_id, command=command, ) if response[STATUS_KEY] is OperationStatus.SUCCESS: logger.info( f"Joining inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" ) join_inference_pipeline( processes_table=self._processes_table, pipeline_id=pipeline_id ) logger.info( f"Joined inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" ) serialised_response = prepare_response( request_id=request_id, response=response, pipeline_id=pipeline_id ) send_data_trough_socket( target=self.request, header_size=HEADER_SIZE, data=serialised_response, request_id=request_id, pipeline_id=pipeline_id, ) def handle_command( processes_table: Dict[str, Tuple[Process, Queue, Queue]], request_id: str, pipeline_id: str, command: dict, ) -> dict: if pipeline_id not in processes_table: return describe_error(exception=None, error_type=ErrorType.NOT_FOUND) _, command_queue, responses_queue = processes_table[pipeline_id] command_queue.put((request_id, command)) return get_response_ignoring_thrash( responses_queue=responses_queue, matching_request_id=request_id ) def get_response_ignoring_thrash( responses_queue: Queue, matching_request_id: str ) -> dict: while True: response = responses_queue.get() if response[0] == matching_request_id: return response[1] logger.warning( f"Dropping response for request_id={response[0]} with payload={response[1]}" ) def execute_termination( signal_number: int, frame: FrameType, processes_table: Dict[str, Tuple[Process, Queue, Queue]], ) -> None: pipeline_ids = list(processes_table.keys()) for pipeline_id in pipeline_ids: logger.info(f"Terminating pipeline: {pipeline_id}") processes_table[pipeline_id][0].terminate() logger.info(f"Pipeline: {pipeline_id} terminated.") logger.info(f"Joining pipeline: {pipeline_id}") processes_table[pipeline_id][0].join() logger.info(f"Pipeline: {pipeline_id} joined.") logger.info(f"Termination handler completed.") sys.exit(0) def join_inference_pipeline( processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str ) -> None: inference_pipeline_manager, command_queue, responses_queue = processes_table[ pipeline_id ] inference_pipeline_manager.join() del processes_table[pipeline_id] if __name__ == "__main__": signal.signal( signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE) ) signal.signal( signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE) ) with RoboflowTCPServer( server_address=(HOST, PORT), handler_class=partial( InferencePipelinesManagerHandler, processes_table=PROCESSES_TABLE ), socket_operations_timeout=SOCKET_TIMEOUT, ) as tcp_server: logger.info( f"Inference Pipeline Processes Manager is ready to accept connections at {(HOST, PORT)}" ) tcp_server.serve_forever()