Spaces:
Sleeping
Sleeping
File size: 5,995 Bytes
5230215 dd737ab 5230215 83ffd1d 70dfc27 83ffd1d 7c68a99 83ffd1d 5230215 83ffd1d dd737ab 85e03c1 5230215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from transformers import AutoFeatureExtractor, AutoModel
import torch
from torchvision.transforms.functional import to_pil_image
from einops import rearrange, reduce
from skops import hub_utils
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import glob
import pickle
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT']
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16']
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino']
embedder_to_setup = dict(zip(embedder_names, setups))
gam_to_setup = dict(zip(gam_names, setups))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedders = {}
for name in embedder_names:
embedder = {}
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name)
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device)
if 'resnet-50' in name:
embedder['num_patches_side'] = 7
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d')
else:
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:]
embedders[embedder_to_setup[name]] = embedder
gams = {}
for name in gam_names:
if not os.path.exists(name):
os.mkdir(name)
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name)
with open(f'{name}/model.pkl', 'rb') as infile:
gams[gam_to_setup[name]] = pickle.load(infile)
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars):
'''Visualizes the patch contributions to all labels of one or more visual
Emb-GAMs'''
if not visual_emb_gam_setups:
fig = plt.Figure()
return fig, fig
patch_contributions = {}
# get patch contributions per Emb-GAM
for setup in visual_emb_gam_setups:
# prepare embedding model
embedder_setup = embedders[setup]
feature_extractor = embedder_setup['feature_extractor']
embedding_postprocess = embedder_setup['embedding_postprocess']
num_patches_side = embedder_setup['num_patches_side']
# prepare GAM
gam = gams[setup]
# get patch embeddings
inputs = {
k: v.to(device)
for k, v
in feature_extractor(input_img, return_tensors='pt').items()
}
with torch.no_grad():
patch_embeddings = embedding_postprocess(
embedder_setup['model'](**inputs)
).cpu()[0]
# get patch emebddings
patch_contributions[setup] = (
gam.coef_ \
@ patch_embeddings.T.numpy() \
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
# plot heatmaps
multiple_setups = len(visual_emb_gam_setups) > 1
# set up figure
fig, axs = plt.subplots(
len(visual_emb_gam_setups),
11,
figsize=(20, round(10/4 * len(visual_emb_gam_setups)))
)
gs_ax = axs[0, 0] if multiple_setups else axs[0]
gs = gs_ax.get_gridspec()
ax_rm = axs[:, 0] if multiple_setups else [axs[0]]
for ax in ax_rm:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0])
# plot original image
ax_orig_img.imshow(input_img)
ax_orig_img.axis('off')
# plot patch contributions
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]]
for i, setup in enumerate(visual_emb_gam_setups):
vmin = patch_contributions[setup].min()
vmax = patch_contributions[setup].max()
for j in range(10):
ax = axs_maps[i][j]
sns.heatmap(
patch_contributions[setup][j],
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
cbar=show_cbars
)
if show_scores:
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}')
if j == 0:
ax.set_ylabel(setup)
if i == 0:
ax.set_title(labels[j])
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
return fig
description = 'Visualize the patch contributions of [visual Emb-GAMs](https://huggingface.co/models?other=visual%20emb-gam) to class labels.'
article = '''An extension of [Emb-GAMs](https://arxiv.org/abs/2209.11799), visual Emb-GAMs classify images by embedding images, taking intermediate representations correponding to different spatial regions, summing these up, and predicting a class label from the sum using a GAM.
The use of a sum of embeddings allows us to visualize which regions of an image contributed positive or negatively to each class score.
No paper yet, but you can refer to these tweets:
- [Tweet #1](https://twitter.com/patrick_j_ramos/status/1586992857969147904?s=20&t=5-j5gKK0FpZOgzR_9Wdm1g)
- [Tweet #2](https://twitter.com/patrick_j_ramos/status/1602187142062804992?s=20&t=roTFXfMkHHYVoCuNyN-AUA)
Also, check out the original [Emb-GAM paper](https://arxiv.org/abs/2209.11799).
```bibtex
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
```
'''
demo = gr.Interface(
fn=visualize,
inputs=[
gr.Image(shape=(224, 224), type='pil', label='Input image'),
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'),
gr.Checkbox(label='Show scores'),
gr.Checkbox(label='Show color bars')
],
outputs=[
gr.Plot(label='Patch contributions'),
],
examples=[[path,setups,False,False] for path in glob.glob('examples/*')],
title='Visual Emb-GAM Probing',
description=description,
article=article,
examples_per_page=20
)
demo.launch(debug=True) |