Spaces:
Sleeping
Sleeping
File size: 5,423 Bytes
bd1896f a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 586d2b6 a235944 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from __future__ import annotations
import logging
from pathlib import PosixPath
from typing import Any
import cv2
import numpy as np
import rerun as rr
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from PIL import Image
from tqdm import tqdm
logger = logging.getLogger(__name__)
def get_frame(
video_path: PosixPath, timestamp: float, video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None
) -> np.ndarray:
"""
Extracts a specific frame from a video.
`video_path`: path to the video.
`timestamp`: timestamp of the wanted frame.
`video_cache`: cache to prevent reading the same video file twice.
"""
if video_cache is None:
video_cache = {}
if video_path not in video_cache:
cap = cv2.VideoCapture(str(video_path))
frames = []
while cap.isOpened():
success, frame = cap.read()
if success:
frames.append(frame)
else:
break
frame_rate = cap.get(cv2.CAP_PROP_FPS)
video_cache[video_path] = (frames, frame_rate)
frames, frame_rate = video_cache[video_path]
return frames[int(timestamp * frame_rate)]
def to_rerun(
column_name: str,
value: Any,
video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None,
videos_dir: PosixPath | None = None,
) -> Any:
"""Do our best to interpret the value and convert it to a Rerun-compatible archetype."""
if isinstance(value, Image.Image):
if "depth" in column_name:
return rr.DepthImage(value)
else:
return rr.Image(value)
elif isinstance(value, np.ndarray):
return rr.Tensor(value)
elif isinstance(value, list):
if isinstance(value[0], float):
return rr.BarChart(value)
else:
return rr.TextDocument(str(value)) # Fallback to text
elif isinstance(value, float) or isinstance(value, int):
return rr.Scalar(value)
elif isinstance(value, torch.Tensor):
if value.dim() == 0:
return rr.Scalar(value.item())
elif value.dim() == 1:
return rr.BarChart(value)
elif value.dim() == 2 and "depth" in column_name:
return rr.DepthImage(value)
elif value.dim() == 2:
return rr.Image(value)
elif value.dim() == 3 and (value.shape[2] == 3 or value.shape[2] == 4):
return rr.Image(value) # Treat it as a RGB or RGBA image
else:
return rr.Tensor(value)
elif isinstance(value, dict) and "path" in value and "timestamp" in value:
path = (videos_dir or PosixPath("./")) / PosixPath(value["path"])
timestamp = value["timestamp"]
return rr.Image(get_frame(path, timestamp, video_cache=video_cache))
else:
return rr.TextDocument(str(value)) # Fallback to text
def log_lerobot_dataset_to_rerun(dataset: LeRobotDataset, episode_index: int) -> None:
# Special time-like columns for LeRobot datasets (https://huggingface.co/lerobot/):
TIME_LIKE = {"index", "frame_id", "timestamp"}
# Ignore these columns (again, LeRobot-specific):
IGNORE = {"episode_data_index_from", "episode_data_index_to", "episode_id"}
hf_ds_subset = dataset.hf_dataset.filter(
lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index
)
video_cache: dict[PosixPath, tuple[np.ndarray, float]] = {}
for row in tqdm(hf_ds_subset):
# Handle time-like columns first, since they set a state (time is an index in Rerun):
for column_name in TIME_LIKE:
if column_name in row:
cell = row[column_name]
if isinstance(cell, torch.Tensor) and cell.dim() == 0:
cell = cell.item()
if isinstance(cell, int):
rr.set_time_sequence(column_name, cell)
elif isinstance(cell, float):
rr.set_time_seconds(column_name, cell) # assume seconds
else:
print(f"Unknown time-like column {column_name} with value {cell}")
# Now log actual data columns:
for column_name, cell in row.items():
if column_name in TIME_LIKE or column_name in IGNORE:
continue
else:
rr.log(
column_name,
to_rerun(column_name, cell, video_cache=video_cache, videos_dir=dataset.videos_dir.parent),
)
def log_dataset_to_rerun(dataset: Any) -> None:
TIME_LIKE = {"index", "frame_id", "timestamp"}
for row in tqdm(dataset):
# Handle time-like columns first, since they set a state (time is an index in Rerun):
for column_name in TIME_LIKE:
if column_name in row:
cell = row[column_name]
if isinstance(cell, int):
rr.set_time_sequence(column_name, cell)
elif isinstance(cell, float):
rr.set_time_seconds(column_name, cell) # assume seconds
else:
print(f"Unknown time-like column {column_name} with value {cell}")
# Now log actual data columns:
for column_name, cell in row.items():
if column_name in TIME_LIKE:
continue
rr.log(column_name, to_rerun(column_name, cell))
|