OMG / inference /enterprise /stream_management /api /stream_manager_client.py
Fucius's picture
Upload 422 files
df6c67d verified
import asyncio
import json
from asyncio import StreamReader, StreamWriter
from json import JSONDecodeError
from typing import Optional, Tuple
from inference.core import logger
from inference.enterprise.stream_management.api.entities import (
CommandContext,
CommandResponse,
InferencePipelineStatusResponse,
ListPipelinesResponse,
PipelineInitialisationRequest,
)
from inference.enterprise.stream_management.api.errors import (
ConnectivityError,
ProcessesManagerAuthorisationError,
ProcessesManagerClientError,
ProcessesManagerInternalError,
ProcessesManagerInvalidPayload,
ProcessesManagerNotFoundError,
ProcessesManagerOperationError,
)
from inference.enterprise.stream_management.manager.entities import (
ERROR_TYPE_KEY,
PIPELINE_ID_KEY,
REQUEST_ID_KEY,
RESPONSE_KEY,
STATUS_KEY,
TYPE_KEY,
CommandType,
ErrorType,
OperationStatus,
)
from inference.enterprise.stream_management.manager.errors import (
CommunicationProtocolError,
MalformedHeaderError,
MalformedPayloadError,
MessageToBigError,
TransmissionChannelClosed,
)
BUFFER_SIZE = 16384
HEADER_SIZE = 4
ERRORS_MAPPING = {
ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError,
ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload,
ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError,
ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError,
ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError,
}
class StreamManagerClient:
@classmethod
def init(
cls,
host: str,
port: int,
operations_timeout: Optional[float] = None,
header_size: int = HEADER_SIZE,
buffer_size: int = BUFFER_SIZE,
) -> "StreamManagerClient":
return cls(
host=host,
port=port,
operations_timeout=operations_timeout,
header_size=header_size,
buffer_size=buffer_size,
)
def __init__(
self,
host: str,
port: int,
operations_timeout: Optional[float],
header_size: int,
buffer_size: int,
):
self._host = host
self._port = port
self._operations_timeout = operations_timeout
self._header_size = header_size
self._buffer_size = buffer_size
async def list_pipelines(self) -> ListPipelinesResponse:
command = {
TYPE_KEY: CommandType.LIST_PIPELINES,
}
response = await self._handle_command(command=command)
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
pipelines = response[RESPONSE_KEY]["pipelines"]
return ListPipelinesResponse(
status=status,
context=context,
pipelines=pipelines,
)
async def initialise_pipeline(
self, initialisation_request: PipelineInitialisationRequest
) -> CommandResponse:
command = initialisation_request.dict(exclude_none=True)
command[TYPE_KEY] = CommandType.INIT
response = await self._handle_command(command=command)
return build_response(response=response)
async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.TERMINATE,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def pause_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.MUTE,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def resume_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.RESUME,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse:
command = {
TYPE_KEY: CommandType.STATUS,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
report = response[RESPONSE_KEY]["report"]
return InferencePipelineStatusResponse(
status=status,
context=context,
report=report,
)
async def _handle_command(self, command: dict) -> dict:
response = await send_command(
host=self._host,
port=self._port,
command=command,
header_size=self._header_size,
buffer_size=self._buffer_size,
timeout=self._operations_timeout,
)
if is_request_unsuccessful(response=response):
dispatch_error(error_response=response)
return response
async def send_command(
host: str,
port: int,
command: dict,
header_size: int,
buffer_size: int,
timeout: Optional[float] = None,
) -> dict:
try:
reader, writer = await establish_socket_connection(
host=host, port=port, timeout=timeout
)
await send_message(
writer=writer, message=command, header_size=header_size, timeout=timeout
)
data = await receive_message(
reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout
)
writer.close()
await writer.wait_closed()
return json.loads(data)
except JSONDecodeError as error:
raise MalformedPayloadError(
f"Could not decode response. Cause: {error}"
) from error
except (OSError, asyncio.TimeoutError) as errors:
raise ConnectivityError(
f"Could not communicate with Process Manager"
) from errors
async def establish_socket_connection(
host: str, port: int, timeout: Optional[float] = None
) -> Tuple[StreamReader, StreamWriter]:
return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)
async def send_message(
writer: StreamWriter,
message: dict,
header_size: int,
timeout: Optional[float] = None,
) -> None:
try:
body = json.dumps(message).encode("utf-8")
header = len(body).to_bytes(length=header_size, byteorder="big")
payload = header + body
writer.write(payload)
await asyncio.wait_for(writer.drain(), timeout=timeout)
except TypeError as error:
raise MalformedPayloadError(f"Could not serialise message. Details: {error}")
except OverflowError as error:
raise MessageToBigError(
f"Could not send message due to size overflow. Details: {error}"
)
except asyncio.TimeoutError as error:
raise ConnectivityError(
f"Could not communicate with Process Manager"
) from error
except Exception as error:
raise CommunicationProtocolError(
f"Could not send message. Cause: {error}"
) from error
async def receive_message(
reader: StreamReader,
header_size: int,
buffer_size: int,
timeout: Optional[float] = None,
) -> bytes:
header = await asyncio.wait_for(reader.read(header_size), timeout=timeout)
if len(header) != header_size:
raise MalformedHeaderError("Header size missmatch")
payload_size = int.from_bytes(bytes=header, byteorder="big")
received = b""
while len(received) < payload_size:
chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout)
if len(chunk) == 0:
raise TransmissionChannelClosed(
"Socket was closed to read before payload was decoded."
)
received += chunk
return received
def is_request_unsuccessful(response: dict) -> bool:
return (
response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value)
!= OperationStatus.SUCCESS.value
)
def dispatch_error(error_response: dict) -> None:
response_payload = error_response.get(RESPONSE_KEY, {})
error_type = response_payload.get(ERROR_TYPE_KEY)
error_class = response_payload.get("error_class", "N/A")
error_message = response_payload.get("error_message", "N/A")
logger.error(
f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} "
f"error_message={error_message}"
)
if error_type in ERRORS_MAPPING:
raise ERRORS_MAPPING[error_type](
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}"
)
raise ProcessesManagerClientError(
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}"
)
def build_response(response: dict) -> CommandResponse:
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
return CommandResponse(
status=status,
context=context,
)