patrickramos commited on
Commit
5230215
·
1 Parent(s): a9b3a0e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoFeatureExtractor, AutoModel
2
+ import torch
3
+ from torchvision.transforms.functional import to_pil_image
4
+ from einops import rearrange, reduce
5
+ from skops import hub_utils
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import gradio as gr
9
+
10
+ import os
11
+ import pickle
12
+
13
+
14
+ setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT']
15
+ embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16']
16
+ gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino']
17
+
18
+ embedder_to_setup = dict(zip(embedder_names, setups))
19
+ gam_to_setup = dict(zip(gam_names, setups))
20
+
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
+ embedders = {}
24
+ for name in embedder_names:
25
+ embedder = {}
26
+ embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name)
27
+ embedder['model'] = AutoModel.from_pretrained(name).eval().to(device)
28
+
29
+ if 'resnet-50' in name:
30
+ embedder['num_patches_side'] = 7
31
+ embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d')
32
+ else:
33
+ embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size
34
+ embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:]
35
+ embedders[embedder_to_setup[name]] = embedder
36
+
37
+ gams = {}
38
+ for name in gam_names:
39
+ if not os.path.exists(name):
40
+ os.mkdir(name)
41
+ hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name)
42
+
43
+ with open(f'{name}/model.pkl', 'rb') as infile:
44
+ gams[gam_to_setup[name]] = pickle.load(infile)
45
+
46
+ labels = [
47
+ 'tench',
48
+ 'English springer',
49
+ 'cassette player',
50
+ 'chain saw',
51
+ 'church',
52
+ 'French horn',
53
+ 'garbage truck',
54
+ 'gas pump',
55
+ 'golf ball',
56
+ 'parachute'
57
+ ]
58
+
59
+ def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars):
60
+ '''Visualizes the patch contributions to all labels of one or more visual
61
+ Emb-GAMs'''
62
+
63
+ if not visual_emb_gam_setups:
64
+ fig = plt.Figure()
65
+ return fig, fig
66
+
67
+ patch_contributions = {}
68
+
69
+ # get patch contributions per Emb-GAM
70
+ for setup in visual_emb_gam_setups:
71
+ # prepare embedding model
72
+ embedder_setup = embedders[setup]
73
+ feature_extractor = embedder_setup['feature_extractor']
74
+ embedding_postprocess = embedder_setup['embedding_postprocess']
75
+ num_patches_side = embedder_setup['num_patches_side']
76
+
77
+ # prepare GAM
78
+ gam = gams[setup]
79
+
80
+ # get patch embeddings
81
+ inputs = {
82
+ k: v.to(device)
83
+ for k, v
84
+ in feature_extractor(input_img, return_tensors='pt').items()
85
+ }
86
+ with torch.no_grad():
87
+ patch_embeddings = embedding_postprocess(
88
+ embedder_setup['model'](**inputs)
89
+ ).cpu()[0]
90
+
91
+ # get patch emebddings
92
+ patch_contributions[setup] = (
93
+ gam.coef_ \
94
+ @ patch_embeddings.T.numpy() \
95
+ + gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
96
+ ).reshape(-1, num_patches_side, num_patches_side)
97
+
98
+ # plot heatmaps
99
+
100
+ multiple_setups = len(visual_emb_gam_setups) > 1
101
+
102
+ # set up figure
103
+ fig, axs = plt.subplots(
104
+ len(visual_emb_gam_setups),
105
+ 11,
106
+ figsize=(20, round(10/4 * len(visual_emb_gam_setups)))
107
+ )
108
+ gs_ax = axs[0, 0] if multiple_setups else axs[0]
109
+ gs = gs_ax.get_gridspec()
110
+ ax_rm = axs[:, 0] if multiple_setups else [axs[0]]
111
+ for ax in ax_rm:
112
+ ax.remove()
113
+ ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0])
114
+
115
+ # plot original image
116
+ ax_orig_img.imshow(input_img)
117
+ ax_orig_img.axis('off')
118
+
119
+ # plot patch contributions
120
+ axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]]
121
+ for i, setup in enumerate(visual_emb_gam_setups):
122
+ vmin = patch_contributions[setup].min()
123
+ vmax = patch_contributions[setup].max()
124
+ for j in range(10):
125
+ ax = axs_maps[i][j]
126
+ sns.heatmap(
127
+ patch_contributions[setup][j],
128
+ ax=ax,
129
+ square=True,
130
+ vmin=vmin,
131
+ vmax=vmax,
132
+ cbar=show_cbars
133
+ )
134
+ if show_scores:
135
+ ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}')
136
+ if j == 0:
137
+ ax.set_ylabel(setup)
138
+ if i == 0:
139
+ ax.set_title(labels[j])
140
+ ax.set_xticks([])
141
+ ax.set_yticks([])
142
+
143
+ plt.tight_layout()
144
+
145
+ return fig
146
+
147
+ demo = gr.Interface(
148
+ fn=visualize,
149
+ inputs=[
150
+ gr.Image(shape=(224, 224), type='pil', label='Input image'),
151
+ gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'),
152
+ gr.Checkbox(label='Show scores'),
153
+ gr.Checkbox(label='Show color bars')
154
+ ],
155
+ outputs=[
156
+ gr.Plot(label='Patch contributions'),
157
+ ]
158
+ )
159
+ demo.launch(debug=True)