Spaces:
Runtime error
Runtime error
import os | |
import signal | |
from dataclasses import asdict | |
from multiprocessing import Process, Queue | |
from types import FrameType | |
from typing import Callable, Optional, Tuple | |
from inference.core import logger | |
from inference.core.exceptions import ( | |
MissingApiKeyError, | |
RoboflowAPINotAuthorizedError, | |
RoboflowAPINotNotFoundError, | |
) | |
from inference.core.interfaces.camera.entities import VideoFrame | |
from inference.core.interfaces.camera.exceptions import StreamOperationNotAllowedError | |
from inference.core.interfaces.camera.video_source import ( | |
BufferConsumptionStrategy, | |
BufferFillingStrategy, | |
) | |
from inference.core.interfaces.stream.entities import ObjectDetectionPrediction | |
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline | |
from inference.core.interfaces.stream.sinks import UDPSink | |
from inference.core.interfaces.stream.watchdog import ( | |
BasePipelineWatchDog, | |
PipelineWatchDog, | |
) | |
from inference.enterprise.stream_management.manager.entities import ( | |
STATUS_KEY, | |
TYPE_KEY, | |
CommandType, | |
ErrorType, | |
OperationStatus, | |
) | |
from inference.enterprise.stream_management.manager.serialisation import describe_error | |
def ignore_signal(signal_number: int, frame: FrameType) -> None: | |
pid = os.getpid() | |
logger.info( | |
f"Ignoring signal {signal_number} in InferencePipelineManager in process:{pid}" | |
) | |
class InferencePipelineManager(Process): | |
def init( | |
cls, command_queue: Queue, responses_queue: Queue | |
) -> "InferencePipelineManager": | |
return cls(command_queue=command_queue, responses_queue=responses_queue) | |
def __init__(self, command_queue: Queue, responses_queue: Queue): | |
super().__init__() | |
self._command_queue = command_queue | |
self._responses_queue = responses_queue | |
self._inference_pipeline: Optional[InferencePipeline] = None | |
self._watchdog: Optional[PipelineWatchDog] = None | |
self._stop = False | |
def run(self) -> None: | |
signal.signal(signal.SIGINT, ignore_signal) | |
signal.signal(signal.SIGTERM, self._handle_termination_signal) | |
while not self._stop: | |
command: Optional[Tuple[str, dict]] = self._command_queue.get() | |
if command is None: | |
break | |
request_id, payload = command | |
self._handle_command(request_id=request_id, payload=payload) | |
def _handle_command(self, request_id: str, payload: dict) -> None: | |
try: | |
logger.info(f"Processing request={request_id}...") | |
command_type = payload[TYPE_KEY] | |
if command_type is CommandType.INIT: | |
return self._initialise_pipeline(request_id=request_id, payload=payload) | |
if command_type is CommandType.TERMINATE: | |
return self._terminate_pipeline(request_id=request_id) | |
if command_type is CommandType.MUTE: | |
return self._mute_pipeline(request_id=request_id) | |
if command_type is CommandType.RESUME: | |
return self._resume_pipeline(request_id=request_id) | |
if command_type is CommandType.STATUS: | |
return self._get_pipeline_status(request_id=request_id) | |
raise NotImplementedError( | |
f"Command type `{command_type}` cannot be handled" | |
) | |
except (KeyError, NotImplementedError) as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD | |
) | |
except Exception as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.INTERNAL_ERROR | |
) | |
def _initialise_pipeline(self, request_id: str, payload: dict) -> None: | |
try: | |
watchdog = BasePipelineWatchDog() | |
sink = assembly_pipeline_sink(sink_config=payload["sink_configuration"]) | |
source_buffer_filling_strategy, source_buffer_consumption_strategy = ( | |
None, | |
None, | |
) | |
if "source_buffer_filling_strategy" in payload: | |
source_buffer_filling_strategy = BufferFillingStrategy( | |
payload["source_buffer_filling_strategy"].upper() | |
) | |
if "source_buffer_consumption_strategy" in payload: | |
source_buffer_consumption_strategy = BufferConsumptionStrategy( | |
payload["source_buffer_consumption_strategy"].upper() | |
) | |
model_configuration = payload["model_configuration"] | |
if model_configuration["type"] != "object-detection": | |
raise NotImplementedError("Only object-detection models are supported") | |
self._inference_pipeline = InferencePipeline.init( | |
model_id=payload["model_id"], | |
video_reference=payload["video_reference"], | |
on_prediction=sink, | |
api_key=payload.get("api_key"), | |
max_fps=payload.get("max_fps"), | |
watchdog=watchdog, | |
source_buffer_filling_strategy=source_buffer_filling_strategy, | |
source_buffer_consumption_strategy=source_buffer_consumption_strategy, | |
class_agnostic_nms=model_configuration.get("class_agnostic_nms"), | |
confidence=model_configuration.get("confidence"), | |
iou_threshold=model_configuration.get("iou_threshold"), | |
max_candidates=model_configuration.get("max_candidates"), | |
max_detections=model_configuration.get("max_detections"), | |
active_learning_enabled=payload.get("active_learning_enabled"), | |
) | |
self._watchdog = watchdog | |
self._inference_pipeline.start(use_main_thread=False) | |
self._responses_queue.put( | |
(request_id, {STATUS_KEY: OperationStatus.SUCCESS}) | |
) | |
logger.info(f"Pipeline initialised. request_id={request_id}...") | |
except (MissingApiKeyError, KeyError, NotImplementedError) as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD | |
) | |
except RoboflowAPINotAuthorizedError as error: | |
self._handle_error( | |
request_id=request_id, | |
error=error, | |
error_type=ErrorType.AUTHORISATION_ERROR, | |
) | |
except RoboflowAPINotNotFoundError as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.NOT_FOUND | |
) | |
def _terminate_pipeline(self, request_id: str) -> None: | |
if self._inference_pipeline is None: | |
self._responses_queue.put( | |
(request_id, {STATUS_KEY: OperationStatus.SUCCESS}) | |
) | |
self._stop = True | |
return None | |
try: | |
self._execute_termination() | |
logger.info(f"Pipeline terminated. request_id={request_id}...") | |
self._responses_queue.put( | |
(request_id, {STATUS_KEY: OperationStatus.SUCCESS}) | |
) | |
except StreamOperationNotAllowedError as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR | |
) | |
def _handle_termination_signal(self, signal_number: int, frame: FrameType) -> None: | |
try: | |
pid = os.getpid() | |
logger.info(f"Terminating pipeline in process:{pid}...") | |
if self._inference_pipeline is not None: | |
self._execute_termination() | |
self._command_queue.put(None) | |
logger.info(f"Termination successful in process:{pid}...") | |
except Exception as error: | |
logger.warning(f"Could not terminate pipeline gracefully. Error: {error}") | |
def _execute_termination(self) -> None: | |
self._inference_pipeline.terminate() | |
self._inference_pipeline.join() | |
self._stop = True | |
def _mute_pipeline(self, request_id: str) -> None: | |
if self._inference_pipeline is None: | |
return self._handle_error( | |
request_id=request_id, error_type=ErrorType.OPERATION_ERROR | |
) | |
try: | |
self._inference_pipeline.mute_stream() | |
logger.info(f"Pipeline muted. request_id={request_id}...") | |
self._responses_queue.put( | |
(request_id, {STATUS_KEY: OperationStatus.SUCCESS}) | |
) | |
except StreamOperationNotAllowedError as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR | |
) | |
def _resume_pipeline(self, request_id: str) -> None: | |
if self._inference_pipeline is None: | |
return self._handle_error( | |
request_id=request_id, error_type=ErrorType.OPERATION_ERROR | |
) | |
try: | |
self._inference_pipeline.resume_stream() | |
logger.info(f"Pipeline resumed. request_id={request_id}...") | |
self._responses_queue.put( | |
(request_id, {STATUS_KEY: OperationStatus.SUCCESS}) | |
) | |
except StreamOperationNotAllowedError as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR | |
) | |
def _get_pipeline_status(self, request_id: str) -> None: | |
if self._watchdog is None: | |
return self._handle_error( | |
request_id=request_id, error_type=ErrorType.OPERATION_ERROR | |
) | |
try: | |
report = self._watchdog.get_report() | |
if report is None: | |
return self._handle_error( | |
request_id=request_id, error_type=ErrorType.OPERATION_ERROR | |
) | |
response_payload = { | |
STATUS_KEY: OperationStatus.SUCCESS, | |
"report": asdict(report), | |
} | |
self._responses_queue.put((request_id, response_payload)) | |
logger.info(f"Pipeline status returned. request_id={request_id}...") | |
except StreamOperationNotAllowedError as error: | |
self._handle_error( | |
request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR | |
) | |
def _handle_error( | |
self, | |
request_id: str, | |
error: Optional[Exception] = None, | |
error_type: ErrorType = ErrorType.INTERNAL_ERROR, | |
): | |
logger.error( | |
f"Could not handle Command. request_id={request_id}, error={error}, error_type={error_type}" | |
) | |
response_payload = describe_error(error, error_type=error_type) | |
self._responses_queue.put((request_id, response_payload)) | |
def assembly_pipeline_sink( | |
sink_config: dict, | |
) -> Callable[[ObjectDetectionPrediction, VideoFrame], None]: | |
if sink_config["type"] != "udp_sink": | |
raise NotImplementedError("Only `udp_socket` sink type is supported") | |
sink = UDPSink.init(ip_address=sink_config["host"], port=sink_config["port"]) | |
return sink.send_predictions | |