File size: 5,112 Bytes
3650c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import os
import random
import torch
import pickle
import numpy as np

from typing import List, Tuple

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


class FeatureStats:
    '''
    Class to store statistics of features, including all features and mean/covariance.

    Args:
        capture_all: Whether to store all the features.
        capture_mean_cov: Whether to store mean and covariance.
        max_items: Maximum number of items to store.
    '''
    def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
        '''
        '''
        self.capture_all = capture_all
        self.capture_mean_cov = capture_mean_cov
        self.max_items = max_items
        self.num_items = 0
        self.num_features = None
        self.all_features = None
        self.raw_mean = None
        self.raw_cov = None

    def set_num_features(self, num_features: int):
        '''
        Set the number of features diminsions.

        Args:
            num_features: Number of features diminsions.
        '''
        if self.num_features is not None:
            assert num_features == self.num_features
        else:
            self.num_features = num_features
            self.all_features = []
            self.raw_mean = np.zeros([num_features], dtype=np.float64)
            self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)

    def is_full(self) -> bool:
        '''
        Check if the maximum number of samples is reached.

        Returns:
            True if the storage is full, False otherwise.
        '''
        return (self.max_items is not None) and (self.num_items >= self.max_items)

    def append(self, x: np.ndarray):
        '''
        Add the newly computed features to the list. Update the mean and covariance.

        Args:
            x: New features to record.
        '''
        x = np.asarray(x, dtype=np.float32)
        assert x.ndim == 2
        if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
            if self.num_items >= self.max_items:
                return
            x = x[:self.max_items - self.num_items]

        self.set_num_features(x.shape[1])
        self.num_items += x.shape[0]
        if self.capture_all:
            self.all_features.append(x)
        if self.capture_mean_cov:
            x64 = x.astype(np.float64)
            self.raw_mean += x64.sum(axis=0)
            self.raw_cov += x64.T @ x64

    def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
        '''
        Add the newly computed PyTorch features to the list. Update the mean and covariance.

        Args:
            x: New features to record.
            rank: Rank of the current GPU.
            num_gpus: Total number of GPUs.
        '''
        assert isinstance(x, torch.Tensor) and x.ndim == 2
        assert 0 <= rank < num_gpus
        if num_gpus > 1:
            ys = []
            for src in range(num_gpus):
                y = x.clone()
                torch.distributed.broadcast(y, src=src)
                ys.append(y)
            x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
        self.append(x.cpu().numpy())

    def get_all(self) -> np.ndarray:
        '''
        Get all the stored features as NumPy Array.

        Returns:
            Concatenation of the stored features.
        '''
        assert self.capture_all
        return np.concatenate(self.all_features, axis=0)

    def get_all_torch(self) -> torch.Tensor:
        '''
        Get all the stored features as PyTorch Tensor.

        Returns:
            Concatenation of the stored features.
        '''
        return torch.from_numpy(self.get_all())

    def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Get the mean and covariance of the stored features.

        Returns:
            Mean and covariance of the stored features.
        '''
        assert self.capture_mean_cov
        mean = self.raw_mean / self.num_items
        cov = self.raw_cov / self.num_items
        cov = cov - np.outer(mean, mean)
        return mean, cov

    def save(self, pkl_file: str):
        '''
        Save the features and statistics to a pickle file.

        Args:
            pkl_file: Path to the pickle file.
        '''
        with open(pkl_file, 'wb') as f:
            pickle.dump(self.__dict__, f)

    @staticmethod
    def load(pkl_file: str) -> 'FeatureStats':
        '''
        Load the features and statistics from a pickle file.

        Args:
            pkl_file: Path to the pickle file.
        '''
        with open(pkl_file, 'rb') as f:
            s = pickle.load(f)
        obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
        obj.__dict__.update(s)
        print('Loaded %d features from %s' % (obj.num_items, pkl_file))
        return obj