chanycha commited on
Commit
054bff7
·
1 Parent(s): db8d869
Files changed (4) hide show
  1. _app.py +0 -663
  2. app.py +658 -8
  3. app_test.py +13 -0
  4. requirements.txt +1 -2
_app.py DELETED
@@ -1,663 +0,0 @@
1
- # python image_gradio.py >> ./logs/image_gradio.log 2>&1
2
- import time
3
- import os
4
- import gradio as gr
5
- from pnpxai.core.experiment import AutoExplanationForImageClassification
6
- from pnpxai.core.detector import extract_graph_data, symbolic_trace
7
- import matplotlib.pyplot as plt
8
- import plotly.graph_objects as go
9
- import plotly.express as px
10
- import networkx as nx
11
- import secrets
12
-
13
-
14
- PLOT_PER_LINE = 4
15
- N_FEATURES_TO_SHOW = 5
16
- OPT_N_TRIALS = 10
17
- OBJECTIVE_METRIC = "AbPC"
18
- SAMPLE_METHOD = "tpe"
19
-
20
- class App:
21
- def __init__(self):
22
- pass
23
-
24
- class Component:
25
- def __init__(self):
26
- pass
27
-
28
- class Tab(Component):
29
- def __init__(self):
30
- pass
31
-
32
- class OverviewTab(Tab):
33
- def __init__(self):
34
- pass
35
-
36
- def show(self):
37
- with gr.Tab(label="Overview") as tab:
38
- gr.Label("This is the overview tab.")
39
-
40
- class DetectionTab(Tab):
41
- def __init__(self, experiments):
42
- self.experiments = experiments
43
-
44
- def show(self):
45
- with gr.Tab(label="Detection") as tab:
46
- gr.Label("This is the detection tab.")
47
-
48
- for nm, exp_info in self.experiments.items():
49
- exp = exp_info['experiment']
50
- detector_res = DetectorRes(exp)
51
- detector_res.show()
52
-
53
- class LocalExpTab(Tab):
54
- def __init__(self, experiments):
55
- self.experiments = experiments
56
-
57
- self.experiment_components = []
58
- for nm, exp_info in self.experiments.items():
59
- self.experiment_components.append(Experiment(exp_info))
60
-
61
- def description(self):
62
- return "This tab shows the local explanation."
63
-
64
- def show(self):
65
- with gr.Tab(label="Local Explanation") as tab:
66
- gr.Label("This is the local explanation tab.")
67
-
68
- for i, exp in enumerate(self.experiments):
69
- self.experiment_components[i].show()
70
-
71
- class DetectorRes(Component):
72
- def __init__(self, experiment):
73
- self.experiment = experiment
74
- graph_module = symbolic_trace(experiment.model)
75
- self.graph_data = extract_graph_data(graph_module)
76
-
77
- def describe(self):
78
- return "This component shows the detection result."
79
-
80
- def show(self):
81
- G = nx.DiGraph()
82
- root = None
83
- for node in self.graph_data['nodes']:
84
- if node['op'] == 'placeholder':
85
- root = node['name']
86
-
87
- G.add_node(node['name'])
88
-
89
-
90
- for edge in self.graph_data['edges']:
91
- if edge['source'] in G.nodes and edge['target'] in G.nodes:
92
- G.add_edge(edge['source'], edge['target'])
93
-
94
-
95
- def get_pos1(graph):
96
- graph = graph.copy()
97
- for layer, nodes in enumerate(reversed(tuple(nx.topological_generations(graph)))):
98
- for node in nodes:
99
- graph.nodes[node]["layer"] = layer
100
-
101
- pos = nx.multipartite_layout(graph, subset_key="layer", align='horizontal')
102
- return pos
103
-
104
-
105
- def get_pos2(graph, root, levels=None, width=1., height=1.):
106
- '''
107
- G: the graph
108
- root: the root node
109
- levels: a dictionary
110
- key: level number (starting from 0)
111
- value: number of nodes in this level
112
- width: horizontal space allocated for drawing
113
- height: vertical space allocated for drawing
114
- '''
115
- TOTAL = "total"
116
- CURRENT = "current"
117
-
118
- def make_levels(levels, node=root, currentLevel=0, parent=None):
119
- # Compute the number of nodes for each level
120
- if not currentLevel in levels:
121
- levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
122
- levels[currentLevel][TOTAL] += 1
123
- neighbors = graph.neighbors(node)
124
- for neighbor in neighbors:
125
- if not neighbor == parent:
126
- levels = make_levels(levels, neighbor, currentLevel + 1, node)
127
- return levels
128
-
129
- def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
130
- dx = 1/levels[currentLevel][TOTAL]
131
- left = dx/2
132
- pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc)
133
- levels[currentLevel][CURRENT] += 1
134
- neighbors = graph.neighbors(node)
135
- for neighbor in neighbors:
136
- if not neighbor == parent:
137
- pos = make_pos(pos, neighbor, currentLevel +
138
- 1, node, vert_loc-vert_gap)
139
- return pos
140
-
141
- if levels is None:
142
- levels = make_levels({})
143
- else:
144
- levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
145
- vert_gap = height / (max([l for l in levels])+1)
146
- return make_pos({})
147
-
148
-
149
- def plot_graph(graph, pos):
150
- fig = plt.figure(figsize=(12, 24))
151
- ax = fig.gca()
152
- nx.draw(graph, pos=pos, with_labels=True, node_size=60, font_size=8, ax=ax)
153
-
154
- fig.tight_layout()
155
- return fig
156
-
157
-
158
-
159
- pos = get_pos1(G)
160
- fig = plot_graph(G, pos)
161
- # pos = get_pos2(G, root)
162
- # fig = plot_graph(G, pos)
163
-
164
- with gr.Row():
165
- gr.Textbox(value="Image Classficiation", label="Task")
166
- gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model")
167
- gr.Plot(value=fig, label=f"Model Architecture of {self.experiment.model.__class__.__name__}", visible=True)
168
-
169
-
170
-
171
- class ImgGallery(Component):
172
- def __init__(self, imgs):
173
- self.imgs = imgs
174
- self.selected_index = gr.Number(value=0, label="Selected Index", visible=False)
175
-
176
- def on_select(self, evt: gr.SelectData):
177
- return evt.index
178
-
179
- def show(self):
180
- self.gallery_obj = gr.Gallery(value=self.imgs, label="Input Data Gallery", columns=6, height=200)
181
- self.gallery_obj.select(self.on_select, outputs=self.selected_index)
182
-
183
-
184
- class Experiment(Component):
185
- def __init__(self, exp_info):
186
- self.exp_info = exp_info
187
- self.experiment = exp_info['experiment']
188
- self.input_visualizer = exp_info['input_visualizer']
189
- self.target_visualizer = exp_info['target_visualizer']
190
-
191
- def viz_input(self, input, data_id):
192
- orig_img_np = self.input_visualizer(input)
193
- orig_img = px.imshow(orig_img_np)
194
-
195
- orig_img.update_layout(
196
- title=f"Data ID: {data_id}",
197
- width=400,
198
- height=350,
199
- xaxis=dict(
200
- showticklabels=False,
201
- ticks='',
202
- showgrid=False
203
- ),
204
- yaxis=dict(
205
- showticklabels=False,
206
- ticks='',
207
- showgrid=False
208
- ),
209
- )
210
-
211
- return orig_img
212
-
213
-
214
- def get_prediction(self, record, topk=3):
215
- probs = record['output'].softmax(-1).squeeze().detach().numpy()
216
- text = f"Ground Truth Label: {self.target_visualizer(record['label'])}\n"
217
-
218
- for ind, pred in enumerate(probs.argsort()[-topk:][::-1]):
219
- label = self.target_visualizer(torch.tensor(pred))
220
- prob = probs[pred]
221
- text += f"Top {ind+1} Prediction: {label} ({prob:.2f})\n"
222
-
223
- return text
224
-
225
-
226
- def get_exp_plot(self, data_index, exp_res):
227
- return ExpRes(data_index, exp_res).show()
228
-
229
- def get_metric_id_by_name(self, metric_name):
230
- metric_info = self.experiment.manager.get_metrics()
231
- idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
232
- return metric_info[1][idx]
233
-
234
- def generate_record(self, data_id, metric_names):
235
- record = {}
236
- _base = self.experiment.run_batch([data_id], 0, 0, 0)
237
- record['data_id'] = data_id
238
- record['input'] = _base['inputs']
239
- record['label'] = _base['labels']
240
- record['output'] = _base['outputs']
241
- record['target'] = _base['targets']
242
- record['explanations'] = []
243
-
244
- metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
245
-
246
- cnt = 0
247
- for info in self.explainer_checkbox_group.info:
248
- if info['checked']:
249
- base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
250
- record['explanations'].append({
251
- 'explainer_nm': base['explainer'].__class__.__name__,
252
- 'value': base['postprocessed'],
253
- 'mode' : info['mode'],
254
- 'evaluations': []
255
- })
256
- for metric_id in metrics_ids:
257
- res = self.experiment.run_batch([data_id], info['id'], info['pp_id'], metric_id)
258
- record['explanations'][-1]['evaluations'].append({
259
- 'metric_nm': res['metric'].__class__.__name__,
260
- 'value' : res['evaluation']
261
- })
262
-
263
- cnt += 1
264
-
265
- # Sort record['explanations'] with respect to the metric values
266
- if len(record['explanations'][0]['evaluations']) > 0:
267
- record['explanations'] = sorted(record['explanations'], key=lambda x: x['evaluations'][0]['value'], reverse=True)
268
-
269
- return record
270
-
271
-
272
- def show(self):
273
- with gr.Row():
274
- gr.Textbox(value="Image Classficiation", label="Task")
275
- gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model")
276
- gr.Textbox(value="Heatmap", label="Explanation Type")
277
-
278
- dset = self.experiment.manager._data.dataset
279
- imgs = []
280
- for i in range(len(dset)):
281
- img = self.input_visualizer(dset[i][0])
282
- imgs.append(img)
283
- gallery = ImgGallery(imgs)
284
- gallery.show()
285
-
286
- explainers, _ = self.experiment.manager.get_explainers()
287
- explainer_names = [exp.__class__.__name__ for exp in explainers]
288
-
289
- self.explainer_checkbox_group = ExplainerCheckboxGroup(explainer_names, self.experiment, gallery)
290
- self.explainer_checkbox_group.show()
291
-
292
- cr_metrics_names = ["AbPC", "MoRF", "LeRF", "MuFidelity"]
293
- cn_metrics_names = ["Sensitivity"]
294
- cp_metrics_names = ["Complexity"]
295
- with gr.Accordion("Evaluators", open=True):
296
- with gr.Row():
297
- cr_metrics = gr.CheckboxGroup(choices=cr_metrics_names, value=[cr_metrics_names[0]], label="Correctness")
298
- def on_select(metrics):
299
- if cr_metrics_names[0] not in metrics:
300
- gr.Warning(f"{cr_metrics_names[0]} is required for the sorting the explanations.")
301
- return [cr_metrics_names[0]] + metrics
302
- else:
303
- return metrics
304
-
305
- cr_metrics.select(on_select, inputs=cr_metrics, outputs=cr_metrics)
306
- with gr.Row():
307
- # cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, value=cn_metrics_names, label="Continuity")
308
- cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, label="Continuity")
309
- with gr.Row():
310
- # cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, value=cp_metrics_names[0], label="Compactness")
311
- cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, label="Compactness")
312
-
313
- metric_inputs = [cr_metrics, cn_metrics, cp_metrics]
314
-
315
- data_id = gallery.selected_index
316
- bttn = gr.Button("Explain", variant="primary")
317
-
318
- buffer_size = 2 * len(explainer_names)
319
- buffer_n_rows = buffer_size // PLOT_PER_LINE
320
- buffer_n_rows = buffer_n_rows + 1 if buffer_size % PLOT_PER_LINE != 0 else buffer_n_rows
321
-
322
- plots = [gr.Textbox(label="Prediction result", visible=False)]
323
- for i in range(buffer_n_rows):
324
- with gr.Row():
325
- for j in range(PLOT_PER_LINE):
326
- plot = gr.Image(value=None, label="Blank", visible=False)
327
- plots.append(plot)
328
-
329
- def show_plots():
330
- _plots = [gr.Textbox(label="Prediction result", visible=False)]
331
- num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
332
- n_rows = num_plots // PLOT_PER_LINE
333
- n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
334
- _plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
335
- _plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
336
- return _plots
337
-
338
- def render_plots(data_id, *metric_inputs):
339
- # Clear Cache Files
340
- cache_dir = f"{os.environ['GRADIO_TEMP_DIR']}/res"
341
- if not os.path.exists(cache_dir): os.makedirs(cache_dir)
342
- for f in os.listdir(cache_dir):
343
- if len(f.split(".")[0]) == 16:
344
- os.remove(os.path.join(cache_dir, f))
345
-
346
- # Render Plots
347
- metric_input = []
348
- for metric in metric_inputs:
349
- if metric:
350
- metric_input += metric
351
-
352
- record = self.generate_record(data_id, metric_input)
353
-
354
- pred = self.get_prediction(record)
355
- plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
356
-
357
- num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
358
- n_rows = num_plots // PLOT_PER_LINE
359
- n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
360
-
361
- for i in range(n_rows):
362
- for j in range(PLOT_PER_LINE):
363
- if i*PLOT_PER_LINE+j < len(record['explanations']):
364
- exp_res = record['explanations'][i*PLOT_PER_LINE+j]
365
- path = self.get_exp_plot(data_id, exp_res)
366
- plot_obj = gr.Image(value=path, label=f"{exp_res['explainer_nm']} ({exp_res['mode']})", visible=True)
367
- plots.append(plot_obj)
368
- else:
369
- plots.append(gr.Image(value=None, label="Blank", visible=True))
370
-
371
- plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
372
-
373
- return plots
374
-
375
- bttn.click(show_plots, outputs=plots)
376
- bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots)
377
-
378
-
379
-
380
- class ExplainerCheckboxGroup(Component):
381
- def __init__(self, explainer_names, experiment, gallery):
382
- super().__init__()
383
- self.explainer_names = explainer_names
384
- self.explainer_objs = []
385
- self.experiment = experiment
386
- self.gallery = gallery
387
- explainers, exp_ids = self.experiment.manager.get_explainers()
388
-
389
- self.info = []
390
- for exp, exp_id in zip(explainers, exp_ids):
391
- self.info.append({'nm': exp.__class__.__name__, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': True})
392
-
393
- def update_check(self, exp_id, val=None):
394
- for info in self.info:
395
- if info['id'] == exp_id:
396
- if val is not None:
397
- info['checked'] = val
398
- else:
399
- info['checked'] = not info['checked']
400
-
401
- def insert_check(self, exp_nm, exp_id, pp_id):
402
- if exp_id in [info['id'] for info in self.info]:
403
- return
404
-
405
- self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False})
406
-
407
- def update_gallery_change(self):
408
- checkboxes = []
409
- bttns = []
410
- checkboxes += [gr.Checkbox(label="Default Parameter", value=True, interactive=True)] * len(self.explainer_objs)
411
- checkboxes += [gr.Checkbox(label="Optimized Parameter (Not Optimal)", value=False, interactive=False)] * len(self.explainer_objs)
412
- bttns += [gr.Button(value="Optimize", size="sm", variant="primary")] * len(self.explainer_objs)
413
-
414
- for exp in self.explainer_objs:
415
- self.update_check(exp.default_exp_id, True)
416
- if hasattr(exp, "optimal_exp_id"):
417
- self.update_check(exp.optimal_exp_id, False)
418
- return checkboxes + bttns
419
-
420
- def get_checkboxes(self):
421
- checkboxes = []
422
- checkboxes += [exp.default_check for exp in self.explainer_objs]
423
- checkboxes += [exp.opt_check for exp in self.explainer_objs]
424
- return checkboxes
425
-
426
- def get_bttns(self):
427
- return [exp.bttn for exp in self.explainer_objs]
428
-
429
- def show(self):
430
- cnt = 0
431
- with gr.Accordion("Explainers", open=True):
432
- while cnt * PLOT_PER_LINE < len(self.explainer_names):
433
- with gr.Row():
434
- for info in self.info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]:
435
- explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
436
- self.explainer_objs.append(explainer_obj)
437
- explainer_obj.show()
438
- cnt += 1
439
-
440
- checkboxes = self.get_checkboxes()
441
- bttns = self.get_bttns()
442
- self.gallery.gallery_obj.select(
443
- fn=self.update_gallery_change,
444
- outputs=checkboxes + bttns
445
- )
446
-
447
-
448
- class ExplainerCheckbox(Component):
449
- def __init__(self, explainer_name, groups, experiment, gallery):
450
- self.explainer_name = explainer_name
451
- self.groups = groups
452
- self.experiment = experiment
453
- self.gallery = gallery
454
-
455
- self.default_exp_id = self.get_explainer_id_by_name(explainer_name)
456
- self.obj_metric = self.get_metric_id_by_name(OBJECTIVE_METRIC)
457
-
458
- def get_explainer_id_by_name(self, explainer_name):
459
- explainer_info = self.experiment.manager.get_explainers()
460
- idx = [exp.__class__.__name__ for exp in explainer_info[0]].index(explainer_name)
461
- return explainer_info[1][idx]
462
-
463
- def get_metric_id_by_name(self, metric_name):
464
- metric_info = self.experiment.manager.get_metrics()
465
- idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
466
- return metric_info[1][idx]
467
-
468
-
469
- def optimize(self):
470
- # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
471
- # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
472
- # return [gr.update()] * 2
473
-
474
- data_id = self.gallery.selected_index
475
-
476
- optimized, _, _ = self.experiment.optimize(
477
- data_id=data_id.value,
478
- explainer_id=self.default_exp_id,
479
- metric_id=self.obj_metric,
480
- direction='maximize',
481
- sampler=SAMPLE_METHOD,
482
- n_trials=OPT_N_TRIALS,
483
- )
484
-
485
- opt_explainer_id = optimized['explainer_id']
486
- opt_postprocessor_id = optimized['postprocessor_id']
487
-
488
- self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
489
- self.optimal_exp_id = opt_explainer_id
490
- checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
491
- bttn = gr.update(value="Optimized", variant="secondary")
492
-
493
- return [checkbox, bttn]
494
-
495
-
496
- def default_on_select(self, evt: gr.EventData):
497
- self.groups.update_check(self.default_exp_id, evt._data['value'])
498
-
499
- def optimal_on_select(self, evt: gr.EventData):
500
- if hasattr(self, "optimal_exp_id"):
501
- self.groups.update_check(self.optimal_exp_id, evt._data['value'])
502
- else:
503
- raise ValueError("Optimal explainer id is not found.")
504
-
505
- def show(self):
506
- with gr.Accordion(self.explainer_name, open=False):
507
- self.default_check = gr.Checkbox(label="Default Parameter", value=True, interactive=True)
508
- self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
509
-
510
- self.default_check.select(self.default_on_select)
511
- self.opt_check.select(self.optimal_on_select)
512
-
513
- self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
514
- self.bttn.click(self.optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1)
515
-
516
-
517
- class ExpRes(Component):
518
- def __init__(self, data_index, exp_res):
519
- self.data_index = data_index
520
- self.exp_res = exp_res
521
-
522
- def show(self):
523
- value = self.exp_res['value']
524
-
525
- fig = go.Figure(data=go.Heatmap(
526
- z=np.flipud(value[0].detach().numpy()),
527
- colorscale='Reds',
528
- showscale=False # remove color bar
529
- ))
530
-
531
- evaluations = self.exp_res['evaluations']
532
- metric_values = [f"{eval['metric_nm'][:4]}: {eval['value'].item():.2f}" for eval in evaluations if eval['value'] is not None]
533
- n = 3
534
- cnt = 0
535
- while cnt * n < len(metric_values):
536
- metric_text = ', '.join(metric_values[cnt*n:cnt*n+n])
537
- fig.add_annotation(
538
- x=0,
539
- y=-0.1 * (cnt+1),
540
- xref='paper',
541
- yref='paper',
542
- text=metric_text,
543
- showarrow=False,
544
- font=dict(
545
- size=18,
546
- ),
547
- )
548
- cnt += 1
549
-
550
-
551
- fig = fig.update_layout(
552
- width=380,
553
- height=400,
554
- xaxis=dict(
555
- showticklabels=False,
556
- ticks='',
557
- showgrid=False
558
- ),
559
- yaxis=dict(
560
- showticklabels=False,
561
- ticks='',
562
- showgrid=False
563
- ),
564
- margin=dict(t=40, b=40*cnt, l=20, r=20),
565
- )
566
-
567
- # Generate Random Unique ID
568
- root = f"{os.environ['GRADIO_TEMP_DIR']}/res"
569
- if not os.path.exists(root): os.makedirs(root)
570
- key = secrets.token_hex(8)
571
- path = f"{root}/{key}.png"
572
- fig.write_image(path)
573
- return path
574
-
575
-
576
- class ImageClsApp(App):
577
- def __init__(self, experiments, **kwargs):
578
- self.name = "Image Classification App"
579
- super().__init__(**kwargs)
580
-
581
- self.experiments = experiments
582
-
583
- self.overview_tab = OverviewTab()
584
- self.detection_tab = DetectionTab(self.experiments)
585
- self.local_exp_tab = LocalExpTab(self.experiments)
586
-
587
- def title(self):
588
- return """
589
- <div style="text-align: center;">
590
- <img src="/file=data/static/XAI-Top-PnP.svg" width="100" height="100">
591
- <h1> Plug and Play XAI Platform for Image Classification </h1>
592
- </div>
593
- """
594
-
595
- def launch(self, **kwargs):
596
- with gr.Blocks(
597
- title=self.name,
598
- ) as demo:
599
- cwd = os.getcwd()
600
- gr.set_static_paths(cwd)
601
- gr.HTML(self.title())
602
-
603
- self.overview_tab.show()
604
- self.detection_tab.show()
605
- self.local_exp_tab.show()
606
-
607
- return demo
608
-
609
- # if __name__ == '__main__':
610
- import os
611
- import torch
612
- import numpy as np
613
- from torch.utils.data import DataLoader
614
- from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
615
-
616
- os.environ['GRADIO_TEMP_DIR'] = '.tmp'
617
-
618
- def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
619
-
620
- experiments = {}
621
-
622
- model, transform = get_torchvision_model('resnet18')
623
- dataset = get_imagenet_dataset(transform)
624
- loader = DataLoader(dataset, batch_size=4, shuffle=False)
625
- experiment1 = AutoExplanationForImageClassification(
626
- model=model,
627
- data=loader,
628
- input_extractor=lambda batch: batch[0],
629
- label_extractor=lambda batch: batch[-1],
630
- target_extractor=lambda outputs: outputs.argmax(-1),
631
- channel_dim=1
632
- )
633
-
634
- experiments['experiment1'] = {
635
- 'name': 'ResNet18',
636
- 'experiment': experiment1,
637
- 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std),
638
- 'target_visualizer': target_visualizer,
639
- }
640
-
641
-
642
- model, transform = get_torchvision_model('vit_b_16')
643
- dataset = get_imagenet_dataset(transform)
644
- loader = DataLoader(dataset, batch_size=4, shuffle=False)
645
- experiment2 = AutoExplanationForImageClassification(
646
- model=model,
647
- data=loader,
648
- input_extractor=lambda batch: batch[0],
649
- label_extractor=lambda batch: batch[-1],
650
- target_extractor=lambda outputs: outputs.argmax(-1),
651
- channel_dim=1
652
- )
653
-
654
- experiments['experiment2'] = {
655
- 'name': 'ViT-B_16',
656
- 'experiment': experiment2,
657
- 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std),
658
- 'target_visualizer': target_visualizer,
659
- }
660
-
661
- app = ImageClsApp(experiments)
662
- demo = app.launch()
663
- demo.launch(favicon_path="data/static/XAI-Top-PnP.svg", share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,13 +1,663 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- with gr.Blocks() as block:
4
- textbox = gr.Textbox(label="Enter your text here")
5
- bttn = gr.Button()
6
- output = gr.Textbox(label="Output")
7
 
8
- def submit(input_data):
9
- return input_data + "Submitted"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- bttn.click(submit, inputs=[textbox], outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- block.launch()
 
 
 
1
+ # python image_gradio.py >> ./logs/image_gradio.log 2>&1
2
+ import time
3
+ import os
4
  import gradio as gr
5
+ from pnpxai.core.experiment import AutoExplanationForImageClassification
6
+ from pnpxai.core.detector import extract_graph_data, symbolic_trace
7
+ import matplotlib.pyplot as plt
8
+ import plotly.graph_objects as go
9
+ import plotly.express as px
10
+ import networkx as nx
11
+ import secrets
12
 
 
 
 
 
13
 
14
+ PLOT_PER_LINE = 4
15
+ N_FEATURES_TO_SHOW = 5
16
+ OPT_N_TRIALS = 10
17
+ OBJECTIVE_METRIC = "AbPC"
18
+ SAMPLE_METHOD = "tpe"
19
+
20
+ class App:
21
+ def __init__(self):
22
+ pass
23
+
24
+ class Component:
25
+ def __init__(self):
26
+ pass
27
+
28
+ class Tab(Component):
29
+ def __init__(self):
30
+ pass
31
+
32
+ class OverviewTab(Tab):
33
+ def __init__(self):
34
+ pass
35
+
36
+ def show(self):
37
+ with gr.Tab(label="Overview") as tab:
38
+ gr.Label("This is the overview tab.")
39
+
40
+ class DetectionTab(Tab):
41
+ def __init__(self, experiments):
42
+ self.experiments = experiments
43
+
44
+ def show(self):
45
+ with gr.Tab(label="Detection") as tab:
46
+ gr.Label("This is the detection tab.")
47
+
48
+ for nm, exp_info in self.experiments.items():
49
+ exp = exp_info['experiment']
50
+ detector_res = DetectorRes(exp)
51
+ detector_res.show()
52
+
53
+ class LocalExpTab(Tab):
54
+ def __init__(self, experiments):
55
+ self.experiments = experiments
56
+
57
+ self.experiment_components = []
58
+ for nm, exp_info in self.experiments.items():
59
+ self.experiment_components.append(Experiment(exp_info))
60
+
61
+ def description(self):
62
+ return "This tab shows the local explanation."
63
+
64
+ def show(self):
65
+ with gr.Tab(label="Local Explanation") as tab:
66
+ gr.Label("This is the local explanation tab.")
67
+
68
+ for i, exp in enumerate(self.experiments):
69
+ self.experiment_components[i].show()
70
+
71
+ class DetectorRes(Component):
72
+ def __init__(self, experiment):
73
+ self.experiment = experiment
74
+ graph_module = symbolic_trace(experiment.model)
75
+ self.graph_data = extract_graph_data(graph_module)
76
+
77
+ def describe(self):
78
+ return "This component shows the detection result."
79
 
80
+ def show(self):
81
+ G = nx.DiGraph()
82
+ root = None
83
+ for node in self.graph_data['nodes']:
84
+ if node['op'] == 'placeholder':
85
+ root = node['name']
86
+
87
+ G.add_node(node['name'])
88
+
89
+
90
+ for edge in self.graph_data['edges']:
91
+ if edge['source'] in G.nodes and edge['target'] in G.nodes:
92
+ G.add_edge(edge['source'], edge['target'])
93
+
94
+
95
+ def get_pos1(graph):
96
+ graph = graph.copy()
97
+ for layer, nodes in enumerate(reversed(tuple(nx.topological_generations(graph)))):
98
+ for node in nodes:
99
+ graph.nodes[node]["layer"] = layer
100
+
101
+ pos = nx.multipartite_layout(graph, subset_key="layer", align='horizontal')
102
+ return pos
103
+
104
+
105
+ def get_pos2(graph, root, levels=None, width=1., height=1.):
106
+ '''
107
+ G: the graph
108
+ root: the root node
109
+ levels: a dictionary
110
+ key: level number (starting from 0)
111
+ value: number of nodes in this level
112
+ width: horizontal space allocated for drawing
113
+ height: vertical space allocated for drawing
114
+ '''
115
+ TOTAL = "total"
116
+ CURRENT = "current"
117
+
118
+ def make_levels(levels, node=root, currentLevel=0, parent=None):
119
+ # Compute the number of nodes for each level
120
+ if not currentLevel in levels:
121
+ levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
122
+ levels[currentLevel][TOTAL] += 1
123
+ neighbors = graph.neighbors(node)
124
+ for neighbor in neighbors:
125
+ if not neighbor == parent:
126
+ levels = make_levels(levels, neighbor, currentLevel + 1, node)
127
+ return levels
128
+
129
+ def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
130
+ dx = 1/levels[currentLevel][TOTAL]
131
+ left = dx/2
132
+ pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc)
133
+ levels[currentLevel][CURRENT] += 1
134
+ neighbors = graph.neighbors(node)
135
+ for neighbor in neighbors:
136
+ if not neighbor == parent:
137
+ pos = make_pos(pos, neighbor, currentLevel +
138
+ 1, node, vert_loc-vert_gap)
139
+ return pos
140
+
141
+ if levels is None:
142
+ levels = make_levels({})
143
+ else:
144
+ levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
145
+ vert_gap = height / (max([l for l in levels])+1)
146
+ return make_pos({})
147
+
148
+
149
+ def plot_graph(graph, pos):
150
+ fig = plt.figure(figsize=(12, 24))
151
+ ax = fig.gca()
152
+ nx.draw(graph, pos=pos, with_labels=True, node_size=60, font_size=8, ax=ax)
153
+
154
+ fig.tight_layout()
155
+ return fig
156
+
157
+
158
+
159
+ pos = get_pos1(G)
160
+ fig = plot_graph(G, pos)
161
+ # pos = get_pos2(G, root)
162
+ # fig = plot_graph(G, pos)
163
+
164
+ with gr.Row():
165
+ gr.Textbox(value="Image Classficiation", label="Task")
166
+ gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model")
167
+ gr.Plot(value=fig, label=f"Model Architecture of {self.experiment.model.__class__.__name__}", visible=True)
168
+
169
+
170
+
171
+ class ImgGallery(Component):
172
+ def __init__(self, imgs):
173
+ self.imgs = imgs
174
+ self.selected_index = gr.Number(value=0, label="Selected Index", visible=False)
175
+
176
+ def on_select(self, evt: gr.SelectData):
177
+ return evt.index
178
+
179
+ def show(self):
180
+ self.gallery_obj = gr.Gallery(value=self.imgs, label="Input Data Gallery", columns=6, height=200)
181
+ self.gallery_obj.select(self.on_select, outputs=self.selected_index)
182
+
183
+
184
+ class Experiment(Component):
185
+ def __init__(self, exp_info):
186
+ self.exp_info = exp_info
187
+ self.experiment = exp_info['experiment']
188
+ self.input_visualizer = exp_info['input_visualizer']
189
+ self.target_visualizer = exp_info['target_visualizer']
190
+
191
+ def viz_input(self, input, data_id):
192
+ orig_img_np = self.input_visualizer(input)
193
+ orig_img = px.imshow(orig_img_np)
194
+
195
+ orig_img.update_layout(
196
+ title=f"Data ID: {data_id}",
197
+ width=400,
198
+ height=350,
199
+ xaxis=dict(
200
+ showticklabels=False,
201
+ ticks='',
202
+ showgrid=False
203
+ ),
204
+ yaxis=dict(
205
+ showticklabels=False,
206
+ ticks='',
207
+ showgrid=False
208
+ ),
209
+ )
210
+
211
+ return orig_img
212
+
213
+
214
+ def get_prediction(self, record, topk=3):
215
+ probs = record['output'].softmax(-1).squeeze().detach().numpy()
216
+ text = f"Ground Truth Label: {self.target_visualizer(record['label'])}\n"
217
+
218
+ for ind, pred in enumerate(probs.argsort()[-topk:][::-1]):
219
+ label = self.target_visualizer(torch.tensor(pred))
220
+ prob = probs[pred]
221
+ text += f"Top {ind+1} Prediction: {label} ({prob:.2f})\n"
222
+
223
+ return text
224
+
225
+
226
+ def get_exp_plot(self, data_index, exp_res):
227
+ return ExpRes(data_index, exp_res).show()
228
+
229
+ def get_metric_id_by_name(self, metric_name):
230
+ metric_info = self.experiment.manager.get_metrics()
231
+ idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
232
+ return metric_info[1][idx]
233
+
234
+ def generate_record(self, data_id, metric_names):
235
+ record = {}
236
+ _base = self.experiment.run_batch([data_id], 0, 0, 0)
237
+ record['data_id'] = data_id
238
+ record['input'] = _base['inputs']
239
+ record['label'] = _base['labels']
240
+ record['output'] = _base['outputs']
241
+ record['target'] = _base['targets']
242
+ record['explanations'] = []
243
+
244
+ metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
245
+
246
+ cnt = 0
247
+ for info in self.explainer_checkbox_group.info:
248
+ if info['checked']:
249
+ base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
250
+ record['explanations'].append({
251
+ 'explainer_nm': base['explainer'].__class__.__name__,
252
+ 'value': base['postprocessed'],
253
+ 'mode' : info['mode'],
254
+ 'evaluations': []
255
+ })
256
+ for metric_id in metrics_ids:
257
+ res = self.experiment.run_batch([data_id], info['id'], info['pp_id'], metric_id)
258
+ record['explanations'][-1]['evaluations'].append({
259
+ 'metric_nm': res['metric'].__class__.__name__,
260
+ 'value' : res['evaluation']
261
+ })
262
+
263
+ cnt += 1
264
+
265
+ # Sort record['explanations'] with respect to the metric values
266
+ if len(record['explanations'][0]['evaluations']) > 0:
267
+ record['explanations'] = sorted(record['explanations'], key=lambda x: x['evaluations'][0]['value'], reverse=True)
268
+
269
+ return record
270
+
271
+
272
+ def show(self):
273
+ with gr.Row():
274
+ gr.Textbox(value="Image Classficiation", label="Task")
275
+ gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model")
276
+ gr.Textbox(value="Heatmap", label="Explanation Type")
277
+
278
+ dset = self.experiment.manager._data.dataset
279
+ imgs = []
280
+ for i in range(len(dset)):
281
+ img = self.input_visualizer(dset[i][0])
282
+ imgs.append(img)
283
+ gallery = ImgGallery(imgs)
284
+ gallery.show()
285
+
286
+ explainers, _ = self.experiment.manager.get_explainers()
287
+ explainer_names = [exp.__class__.__name__ for exp in explainers]
288
+
289
+ self.explainer_checkbox_group = ExplainerCheckboxGroup(explainer_names, self.experiment, gallery)
290
+ self.explainer_checkbox_group.show()
291
+
292
+ cr_metrics_names = ["AbPC", "MoRF", "LeRF", "MuFidelity"]
293
+ cn_metrics_names = ["Sensitivity"]
294
+ cp_metrics_names = ["Complexity"]
295
+ with gr.Accordion("Evaluators", open=True):
296
+ with gr.Row():
297
+ cr_metrics = gr.CheckboxGroup(choices=cr_metrics_names, value=[cr_metrics_names[0]], label="Correctness")
298
+ def on_select(metrics):
299
+ if cr_metrics_names[0] not in metrics:
300
+ gr.Warning(f"{cr_metrics_names[0]} is required for the sorting the explanations.")
301
+ return [cr_metrics_names[0]] + metrics
302
+ else:
303
+ return metrics
304
+
305
+ cr_metrics.select(on_select, inputs=cr_metrics, outputs=cr_metrics)
306
+ with gr.Row():
307
+ # cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, value=cn_metrics_names, label="Continuity")
308
+ cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, label="Continuity")
309
+ with gr.Row():
310
+ # cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, value=cp_metrics_names[0], label="Compactness")
311
+ cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, label="Compactness")
312
+
313
+ metric_inputs = [cr_metrics, cn_metrics, cp_metrics]
314
+
315
+ data_id = gallery.selected_index
316
+ bttn = gr.Button("Explain", variant="primary")
317
+
318
+ buffer_size = 2 * len(explainer_names)
319
+ buffer_n_rows = buffer_size // PLOT_PER_LINE
320
+ buffer_n_rows = buffer_n_rows + 1 if buffer_size % PLOT_PER_LINE != 0 else buffer_n_rows
321
+
322
+ plots = [gr.Textbox(label="Prediction result", visible=False)]
323
+ for i in range(buffer_n_rows):
324
+ with gr.Row():
325
+ for j in range(PLOT_PER_LINE):
326
+ plot = gr.Image(value=None, label="Blank", visible=False)
327
+ plots.append(plot)
328
+
329
+ def show_plots():
330
+ _plots = [gr.Textbox(label="Prediction result", visible=False)]
331
+ num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
332
+ n_rows = num_plots // PLOT_PER_LINE
333
+ n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
334
+ _plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
335
+ _plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
336
+ return _plots
337
+
338
+ def render_plots(data_id, *metric_inputs):
339
+ # Clear Cache Files
340
+ cache_dir = f"{os.environ['GRADIO_TEMP_DIR']}/res"
341
+ if not os.path.exists(cache_dir): os.makedirs(cache_dir)
342
+ for f in os.listdir(cache_dir):
343
+ if len(f.split(".")[0]) == 16:
344
+ os.remove(os.path.join(cache_dir, f))
345
+
346
+ # Render Plots
347
+ metric_input = []
348
+ for metric in metric_inputs:
349
+ if metric:
350
+ metric_input += metric
351
+
352
+ record = self.generate_record(data_id, metric_input)
353
+
354
+ pred = self.get_prediction(record)
355
+ plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
356
+
357
+ num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
358
+ n_rows = num_plots // PLOT_PER_LINE
359
+ n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
360
+
361
+ for i in range(n_rows):
362
+ for j in range(PLOT_PER_LINE):
363
+ if i*PLOT_PER_LINE+j < len(record['explanations']):
364
+ exp_res = record['explanations'][i*PLOT_PER_LINE+j]
365
+ path = self.get_exp_plot(data_id, exp_res)
366
+ plot_obj = gr.Image(value=path, label=f"{exp_res['explainer_nm']} ({exp_res['mode']})", visible=True)
367
+ plots.append(plot_obj)
368
+ else:
369
+ plots.append(gr.Image(value=None, label="Blank", visible=True))
370
+
371
+ plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
372
+
373
+ return plots
374
+
375
+ bttn.click(show_plots, outputs=plots)
376
+ bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots)
377
+
378
+
379
+
380
+ class ExplainerCheckboxGroup(Component):
381
+ def __init__(self, explainer_names, experiment, gallery):
382
+ super().__init__()
383
+ self.explainer_names = explainer_names
384
+ self.explainer_objs = []
385
+ self.experiment = experiment
386
+ self.gallery = gallery
387
+ explainers, exp_ids = self.experiment.manager.get_explainers()
388
+
389
+ self.info = []
390
+ for exp, exp_id in zip(explainers, exp_ids):
391
+ self.info.append({'nm': exp.__class__.__name__, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': True})
392
+
393
+ def update_check(self, exp_id, val=None):
394
+ for info in self.info:
395
+ if info['id'] == exp_id:
396
+ if val is not None:
397
+ info['checked'] = val
398
+ else:
399
+ info['checked'] = not info['checked']
400
+
401
+ def insert_check(self, exp_nm, exp_id, pp_id):
402
+ if exp_id in [info['id'] for info in self.info]:
403
+ return
404
+
405
+ self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False})
406
+
407
+ def update_gallery_change(self):
408
+ checkboxes = []
409
+ bttns = []
410
+ checkboxes += [gr.Checkbox(label="Default Parameter", value=True, interactive=True)] * len(self.explainer_objs)
411
+ checkboxes += [gr.Checkbox(label="Optimized Parameter (Not Optimal)", value=False, interactive=False)] * len(self.explainer_objs)
412
+ bttns += [gr.Button(value="Optimize", size="sm", variant="primary")] * len(self.explainer_objs)
413
+
414
+ for exp in self.explainer_objs:
415
+ self.update_check(exp.default_exp_id, True)
416
+ if hasattr(exp, "optimal_exp_id"):
417
+ self.update_check(exp.optimal_exp_id, False)
418
+ return checkboxes + bttns
419
+
420
+ def get_checkboxes(self):
421
+ checkboxes = []
422
+ checkboxes += [exp.default_check for exp in self.explainer_objs]
423
+ checkboxes += [exp.opt_check for exp in self.explainer_objs]
424
+ return checkboxes
425
+
426
+ def get_bttns(self):
427
+ return [exp.bttn for exp in self.explainer_objs]
428
+
429
+ def show(self):
430
+ cnt = 0
431
+ with gr.Accordion("Explainers", open=True):
432
+ while cnt * PLOT_PER_LINE < len(self.explainer_names):
433
+ with gr.Row():
434
+ for info in self.info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]:
435
+ explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
436
+ self.explainer_objs.append(explainer_obj)
437
+ explainer_obj.show()
438
+ cnt += 1
439
+
440
+ checkboxes = self.get_checkboxes()
441
+ bttns = self.get_bttns()
442
+ self.gallery.gallery_obj.select(
443
+ fn=self.update_gallery_change,
444
+ outputs=checkboxes + bttns
445
+ )
446
+
447
+
448
+ class ExplainerCheckbox(Component):
449
+ def __init__(self, explainer_name, groups, experiment, gallery):
450
+ self.explainer_name = explainer_name
451
+ self.groups = groups
452
+ self.experiment = experiment
453
+ self.gallery = gallery
454
+
455
+ self.default_exp_id = self.get_explainer_id_by_name(explainer_name)
456
+ self.obj_metric = self.get_metric_id_by_name(OBJECTIVE_METRIC)
457
+
458
+ def get_explainer_id_by_name(self, explainer_name):
459
+ explainer_info = self.experiment.manager.get_explainers()
460
+ idx = [exp.__class__.__name__ for exp in explainer_info[0]].index(explainer_name)
461
+ return explainer_info[1][idx]
462
+
463
+ def get_metric_id_by_name(self, metric_name):
464
+ metric_info = self.experiment.manager.get_metrics()
465
+ idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
466
+ return metric_info[1][idx]
467
+
468
+
469
+ def optimize(self):
470
+ # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
471
+ # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
472
+ # return [gr.update()] * 2
473
+
474
+ data_id = self.gallery.selected_index
475
+
476
+ optimized, _, _ = self.experiment.optimize(
477
+ data_id=data_id.value,
478
+ explainer_id=self.default_exp_id,
479
+ metric_id=self.obj_metric,
480
+ direction='maximize',
481
+ sampler=SAMPLE_METHOD,
482
+ n_trials=OPT_N_TRIALS,
483
+ )
484
+
485
+ opt_explainer_id = optimized['explainer_id']
486
+ opt_postprocessor_id = optimized['postprocessor_id']
487
+
488
+ self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
489
+ self.optimal_exp_id = opt_explainer_id
490
+ checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
491
+ bttn = gr.update(value="Optimized", variant="secondary")
492
+
493
+ return [checkbox, bttn]
494
+
495
+
496
+ def default_on_select(self, evt: gr.EventData):
497
+ self.groups.update_check(self.default_exp_id, evt._data['value'])
498
+
499
+ def optimal_on_select(self, evt: gr.EventData):
500
+ if hasattr(self, "optimal_exp_id"):
501
+ self.groups.update_check(self.optimal_exp_id, evt._data['value'])
502
+ else:
503
+ raise ValueError("Optimal explainer id is not found.")
504
+
505
+ def show(self):
506
+ with gr.Accordion(self.explainer_name, open=False):
507
+ self.default_check = gr.Checkbox(label="Default Parameter", value=True, interactive=True)
508
+ self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
509
+
510
+ self.default_check.select(self.default_on_select)
511
+ self.opt_check.select(self.optimal_on_select)
512
+
513
+ self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
514
+ self.bttn.click(self.optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1)
515
+
516
+
517
+ class ExpRes(Component):
518
+ def __init__(self, data_index, exp_res):
519
+ self.data_index = data_index
520
+ self.exp_res = exp_res
521
+
522
+ def show(self):
523
+ value = self.exp_res['value']
524
+
525
+ fig = go.Figure(data=go.Heatmap(
526
+ z=np.flipud(value[0].detach().numpy()),
527
+ colorscale='Reds',
528
+ showscale=False # remove color bar
529
+ ))
530
+
531
+ evaluations = self.exp_res['evaluations']
532
+ metric_values = [f"{eval['metric_nm'][:4]}: {eval['value'].item():.2f}" for eval in evaluations if eval['value'] is not None]
533
+ n = 3
534
+ cnt = 0
535
+ while cnt * n < len(metric_values):
536
+ metric_text = ', '.join(metric_values[cnt*n:cnt*n+n])
537
+ fig.add_annotation(
538
+ x=0,
539
+ y=-0.1 * (cnt+1),
540
+ xref='paper',
541
+ yref='paper',
542
+ text=metric_text,
543
+ showarrow=False,
544
+ font=dict(
545
+ size=18,
546
+ ),
547
+ )
548
+ cnt += 1
549
+
550
+
551
+ fig = fig.update_layout(
552
+ width=380,
553
+ height=400,
554
+ xaxis=dict(
555
+ showticklabels=False,
556
+ ticks='',
557
+ showgrid=False
558
+ ),
559
+ yaxis=dict(
560
+ showticklabels=False,
561
+ ticks='',
562
+ showgrid=False
563
+ ),
564
+ margin=dict(t=40, b=40*cnt, l=20, r=20),
565
+ )
566
+
567
+ # Generate Random Unique ID
568
+ root = f"{os.environ['GRADIO_TEMP_DIR']}/res"
569
+ if not os.path.exists(root): os.makedirs(root)
570
+ key = secrets.token_hex(8)
571
+ path = f"{root}/{key}.png"
572
+ fig.write_image(path)
573
+ return path
574
+
575
+
576
+ class ImageClsApp(App):
577
+ def __init__(self, experiments, **kwargs):
578
+ self.name = "Image Classification App"
579
+ super().__init__(**kwargs)
580
+
581
+ self.experiments = experiments
582
+
583
+ self.overview_tab = OverviewTab()
584
+ self.detection_tab = DetectionTab(self.experiments)
585
+ self.local_exp_tab = LocalExpTab(self.experiments)
586
+
587
+ def title(self):
588
+ return """
589
+ <div style="text-align: center;">
590
+ <img src="/file=data/static/XAI-Top-PnP.svg" width="100" height="100">
591
+ <h1> Plug and Play XAI Platform for Image Classification </h1>
592
+ </div>
593
+ """
594
+
595
+ def launch(self, **kwargs):
596
+ with gr.Blocks(
597
+ title=self.name,
598
+ ) as demo:
599
+ cwd = os.getcwd()
600
+ gr.set_static_paths(cwd)
601
+ gr.HTML(self.title())
602
+
603
+ self.overview_tab.show()
604
+ self.detection_tab.show()
605
+ self.local_exp_tab.show()
606
+
607
+ return demo
608
+
609
+ # if __name__ == '__main__':
610
+ import os
611
+ import torch
612
+ import numpy as np
613
+ from torch.utils.data import DataLoader
614
+ from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
615
+
616
+ os.environ['GRADIO_TEMP_DIR'] = '.tmp'
617
+
618
+ def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
619
+
620
+ experiments = {}
621
+
622
+ model, transform = get_torchvision_model('resnet18')
623
+ dataset = get_imagenet_dataset(transform)
624
+ loader = DataLoader(dataset, batch_size=4, shuffle=False)
625
+ experiment1 = AutoExplanationForImageClassification(
626
+ model=model,
627
+ data=loader,
628
+ input_extractor=lambda batch: batch[0],
629
+ label_extractor=lambda batch: batch[-1],
630
+ target_extractor=lambda outputs: outputs.argmax(-1),
631
+ channel_dim=1
632
+ )
633
+
634
+ experiments['experiment1'] = {
635
+ 'name': 'ResNet18',
636
+ 'experiment': experiment1,
637
+ 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std),
638
+ 'target_visualizer': target_visualizer,
639
+ }
640
+
641
+
642
+ model, transform = get_torchvision_model('vit_b_16')
643
+ dataset = get_imagenet_dataset(transform)
644
+ loader = DataLoader(dataset, batch_size=4, shuffle=False)
645
+ experiment2 = AutoExplanationForImageClassification(
646
+ model=model,
647
+ data=loader,
648
+ input_extractor=lambda batch: batch[0],
649
+ label_extractor=lambda batch: batch[-1],
650
+ target_extractor=lambda outputs: outputs.argmax(-1),
651
+ channel_dim=1
652
+ )
653
+
654
+ experiments['experiment2'] = {
655
+ 'name': 'ViT-B_16',
656
+ 'experiment': experiment2,
657
+ 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std),
658
+ 'target_visualizer': target_visualizer,
659
+ }
660
 
661
+ app = ImageClsApp(experiments)
662
+ demo = app.launch()
663
+ demo.launch(favicon_path="data/static/XAI-Top-PnP.svg", share=True)
app_test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as block:
4
+ textbox = gr.Textbox(label="Enter your text here")
5
+ bttn = gr.Button()
6
+ output = gr.Textbox(label="Output")
7
+
8
+ def submit(input_data):
9
+ return input_data + "Submitted"
10
+
11
+ bttn.click(submit, inputs=[textbox], outputs=[output])
12
+
13
+ block.launch()
requirements.txt CHANGED
@@ -20,9 +20,8 @@ flask
20
  flask-cors
21
  flask-restx
22
  optuna
 
23
 
24
  # for text explainers
25
  transformers>=4.0.0
26
  gensim>=4.0.0
27
-
28
- git+https://github.com/OpenXAIProject/pnpxai.git@dev#egg=pnpxai
 
20
  flask-cors
21
  flask-restx
22
  optuna
23
+ git+https://github.com/OpenXAIProject/pnpxai.git@dev#egg=pnpxai
24
 
25
  # for text explainers
26
  transformers>=4.0.0
27
  gensim>=4.0.0