import pytorch_lightning as pl import numpy as np import torch import PIL import os import random from skimage.io import imread import webdataset as wds import PIL.Image as Image from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler from pathlib import Path # from ldm.base_utils import read_pickle, pose_inverse import torchvision.transforms as transforms import torchvision from einops import rearrange def add_margin(pil_img, color=0, size=256): width, height = pil_img.size result = Image.new(pil_img.mode, (size, size), color) result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) return result def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256): image_input = Image.open(image_path) if crop_size!=-1: alpha_np = np.asarray(image_input)[:, :, 3] coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] min_x, min_y = np.min(coords, 0) max_x, max_y = np.max(coords, 0) ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) h, w = ref_img_.height, ref_img_.width scale = crop_size / max(h, w) h_, w_ = int(scale * h), int(scale * w) ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) image_input = add_margin(ref_img_, size=image_size) else: image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC) image_input = np.asarray(image_input) image_input = image_input.astype(np.float32) / 255.0 ref_mask = image_input[:, :, 3:] image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background image_input = image_input[:, :, :3] * 2.0 - 1.0 image_input = torch.from_numpy(image_input.astype(np.float32)) elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) return {"input_image": image_input, "input_elevation": elevation_input} class VideoTrainDataset(Dataset): def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', width=1024, height=576, sample_frames=25): """ Args: num_samples (int): Number of samples in the dataset. channels (int): Number of channels, default is 3 for RGB. """ # Define the path to the folder containing video frames self.base_folder = base_folder self.folders = os.listdir(self.base_folder) self.num_samples = len(self.folders) self.channels = 3 self.width = width self.height = height self.sample_frames = sample_frames self.elevations = [-10, 0, 10, 20, 30, 40] def __len__(self): return self.num_samples def load_im(self, path): img = imread(path) img = img.astype(np.float32) / 255.0 mask = img[:,:,3:] img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) return img, mask def __getitem__(self, idx): """ Args: idx (int): Index of the sample to return. Returns: dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). """ # Randomly select a folder (representing a video) from the base folder chosen_folder = random.choice(self.folders) folder_path = os.path.join(self.base_folder, chosen_folder) frames = os.listdir(folder_path) # Sort the frames by name frames.sort() # Ensure the selected folder has at least `sample_frames`` frames if len(frames) < self.sample_frames: raise ValueError( f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.") # Randomly select a start index for frame sequence. Fixed elevation start_idx = random.randint(0, len(frames) - 1) range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5 elevation = self.elevations[range_id] selected_frames = [] for frame_idx in range(start_idx, (range_id + 1) * 16): selected_frames.append(frames[frame_idx]) for frame_idx in range((range_id) * 16, start_idx): selected_frames.append(frames[frame_idx]) # Initialize a tensor to store the pixel values pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width)) # Load and process each frame for i, frame_name in enumerate(selected_frames): frame_path = os.path.join(folder_path, frame_name) img, mask = self.load_im(frame_path) # Resize the image and convert it to a tensor img_resized = img.resize((self.width, self.height)) img_tensor = torch.from_numpy(np.array(img_resized)).float() # Normalize the image by scaling pixel values to [-1, 1] img_normalized = img_tensor / 127.5 - 1 # Rearrange channels if necessary if self.channels == 3: img_normalized = img_normalized.permute( 2, 0, 1) # For RGB images elif self.channels == 1: img_normalized = img_normalized.mean( dim=2, keepdim=True) # For grayscale images pixel_values[i] = img_normalized pixel_values = rearrange(pixel_values, 't c h w -> c t h w') caption = chosen_folder + "_" + str(start_idx) return {'video': pixel_values, 'elevation': elevation, 'caption': caption, "fps_id": 7, "motion_bucket_id": 127} class SyncDreamerEvalData(Dataset): def __init__(self, image_dir): self.image_size = 512 self.image_dir = Path(image_dir) self.crop_size = 20 self.fns = [] for fn in Path(image_dir).iterdir(): if fn.suffix=='.png': self.fns.append(fn) print('============= length of dataset %d =============' % len(self.fns)) def __len__(self): return len(self.fns) def get_data_for_index(self, index): input_img_fn = self.fns[index] elevation = 0 return prepare_inputs(input_img_fn, elevation, 512) def __getitem__(self, index): return self.get_data_for_index(index) class VideoDataset(pl.LightningDataModule): def __init__(self, base_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs): super().__init__() self.base_folder = base_folder self.eval_folder = eval_folder self.width = width self.height = height self.sample_frames = sample_frames self.batch_size = batch_size self.num_workers = num_workers self.seed = seed self.additional_args = kwargs def setup(self): self.train_dataset = VideoTrainDataset(self.base_folder, self.width, self.height, self.sample_frames) self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder) def train_dataloader(self): sampler = DistributedSampler(self.train_dataset, seed=self.seed) return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) def val_dataloader(self): loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) return loader def test_dataloader(self): return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)