ClothingGAN / netdissect /aceplotablate.py
mfrashad's picture
Init code
97069e1
raw
history blame
2.12 kB
import os, sys, argparse, json, shutil
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
import matplotlib
def main():
parser = argparse.ArgumentParser(description='ACE optimization utility',
prog='python -m netdissect.aceoptimize')
parser.add_argument('--classname', type=str, default=None,
help='intervention classname')
parser.add_argument('--layer', type=str, default='layer4',
help='layer name')
parser.add_argument('--outdir', type=str, default=None,
help='dissection directory')
parser.add_argument('--metric', type=str, default=None,
help='experiment variant')
args = parser.parse_args()
if args.metric is None:
args.metric = 'ace'
run_command(args)
def run_command(args):
fig = Figure(figsize=(4.5,3.5))
FigureCanvas(fig)
ax = fig.add_subplot(111)
for metric in [args.metric, 'iou']:
jsonname = os.path.join(args.outdir, args.layer, 'fullablation',
'%s-%s.json' % (args.classname, metric))
with open(jsonname) as f:
summary = json.load(f)
baseline = summary['baseline']
effects = summary['ablation_effects'][:26]
norm_effects = [0] + [1.0 - e / baseline for e in effects]
ax.plot(norm_effects, label=
'Units by ACE' if 'ace' in metric else 'Top units by IoU')
ax.set_title('Effect of ablating units for %s' % (args.classname))
ax.grid(True)
ax.legend()
ax.set_ylabel('Portion of %s pixels removed' % args.classname)
ax.set_xlabel('Number of units ablated')
ax.set_ylim(0, 1.0)
ax.set_xlim(0, 25)
fig.tight_layout()
dirname = os.path.join(args.outdir, args.layer, 'fullablation')
fig.savefig(os.path.join(dirname, 'effect-%s-%s.png' %
(args.classname, args.metric)))
fig.savefig(os.path.join(dirname, 'effect-%s-%s.pdf' %
(args.classname, args.metric)))
if __name__ == '__main__':
main()