File size: 5,423 Bytes
bd1896f
 
a235944
586d2b6
a235944
 
586d2b6
a235944
 
586d2b6
 
a235944
 
 
 
 
 
586d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a235944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a235944
 
 
 
586d2b6
 
a235944
 
 
 
 
586d2b6
 
 
 
 
 
 
a235944
 
 
 
586d2b6
 
a235944
 
 
 
 
 
 
 
 
 
 
586d2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a235944
586d2b6
 
 
 
a235944
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations

import logging
from pathlib import PosixPath
from typing import Any

import cv2
import numpy as np
import rerun as rr
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from PIL import Image
from tqdm import tqdm

logger = logging.getLogger(__name__)


def get_frame(
    video_path: PosixPath, timestamp: float, video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None
) -> np.ndarray:
    """
    Extracts a specific frame from a video.

    `video_path`: path to the video.
    `timestamp`: timestamp of the wanted frame.
    `video_cache`: cache to prevent reading the same video file twice.
    """

    if video_cache is None:
        video_cache = {}
    if video_path not in video_cache:
        cap = cv2.VideoCapture(str(video_path))
        frames = []
        while cap.isOpened():
            success, frame = cap.read()
            if success:
                frames.append(frame)
            else:
                break
        frame_rate = cap.get(cv2.CAP_PROP_FPS)
        video_cache[video_path] = (frames, frame_rate)

    frames, frame_rate = video_cache[video_path]
    return frames[int(timestamp * frame_rate)]


def to_rerun(
    column_name: str,
    value: Any,
    video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None,
    videos_dir: PosixPath | None = None,
) -> Any:
    """Do our best to interpret the value and convert it to a Rerun-compatible archetype."""
    if isinstance(value, Image.Image):
        if "depth" in column_name:
            return rr.DepthImage(value)
        else:
            return rr.Image(value)
    elif isinstance(value, np.ndarray):
        return rr.Tensor(value)
    elif isinstance(value, list):
        if isinstance(value[0], float):
            return rr.BarChart(value)
        else:
            return rr.TextDocument(str(value))  # Fallback to text
    elif isinstance(value, float) or isinstance(value, int):
        return rr.Scalar(value)
    elif isinstance(value, torch.Tensor):
        if value.dim() == 0:
            return rr.Scalar(value.item())
        elif value.dim() == 1:
            return rr.BarChart(value)
        elif value.dim() == 2 and "depth" in column_name:
            return rr.DepthImage(value)
        elif value.dim() == 2:
            return rr.Image(value)
        elif value.dim() == 3 and (value.shape[2] == 3 or value.shape[2] == 4):
            return rr.Image(value)  # Treat it as a RGB or RGBA image
        else:
            return rr.Tensor(value)
    elif isinstance(value, dict) and "path" in value and "timestamp" in value:
        path = (videos_dir or PosixPath("./")) / PosixPath(value["path"])
        timestamp = value["timestamp"]
        return rr.Image(get_frame(path, timestamp, video_cache=video_cache))
    else:
        return rr.TextDocument(str(value))  # Fallback to text


def log_lerobot_dataset_to_rerun(dataset: LeRobotDataset, episode_index: int) -> None:
    # Special time-like columns for LeRobot datasets (https://huggingface.co/lerobot/):
    TIME_LIKE = {"index", "frame_id", "timestamp"}

    # Ignore these columns (again, LeRobot-specific):
    IGNORE = {"episode_data_index_from", "episode_data_index_to", "episode_id"}

    hf_ds_subset = dataset.hf_dataset.filter(
        lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index
    )

    video_cache: dict[PosixPath, tuple[np.ndarray, float]] = {}

    for row in tqdm(hf_ds_subset):
        # Handle time-like columns first, since they set a state (time is an index in Rerun):
        for column_name in TIME_LIKE:
            if column_name in row:
                cell = row[column_name]
                if isinstance(cell, torch.Tensor) and cell.dim() == 0:
                    cell = cell.item()
                if isinstance(cell, int):
                    rr.set_time_sequence(column_name, cell)
                elif isinstance(cell, float):
                    rr.set_time_seconds(column_name, cell)  # assume seconds
                else:
                    print(f"Unknown time-like column {column_name} with value {cell}")

        # Now log actual data columns:
        for column_name, cell in row.items():
            if column_name in TIME_LIKE or column_name in IGNORE:
                continue
            else:
                rr.log(
                    column_name,
                    to_rerun(column_name, cell, video_cache=video_cache, videos_dir=dataset.videos_dir.parent),
                )


def log_dataset_to_rerun(dataset: Any) -> None:
    TIME_LIKE = {"index", "frame_id", "timestamp"}

    for row in tqdm(dataset):
        # Handle time-like columns first, since they set a state (time is an index in Rerun):
        for column_name in TIME_LIKE:
            if column_name in row:
                cell = row[column_name]
                if isinstance(cell, int):
                    rr.set_time_sequence(column_name, cell)
                elif isinstance(cell, float):
                    rr.set_time_seconds(column_name, cell)  # assume seconds
                else:
                    print(f"Unknown time-like column {column_name} with value {cell}")

        # Now log actual data columns:
        for column_name, cell in row.items():
            if column_name in TIME_LIKE:
                continue
            rr.log(column_name, to_rerun(column_name, cell))