File size: 5,817 Bytes
07c6a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import os

import imageio
import torch
import torchvision.transforms.functional as F
import tqdm
from calculate_lpips import calculate_lpips
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim


def load_videos(directory, video_ids, file_extension):
    videos = []
    for video_id in video_ids:
        video_path = os.path.join(directory, f"{video_id}.{file_extension}")
        if os.path.exists(video_path):
            video = load_video(video_path)  # Define load_video based on how videos are stored
            videos.append(video)
        else:
            raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
    return videos


def load_video(video_path):
    """
    Load a video from the given path and convert it to a PyTorch tensor.
    """
    # Read the video using imageio
    reader = imageio.get_reader(video_path, "ffmpeg")

    # Extract frames and convert to a list of tensors
    frames = []
    for frame in reader:
        # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
        frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
        frames.append(frame_tensor)

    # Stack the list of tensors into a single tensor with shape (T, C, H, W)
    video_tensor = torch.stack(frames)

    return video_tensor


def resize_video(video, target_height, target_width):
    resized_frames = []
    for frame in video:
        resized_frame = F.resize(frame, [target_height, target_width])
        resized_frames.append(resized_frame)
    return torch.stack(resized_frames)


def preprocess_eval_video(eval_video, generated_video_shape):
    T_gen, _, H_gen, W_gen = generated_video_shape
    T_eval, _, H_eval, W_eval = eval_video.shape

    if T_eval < T_gen:
        raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")

    if H_eval < H_gen or W_eval < W_gen:
        # Resize the video maintaining the aspect ratio
        resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
        resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
        eval_video = resize_video(eval_video, resize_height, resize_width)
        # Recalculate the dimensions
        T_eval, _, H_eval, W_eval = eval_video.shape

    # Center crop
    start_h = (H_eval - H_gen) // 2
    start_w = (W_eval - W_gen) // 2
    cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]

    return cropped_video


def main(args):
    device = "cuda"
    gt_video_dir = args.gt_video_dir
    generated_video_dir = args.generated_video_dir

    video_ids = []
    file_extension = "mp4"
    for f in os.listdir(generated_video_dir):
        if f.endswith(f".{file_extension}"):
            video_ids.append(f.replace(f".{file_extension}", ""))
    if not video_ids:
        raise ValueError("No videos found in the generated video dataset. Exiting.")

    print(f"Find {len(video_ids)} videos")
    prompt_interval = 1
    batch_size = 16
    calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True

    lpips_results = []
    psnr_results = []
    ssim_results = []

    total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)

    for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
        gt_videos_tensor = []
        generated_videos_tensor = []
        for i in range(batch_size):
            video_idx = idx * batch_size + i
            if video_idx >= len(video_ids):
                break
            video_id = video_ids[video_idx]
            generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
            generated_videos_tensor.append(generated_video)
            eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
            gt_videos_tensor.append(eval_video)
        gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
        generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()

        if calculate_lpips_flag:
            result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
            result = result["value"].values()
            result = sum(result) / len(result)
            lpips_results.append(result)

        if calculate_psnr_flag:
            result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
            result = result["value"].values()
            result = sum(result) / len(result)
            psnr_results.append(result)

        if calculate_ssim_flag:
            result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
            result = result["value"].values()
            result = sum(result) / len(result)
            ssim_results.append(result)

        if (idx + 1) % prompt_interval == 0:
            out_str = ""
            for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
                result = sum(results) / len(results)
                out_str += f"{name}: {result:.4f}, "
            print(f"Processed {idx + 1} videos. {out_str[:-2]}")

    out_str = ""
    for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
        result = sum(results) / len(results)
        out_str += f"{name}: {result:.4f}, "
    out_str = out_str[:-2]

    # save
    with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
        f.write(out_str)

    print(f"Processed all videos. {out_str}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gt_video_dir", type=str)
    parser.add_argument("--generated_video_dir", type=str)

    args = parser.parse_args()

    main(args)