Spaces:
Sleeping
Sleeping
# python image_gradio.py >> ./logs/image_gradio.log 2>&1 | |
import time | |
import os | |
import gradio as gr | |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification | |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace | |
import matplotlib.pyplot as plt | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import networkx as nx | |
import secrets | |
PLOT_PER_LINE = 4 | |
N_FEATURES_TO_SHOW = 5 | |
OPT_N_TRIALS = 10 | |
OBJECTIVE_METRIC = "AbPC" | |
SAMPLE_METHOD = "tpe" | |
DEFAULT_EXPLAINER = ["GradientXInput", "IntegratedGradients", "LRPEpsilonPlus"] | |
class App: | |
def __init__(self): | |
pass | |
class Component: | |
def __init__(self): | |
pass | |
class Tab(Component): | |
def __init__(self): | |
pass | |
class OverviewTab(Tab): | |
def __init__(self): | |
pass | |
def show(self): | |
with gr.Tab(label="Overview") as tab: | |
gr.Label("This is the overview tab.") | |
class DetectionTab(Tab): | |
def __init__(self, experiments): | |
self.experiments = experiments | |
def show(self): | |
with gr.Tab(label="Detection") as tab: | |
gr.Label("This is the detection tab.") | |
for nm, exp_info in self.experiments.items(): | |
exp = exp_info['experiment'] | |
detector_res = DetectorRes(exp) | |
detector_res.show() | |
class LocalExpTab(Tab): | |
def __init__(self, experiments): | |
self.experiments = experiments | |
self.experiment_components = [] | |
for nm, exp_info in self.experiments.items(): | |
self.experiment_components.append(Experiment(exp_info)) | |
def description(self): | |
return "This tab shows the local explanation." | |
def show(self): | |
with gr.Tab(label="Local Explanation") as tab: | |
gr.Label("This is the local explanation tab.") | |
for i, exp in enumerate(self.experiments): | |
self.experiment_components[i].show() | |
class DetectorRes(Component): | |
def __init__(self, experiment): | |
self.experiment = experiment | |
graph_module = symbolic_trace(experiment.model) | |
self.graph_data = extract_graph_data(graph_module) | |
def describe(self): | |
return "This component shows the detection result." | |
def show(self): | |
G = nx.DiGraph() | |
root = None | |
for node in self.graph_data['nodes']: | |
if node['op'] == 'placeholder': | |
root = node['name'] | |
G.add_node(node['name']) | |
for edge in self.graph_data['edges']: | |
if edge['source'] in G.nodes and edge['target'] in G.nodes: | |
G.add_edge(edge['source'], edge['target']) | |
def get_pos1(graph): | |
graph = graph.copy() | |
for layer, nodes in enumerate(reversed(tuple(nx.topological_generations(graph)))): | |
for node in nodes: | |
graph.nodes[node]["layer"] = layer | |
pos = nx.multipartite_layout(graph, subset_key="layer", align='horizontal') | |
return pos | |
def get_pos2(graph, root, levels=None, width=1., height=1.): | |
''' | |
G: the graph | |
root: the root node | |
levels: a dictionary | |
key: level number (starting from 0) | |
value: number of nodes in this level | |
width: horizontal space allocated for drawing | |
height: vertical space allocated for drawing | |
''' | |
TOTAL = "total" | |
CURRENT = "current" | |
def make_levels(levels, node=root, currentLevel=0, parent=None): | |
# Compute the number of nodes for each level | |
if not currentLevel in levels: | |
levels[currentLevel] = {TOTAL: 0, CURRENT: 0} | |
levels[currentLevel][TOTAL] += 1 | |
neighbors = graph.neighbors(node) | |
for neighbor in neighbors: | |
if not neighbor == parent: | |
levels = make_levels(levels, neighbor, currentLevel + 1, node) | |
return levels | |
def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): | |
dx = 1/levels[currentLevel][TOTAL] | |
left = dx/2 | |
pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc) | |
levels[currentLevel][CURRENT] += 1 | |
neighbors = graph.neighbors(node) | |
for neighbor in neighbors: | |
if not neighbor == parent: | |
pos = make_pos(pos, neighbor, currentLevel + | |
1, node, vert_loc-vert_gap) | |
return pos | |
if levels is None: | |
levels = make_levels({}) | |
else: | |
levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels} | |
vert_gap = height / (max([l for l in levels])+1) | |
return make_pos({}) | |
def plot_graph(graph, pos): | |
fig = plt.figure(figsize=(12, 24)) | |
ax = fig.gca() | |
nx.draw(graph, pos=pos, with_labels=True, node_size=60, font_size=8, ax=ax) | |
fig.tight_layout() | |
return fig | |
pos = get_pos1(G) | |
fig = plot_graph(G, pos) | |
# pos = get_pos2(G, root) | |
# fig = plot_graph(G, pos) | |
with gr.Row(): | |
gr.Textbox(value="Image Classficiation", label="Task") | |
gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") | |
gr.Plot(value=fig, label=f"Model Architecture of {self.experiment.model.__class__.__name__}", visible=True) | |
class ImgGallery(Component): | |
def __init__(self, imgs): | |
self.imgs = imgs | |
self.selected_index = gr.Number(value=0, label="Selected Index", visible=False) | |
def on_select(self, evt: gr.SelectData): | |
return evt.index | |
def show(self): | |
self.gallery_obj = gr.Gallery(value=self.imgs, label="Input Data Gallery", columns=6, height=200) | |
self.gallery_obj.select(self.on_select, outputs=self.selected_index) | |
class Experiment(Component): | |
def __init__(self, exp_info): | |
self.exp_info = exp_info | |
self.experiment = exp_info['experiment'] | |
self.input_visualizer = exp_info['input_visualizer'] | |
self.target_visualizer = exp_info['target_visualizer'] | |
def viz_input(self, input, data_id): | |
orig_img_np = self.input_visualizer(input) | |
orig_img = px.imshow(orig_img_np) | |
orig_img.update_layout( | |
title=f"Data ID: {data_id}", | |
width=400, | |
height=350, | |
xaxis=dict( | |
showticklabels=False, | |
ticks='', | |
showgrid=False | |
), | |
yaxis=dict( | |
showticklabels=False, | |
ticks='', | |
showgrid=False | |
), | |
) | |
return orig_img | |
def get_prediction(self, record, topk=3): | |
probs = record['output'].softmax(-1).squeeze().detach().numpy() | |
text = f"Ground Truth Label: {self.target_visualizer(record['label'])}\n" | |
for ind, pred in enumerate(probs.argsort()[-topk:][::-1]): | |
label = self.target_visualizer(torch.tensor(pred)) | |
prob = probs[pred] | |
text += f"Top {ind+1} Prediction: {label} ({prob:.2f})\n" | |
return text | |
def get_exp_plot(self, data_index, exp_res): | |
return ExpRes(data_index, exp_res).show() | |
def get_metric_id_by_name(self, metric_name): | |
metric_info = self.experiment.manager.get_metrics() | |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) | |
return metric_info[1][idx] | |
def generate_record(self, data_id, metric_names): | |
record = {} | |
_base = self.experiment.run_batch([data_id], 0, 0, 0) | |
record['data_id'] = data_id | |
record['input'] = _base['inputs'] | |
record['label'] = _base['labels'] | |
record['output'] = _base['outputs'] | |
record['target'] = _base['targets'] | |
record['explanations'] = [] | |
metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names] | |
cnt = 0 | |
for info in self.explainer_checkbox_group.info: | |
if info['checked']: | |
base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0) | |
record['explanations'].append({ | |
'explainer_nm': base['explainer'].__class__.__name__, | |
'value': base['postprocessed'], | |
'mode' : info['mode'], | |
'evaluations': [] | |
}) | |
for metric_id in metrics_ids: | |
res = self.experiment.run_batch([data_id], info['id'], info['pp_id'], metric_id) | |
record['explanations'][-1]['evaluations'].append({ | |
'metric_nm': res['metric'].__class__.__name__, | |
'value' : res['evaluation'] | |
}) | |
cnt += 1 | |
# Sort record['explanations'] with respect to the metric values | |
if len(record['explanations'][0]['evaluations']) > 0: | |
record['explanations'] = sorted(record['explanations'], key=lambda x: x['evaluations'][0]['value'], reverse=True) | |
return record | |
def show(self): | |
with gr.Row(): | |
gr.Textbox(value="Image Classficiation", label="Task") | |
gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") | |
gr.Textbox(value="Heatmap", label="Explanation Type") | |
dset = self.experiment.manager._data.dataset | |
imgs = [] | |
for i in range(len(dset)): | |
img = self.input_visualizer(dset[i][0]) | |
imgs.append(img) | |
gallery = ImgGallery(imgs) | |
gallery.show() | |
explainers, _ = self.experiment.manager.get_explainers() | |
explainer_names = [exp.__class__.__name__ for exp in explainers] | |
self.explainer_checkbox_group = ExplainerCheckboxGroup(explainer_names, self.experiment, gallery) | |
self.explainer_checkbox_group.show() | |
cr_metrics_names = ["AbPC", "MoRF", "LeRF", "MuFidelity"] | |
cn_metrics_names = ["Sensitivity"] | |
cp_metrics_names = ["Complexity"] | |
with gr.Accordion("Evaluators", open=True): | |
with gr.Row(): | |
cr_metrics = gr.CheckboxGroup(choices=cr_metrics_names, value=[cr_metrics_names[0]], label="Correctness") | |
def on_select(metrics): | |
if cr_metrics_names[0] not in metrics: | |
gr.Warning(f"{cr_metrics_names[0]} is required for the sorting the explanations.") | |
return [cr_metrics_names[0]] + metrics | |
else: | |
return metrics | |
cr_metrics.select(on_select, inputs=cr_metrics, outputs=cr_metrics) | |
with gr.Row(): | |
# cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, value=cn_metrics_names, label="Continuity") | |
cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, label="Continuity") | |
with gr.Row(): | |
# cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, value=cp_metrics_names[0], label="Compactness") | |
cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, label="Compactness") | |
metric_inputs = [cr_metrics, cn_metrics, cp_metrics] | |
data_id = gallery.selected_index | |
bttn = gr.Button("Explain", variant="primary") | |
buffer_size = 2 * len(explainer_names) | |
buffer_n_rows = buffer_size // PLOT_PER_LINE | |
buffer_n_rows = buffer_n_rows + 1 if buffer_size % PLOT_PER_LINE != 0 else buffer_n_rows | |
plots = [gr.Textbox(label="Prediction result", visible=False)] | |
for i in range(buffer_n_rows): | |
with gr.Row(): | |
for j in range(PLOT_PER_LINE): | |
plot = gr.Image(value=None, label="Blank", visible=False) | |
plots.append(plot) | |
def show_plots(): | |
_plots = [gr.Textbox(label="Prediction result", visible=False)] | |
num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']]) | |
n_rows = num_plots // PLOT_PER_LINE | |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows | |
_plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE) | |
_plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) | |
return _plots | |
def render_plots(data_id, *metric_inputs): | |
# Clear Cache Files | |
cache_dir = f"{os.environ['GRADIO_TEMP_DIR']}/res" | |
if not os.path.exists(cache_dir): os.makedirs(cache_dir) | |
for f in os.listdir(cache_dir): | |
if len(f.split(".")[0]) == 16: | |
os.remove(os.path.join(cache_dir, f)) | |
# Render Plots | |
metric_input = [] | |
for metric in metric_inputs: | |
if metric: | |
metric_input += metric | |
record = self.generate_record(data_id, metric_input) | |
pred = self.get_prediction(record) | |
plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)] | |
num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']]) | |
n_rows = num_plots // PLOT_PER_LINE | |
n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows | |
for i in range(n_rows): | |
for j in range(PLOT_PER_LINE): | |
if i*PLOT_PER_LINE+j < len(record['explanations']): | |
exp_res = record['explanations'][i*PLOT_PER_LINE+j] | |
path = self.get_exp_plot(data_id, exp_res) | |
plot_obj = gr.Image(value=path, label=f"{exp_res['explainer_nm']} ({exp_res['mode']})", visible=True) | |
plots.append(plot_obj) | |
else: | |
plots.append(gr.Image(value=None, label="Blank", visible=True)) | |
plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) | |
return plots | |
bttn.click(show_plots, outputs=plots) | |
bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots) | |
class ExplainerCheckboxGroup(Component): | |
def __init__(self, explainer_names, experiment, gallery): | |
super().__init__() | |
self.explainer_names = explainer_names | |
self.explainer_objs = [] | |
self.experiment = experiment | |
self.gallery = gallery | |
explainers, exp_ids = self.experiment.manager.get_explainers() | |
self.info = [] | |
for exp, exp_id in zip(explainers, exp_ids): | |
exp_nm = exp.__class__.__name__ | |
if exp_nm in DEFAULT_EXPLAINER: | |
checked = True | |
else: | |
checked = False | |
self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked}) | |
def update_check(self, exp_id, val=None): | |
for info in self.info: | |
if info['id'] == exp_id: | |
if val is not None: | |
info['checked'] = val | |
else: | |
info['checked'] = not info['checked'] | |
def insert_check(self, exp_nm, exp_id, pp_id): | |
if exp_id in [info['id'] for info in self.info]: | |
return | |
self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False}) | |
def update_gallery_change(self): | |
checkboxes = [] | |
bttns = [] | |
for exp in self.explainer_objs: | |
val = exp.explainer_name in DEFAULT_EXPLAINER | |
checkboxes.append(gr.Checkbox(label="Default Parameter", value=val, interactive=True)) | |
checkboxes += [gr.Checkbox(label="Optimized Parameter (Not Optimal)", value=False, interactive=False)] * len(self.explainer_objs) | |
bttns += [gr.Button(value="Optimize", size="sm", variant="primary")] * len(self.explainer_objs) | |
for exp in self.explainer_objs: | |
val = exp.explainer_name in DEFAULT_EXPLAINER | |
self.update_check(exp.default_exp_id, val) | |
if hasattr(exp, "optimal_exp_id"): | |
self.update_check(exp.optimal_exp_id, False) | |
return checkboxes + bttns | |
def get_checkboxes(self): | |
checkboxes = [] | |
checkboxes += [exp.default_check for exp in self.explainer_objs] | |
checkboxes += [exp.opt_check for exp in self.explainer_objs] | |
return checkboxes | |
def get_bttns(self): | |
return [exp.bttn for exp in self.explainer_objs] | |
def show(self): | |
cnt = 0 | |
sorted_info = sorted(self.info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm'])) | |
with gr.Accordion("Explainers", open=True): | |
while cnt * PLOT_PER_LINE < len(self.explainer_names): | |
with gr.Row(): | |
for info in sorted_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]: | |
explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery) | |
self.explainer_objs.append(explainer_obj) | |
explainer_obj.show() | |
cnt += 1 | |
checkboxes = self.get_checkboxes() | |
bttns = self.get_bttns() | |
self.gallery.gallery_obj.select( | |
fn=self.update_gallery_change, | |
outputs=checkboxes + bttns | |
) | |
class ExplainerCheckbox(Component): | |
def __init__(self, explainer_name, groups, experiment, gallery): | |
self.explainer_name = explainer_name | |
self.groups = groups | |
self.experiment = experiment | |
self.gallery = gallery | |
self.default_exp_id = self.get_explainer_id_by_name(explainer_name) | |
self.obj_metric = self.get_metric_id_by_name(OBJECTIVE_METRIC) | |
def get_explainer_id_by_name(self, explainer_name): | |
explainer_info = self.experiment.manager.get_explainers() | |
idx = [exp.__class__.__name__ for exp in explainer_info[0]].index(explainer_name) | |
return explainer_info[1][idx] | |
def get_metric_id_by_name(self, metric_name): | |
metric_info = self.experiment.manager.get_metrics() | |
idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) | |
return metric_info[1][idx] | |
def optimize(self): | |
# if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]: | |
# gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.") | |
# return [gr.update()] * 2 | |
data_id = self.gallery.selected_index | |
opt_output = self.experiment.optimize( | |
data_ids=data_id.value, | |
explainer_id=self.default_exp_id, | |
metric_id=self.obj_metric, | |
direction='maximize', | |
sampler=SAMPLE_METHOD, | |
n_trials=OPT_N_TRIALS, | |
) | |
def get_str_ppid(pp_obj): | |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__ | |
str_id = get_str_ppid(opt_output.postprocessor) | |
for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()): | |
if get_str_ppid(pp_obj) == str_id: | |
opt_postprocessor_id = pp_id | |
break | |
opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1 | |
opt_output.explainer.model = self.experiment.model | |
self.experiment.manager._explainers.append(opt_output.explainer) | |
self.experiment.manager._explainer_ids.append(opt_explainer_id) | |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id) | |
self.optimal_exp_id = opt_explainer_id | |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True) | |
bttn = gr.update(value="Optimized", variant="secondary") | |
return [checkbox, bttn] | |
def default_on_select(self, evt: gr.EventData): | |
self.groups.update_check(self.default_exp_id, evt._data['value']) | |
def optimal_on_select(self, evt: gr.EventData): | |
if hasattr(self, "optimal_exp_id"): | |
self.groups.update_check(self.optimal_exp_id, evt._data['value']) | |
else: | |
raise ValueError("Optimal explainer id is not found.") | |
def show(self): | |
val = self.explainer_name in DEFAULT_EXPLAINER | |
with gr.Accordion(self.explainer_name, open=val): | |
checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.info))['checked'] | |
self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True) | |
self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False) | |
self.default_check.select(self.default_on_select) | |
self.opt_check.select(self.optimal_on_select) | |
self.bttn = gr.Button(value="Optimize", size="sm", variant="primary") | |
self.bttn.click(self.optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1) | |
class ExpRes(Component): | |
def __init__(self, data_index, exp_res): | |
self.data_index = data_index | |
self.exp_res = exp_res | |
def show(self): | |
value = self.exp_res['value'] | |
fig = go.Figure(data=go.Heatmap( | |
z=np.flipud(value[0].detach().numpy()), | |
colorscale='Reds', | |
showscale=False # remove color bar | |
)) | |
evaluations = self.exp_res['evaluations'] | |
metric_values = [f"{eval['metric_nm'][:4]}: {eval['value'].item():.2f}" for eval in evaluations if eval['value'] is not None] | |
n = 3 | |
cnt = 0 | |
while cnt * n < len(metric_values): | |
metric_text = ', '.join(metric_values[cnt*n:cnt*n+n]) | |
fig.add_annotation( | |
x=0, | |
y=-0.1 * (cnt+1), | |
xref='paper', | |
yref='paper', | |
text=metric_text, | |
showarrow=False, | |
font=dict( | |
size=18, | |
), | |
) | |
cnt += 1 | |
fig = fig.update_layout( | |
width=380, | |
height=400, | |
xaxis=dict( | |
showticklabels=False, | |
ticks='', | |
showgrid=False | |
), | |
yaxis=dict( | |
showticklabels=False, | |
ticks='', | |
showgrid=False | |
), | |
margin=dict(t=40, b=40*cnt, l=20, r=20), | |
) | |
# Generate Random Unique ID | |
root = f"{os.environ['GRADIO_TEMP_DIR']}/res" | |
if not os.path.exists(root): os.makedirs(root) | |
key = secrets.token_hex(8) | |
path = f"{root}/{key}.png" | |
fig.write_image(path) | |
return path | |
class ImageClsApp(App): | |
def __init__(self, experiments, **kwargs): | |
self.name = "Image Classification App" | |
super().__init__(**kwargs) | |
self.experiments = experiments | |
self.overview_tab = OverviewTab() | |
self.detection_tab = DetectionTab(self.experiments) | |
self.local_exp_tab = LocalExpTab(self.experiments) | |
def title(self): | |
return """ | |
<div style="text-align: center;"> | |
<a href="https://openxaiproject.github.io/pnpxai/"> | |
<img src="/file=data/static/XAI-Top-PnP.svg" width="100" height="100"> | |
</a> | |
<h1> Plug and Play XAI Platform for Image Classification </h1> | |
</div> | |
""" | |
def launch(self, **kwargs): | |
with gr.Blocks( | |
title=self.name, | |
) as demo: | |
cwd = os.getcwd() | |
gr.set_static_paths(cwd) | |
gr.HTML(self.title()) | |
self.overview_tab.show() | |
self.detection_tab.show() | |
self.local_exp_tab.show() | |
return demo | |
# if __name__ == '__main__': | |
import os | |
import torch | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image | |
os.environ['GRADIO_TEMP_DIR'] = '.tmp' | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = torch.device("cpu") | |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item()) | |
experiments = {} | |
model, transform = get_torchvision_model('resnet18') | |
dataset = get_imagenet_dataset(transform) | |
loader = DataLoader(dataset, batch_size=4, shuffle=False) | |
experiment1 = AutoExplanationForImageClassification( | |
model=model.to(device), | |
data=loader, | |
input_extractor=lambda batch: batch[0], | |
label_extractor=lambda batch: batch[-1], | |
target_extractor=lambda outputs: outputs.argmax(-1), | |
channel_dim=1 | |
) | |
experiments['experiment1'] = { | |
'name': 'ResNet18', | |
'experiment': experiment1, | |
'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), | |
'target_visualizer': target_visualizer, | |
} | |
model, transform = get_torchvision_model('vit_b_16') | |
dataset = get_imagenet_dataset(transform) | |
loader = DataLoader(dataset, batch_size=4, shuffle=False) | |
experiment2 = AutoExplanationForImageClassification( | |
model=model.to(device), | |
data=loader, | |
input_extractor=lambda batch: batch[0], | |
label_extractor=lambda batch: batch[-1], | |
target_extractor=lambda outputs: outputs.argmax(-1), | |
channel_dim=1 | |
) | |
experiments['experiment2'] = { | |
'name': 'ViT-B_16', | |
'experiment': experiment2, | |
'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), | |
'target_visualizer': target_visualizer, | |
} | |
app = ImageClsApp(experiments) | |
demo = app.launch() | |
demo.launch(favicon_path="data/static/XAI-Top-PnP.svg", share=True) | |
# demo.launch(favicon_path="data/static/XAI-Top-PnP.svg") |