File size: 7,085 Bytes
29cdbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
'''
A simple tool to generate sample of output of a GAN,
subject to filtering, sorting, or intervention.
'''

import torch, numpy, os, argparse, numbers, sys, shutil
from PIL import Image
from torch.utils.data import TensorDataset
from netdissect.zdataset import standard_z_sample
from netdissect.progress import default_progress, verbose_progress
from netdissect.autoeval import autoimport_eval
from netdissect.workerpool import WorkerBase, WorkerPool
from netdissect.nethook import edit_layers, retain_layers

def main():
    parser = argparse.ArgumentParser(description='GAN sample making utility')
    parser.add_argument('--model', type=str, default=None,
            help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
            help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='images',
            help='directory for image output')
    parser.add_argument('--size', type=int, default=100,
            help='number of images to output')
    parser.add_argument('--test_size', type=int, default=None,
            help='number of images to test')
    parser.add_argument('--layer', type=str, default=None,
            help='layer to inspect')
    parser.add_argument('--seed', type=int, default=1,
            help='seed')
    parser.add_argument('--maximize_units', type=int, nargs='+', default=None,
            help='units to maximize')
    parser.add_argument('--ablate_units', type=int, nargs='+', default=None,
            help='units to ablate')
    parser.add_argument('--quiet', action='store_true', default=False,
            help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    verbose_progress(not args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)
    if args.pthfile is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())
    # Examine first conv in model to determine input feature size.
    first_layer = [c for c in model.modules()
            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                torch.nn.Linear))][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        z_channels = first_layer.in_channels
        spatialdims = (1, 1)
    else:
        z_channels = first_layer.in_features
        spatialdims = ()
    # Instrument the model if needed
    if args.maximize_units is not None:
        retain_layers(model, [args.layer])
    model.cuda()

    # Get the sample of z vectors
    if args.maximize_units is None:
        indexes = torch.arange(args.size)
        z_sample = standard_z_sample(args.size, z_channels, seed=args.seed)
        z_sample = z_sample.view(tuple(z_sample.shape) + spatialdims)
    else:
        # By default, if maximizing units, get a 'top 5%' sample.
        if args.test_size is None:
            args.test_size = args.size * 20
        z_universe = standard_z_sample(args.test_size, z_channels,
                seed=args.seed)
        z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims)
        indexes = get_highest_znums(model, z_universe, args.maximize_units,
                args.size, seed=args.seed)
        z_sample = z_universe[indexes]

    if args.ablate_units:
        edit_layers(model, [args.layer])
        dims = max(2, max(args.ablate_units) + 1) # >=2 to avoid broadcast
        model.ablation[args.layer] = torch.zeros(dims)
        model.ablation[args.layer][args.ablate_units] = 1

    save_znum_images(args.outdir, model, z_sample, indexes,
            args.layer, args.ablate_units)
    copy_lightbox_to(args.outdir)


def get_highest_znums(model, z_universe, max_units, size,
        batch_size=100, seed=1):
    # The model should have been instrumented already
    retained_items = list(model.retained.items())
    assert len(retained_items) == 1
    layer = retained_items[0][0]
    # By default, a 10% sample
    progress = default_progress()
    num_units = None
    with torch.no_grad():
        # Pass 1: collect max activation stats
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        scores = []
        for [z] in progress(z_loader, desc='Finding max activations'):
            z = z.cuda()
            model(z)
            feature = model.retained[layer]
            num_units = feature.shape[1]
            max_feature = feature[:, max_units, ...].view(
                    feature.shape[0], len(max_units), -1).max(2)[0]
            total_feature = max_feature.sum(1)
            scores.append(total_feature.cpu())
        scores = torch.cat(scores, 0)
        highest = (-scores).sort(0)[1][:size].sort(0)[0]
    return highest


def save_znum_images(dirname, model, z_sample, indexes, layer, ablated_units,
        name_template="image_{}.png", lightbox=False, batch_size=100, seed=1):
    progress = default_progress()
    os.makedirs(dirname, exist_ok=True)
    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        saver = WorkerPool(SaveImageWorker)
        if ablated_units is not None:
            dims = max(2, max(ablated_units) + 1) # >=2 to avoid broadcast
            mask = torch.zeros(dims)
            mask[ablated_units] = 1
            model.ablation[layer] = mask[None,:,None,None].cuda()
        for batch_num, [z] in enumerate(progress(z_loader,
                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
                    0, 2, 3, 1).cpu()
            for i in range(len(im)):
                index = i + start_index
                if indexes is not None:
                    index = indexes[index].item()
                filename = os.path.join(dirname, name_template.format(index))
                saver.add(im[i].numpy(), filename)
    saver.join()

def copy_lightbox_to(dirname):
   srcdir = os.path.realpath(
       os.path.join(os.getcwd(), os.path.dirname(__file__)))
   shutil.copy(os.path.join(srcdir, 'lightbox.html'),
           os.path.join(dirname, '+lightbox.html'))

class SaveImageWorker(WorkerBase):
    def work(self, data, filename):
        Image.fromarray(data).save(filename, optimize=True, quality=100)

if __name__ == '__main__':
    main()