Spaces:
Runtime error
Runtime error
from typing import Literal, Optional | |
import requests | |
from pydantic import BaseModel | |
import docker | |
from inference.core.devices.utils import GLOBAL_DEVICE_ID | |
from inference.core.env import API_BASE_URL, API_KEY | |
from inference.core.logger import logger | |
from inference.core.utils.url_utils import wrap_url | |
from inference.enterprise.device_manager.container_service import get_container_by_id | |
class Command(BaseModel): | |
id: str | |
containerId: str | |
command: Literal["restart", "stop", "ping", "snapshot", "update_version"] | |
deviceId: str | |
requested_on: Optional[int] = None | |
def fetch_commands(): | |
url = wrap_url( | |
f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands?api_key={API_KEY}" | |
) | |
resp = requests.get(url).json() | |
for cmd in resp.get("data", []): | |
handle_command(cmd) | |
def handle_command(cmd_payload: dict): | |
was_processed = False | |
container_id = cmd_payload.get("containerId") | |
container = get_container_by_id(container_id) | |
if not container: | |
logger.warn(f"Container with id {container_id} not found") | |
ack_command(cmd_payload.get("id"), was_processed) | |
return | |
cmd = cmd_payload.get("command") | |
data = None | |
match cmd: | |
case "restart": | |
was_processed, data = container.restart() | |
case "stop": | |
was_processed, data = container.stop() | |
case "ping": | |
was_processed, data = container.ping() | |
case "snapshot": | |
was_processed, data = container.snapshot() | |
case "start": | |
was_processed, data = container.start() | |
case "update_version": | |
was_processed, data = handle_version_update(container) | |
case _: | |
logger.error("Unknown command: {}".format(cmd)) | |
return ack_command(cmd_payload.get("id"), was_processed, data=data) | |
def ack_command(command_id, was_processed, data=None): | |
post_body = dict() | |
post_body["api_key"] = API_KEY | |
post_body["commandId"] = command_id | |
post_body["wasProcessed"] = was_processed | |
if data: | |
post_body["data"] = data | |
url = wrap_url(f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands/ack") | |
requests.post(url, json=post_body) | |
def handle_version_update(container): | |
try: | |
config = container.get_startup_config() | |
image_name = config["image"].split(":")[0] | |
container.kill() | |
client = docker.from_env() | |
new_container = client.containers.run( | |
image=f"{image_name}:latest", | |
detach=config["detach"], | |
privileged=config["privileged"], | |
labels=config["labels"], | |
ports=config["port_bindings"], | |
environment=config["env"], | |
network="host", | |
) | |
logger.info(f"New container started {new_container}") | |
return True, None | |
except Exception as e: | |
logger.error(e) | |
return False, None | |