Spaces:
Runtime error
Runtime error
import json | |
import socket | |
from typing import Optional | |
from inference.core import logger | |
from inference.enterprise.stream_management.manager.entities import ErrorType | |
from inference.enterprise.stream_management.manager.errors import ( | |
MalformedHeaderError, | |
MalformedPayloadError, | |
TransmissionChannelClosed, | |
) | |
from inference.enterprise.stream_management.manager.serialisation import ( | |
prepare_error_response, | |
) | |
def receive_socket_data( | |
source: socket.socket, header_size: int, buffer_size: int | |
) -> dict: | |
header = source.recv(header_size) | |
if len(header) != header_size: | |
raise MalformedHeaderError( | |
f"Expected header size: {header_size}, received: {header}" | |
) | |
payload_size = int.from_bytes(bytes=header, byteorder="big") | |
if payload_size <= 0: | |
raise MalformedHeaderError( | |
f"Header is indicating non positive payload size: {payload_size}" | |
) | |
received = b"" | |
while len(received) < payload_size: | |
chunk = source.recv(buffer_size) | |
if len(chunk) == 0: | |
raise TransmissionChannelClosed( | |
"Socket was closed to read before payload was decoded." | |
) | |
received += chunk | |
try: | |
return json.loads(received) | |
except ValueError: | |
raise MalformedPayloadError("Received payload that is not in a JSON format") | |
def send_data_trough_socket( | |
target: socket.socket, | |
header_size: int, | |
data: bytes, | |
request_id: str, | |
recover_from_overflow: bool = True, | |
pipeline_id: Optional[str] = None, | |
) -> None: | |
try: | |
data_size = len(data) | |
header = data_size.to_bytes(length=header_size, byteorder="big") | |
payload = header + data | |
target.sendall(payload) | |
except OverflowError as error: | |
if not recover_from_overflow: | |
logger.error(f"OverflowError was suppressed. {error}") | |
return None | |
error_response = prepare_error_response( | |
request_id=request_id, | |
error=error, | |
error_type=ErrorType.INTERNAL_ERROR, | |
pipeline_id=pipeline_id, | |
) | |
send_data_trough_socket( | |
target=target, | |
header_size=header_size, | |
data=error_response, | |
request_id=request_id, | |
recover_from_overflow=False, | |
pipeline_id=pipeline_id, | |
) | |
except Exception as error: | |
logger.error(f"Could not send the response through socket. Error: {error}") | |