File size: 2,490 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import json
import socket
from typing import Optional

from inference.core import logger
from inference.enterprise.stream_management.manager.entities import ErrorType
from inference.enterprise.stream_management.manager.errors import (
    MalformedHeaderError,
    MalformedPayloadError,
    TransmissionChannelClosed,
)
from inference.enterprise.stream_management.manager.serialisation import (
    prepare_error_response,
)


def receive_socket_data(
    source: socket.socket, header_size: int, buffer_size: int
) -> dict:
    header = source.recv(header_size)
    if len(header) != header_size:
        raise MalformedHeaderError(
            f"Expected header size: {header_size}, received: {header}"
        )
    payload_size = int.from_bytes(bytes=header, byteorder="big")
    if payload_size <= 0:
        raise MalformedHeaderError(
            f"Header is indicating non positive payload size: {payload_size}"
        )
    received = b""
    while len(received) < payload_size:
        chunk = source.recv(buffer_size)
        if len(chunk) == 0:
            raise TransmissionChannelClosed(
                "Socket was closed to read before payload was decoded."
            )
        received += chunk
    try:
        return json.loads(received)
    except ValueError:
        raise MalformedPayloadError("Received payload that is not in a JSON format")


def send_data_trough_socket(
    target: socket.socket,
    header_size: int,
    data: bytes,
    request_id: str,
    recover_from_overflow: bool = True,
    pipeline_id: Optional[str] = None,
) -> None:
    try:
        data_size = len(data)
        header = data_size.to_bytes(length=header_size, byteorder="big")
        payload = header + data
        target.sendall(payload)
    except OverflowError as error:
        if not recover_from_overflow:
            logger.error(f"OverflowError was suppressed. {error}")
            return None
        error_response = prepare_error_response(
            request_id=request_id,
            error=error,
            error_type=ErrorType.INTERNAL_ERROR,
            pipeline_id=pipeline_id,
        )
        send_data_trough_socket(
            target=target,
            header_size=header_size,
            data=error_response,
            request_id=request_id,
            recover_from_overflow=False,
            pipeline_id=pipeline_id,
        )
    except Exception as error:
        logger.error(f"Could not send the response through socket. Error: {error}")