Spaces:
Runtime error
Runtime error
import sys | |
import PIL | |
import cv2 | |
import torch | |
import torchvision | |
import torch.nn as nn | |
from utils.save_load import load_model | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
import gradio as gr | |
from pytorch_grad_cam import GradCAM, AblationCAM, FullGrad, EigenGradCAM, LayerCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam import DeepFeatureFactorization | |
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image | |
import numpy as np | |
from typing import List | |
from matplotlib import pyplot as plt | |
from matplotlib.lines import Line2D | |
labels = [ | |
"Achaemenid architecture", | |
"American craftsman style", | |
"American Foursquare architecture", | |
"Ancient Egyptian architecture", | |
"Art Deco architecture", | |
"Art Nouveau architecture", | |
"Baroque architecture", | |
"Bauhaus architecture", | |
"Beaux-Arts architecture", | |
"Brutalism architecture", | |
"Byzantine architecture", | |
"Chicago school architecture", | |
"Colonial architecture", | |
"Deconstructivism", | |
"Edwardian architecture", | |
"Georgian architecture", | |
"Gothic architecture", | |
"Greek Revival architecture", | |
"International style", | |
"Islamic architecture", | |
"Novelty architecture", | |
"Palladian architecture", | |
"Postmodern architecture", | |
"Queen Anne architecture", | |
"Romanesque architecture", | |
"Russian Revival architecture", | |
"Tudor Revival architecture" | |
] | |
print(len(labels)) | |
model = torchvision.models.efficientnet_v2_l() | |
model.classifier = nn.Sequential( | |
nn.Dropout(p=0.4, inplace=True), | |
nn.Linear(1280, len(labels), bias=True) | |
) | |
load_model(model) | |
target_layers = model.features[-1] | |
classifier = model.classifier | |
cam = LayerCAM(model=model, target_layers=target_layers, use_cuda=False) | |
dff = DeepFeatureFactorization( | |
model=model, target_layer=target_layers, computation_on_concepts=classifier) | |
def show_factorization_on_image(img: np.ndarray, | |
explanations: np.ndarray, | |
colors: List[np.ndarray] = None, | |
image_weight: float = 0.5, | |
concept_labels: List = None) -> np.ndarray: | |
n_components = explanations.shape[0] | |
if colors is None: | |
# taken from https://github.com/edocollins/DFF/blob/master/utils.py | |
_cmap = plt.cm.get_cmap('gist_rainbow') | |
colors = [ | |
np.array( | |
_cmap(i)) for i in np.arange( | |
0, | |
1, | |
1.0 / | |
n_components)] | |
concept_per_pixel = explanations.argmax(axis=0) | |
masks = [] | |
for i in range(n_components): | |
mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) | |
mask[:, :, :] = colors[i][:3] | |
explanation = explanations[i] | |
explanation[concept_per_pixel != i] = 0 | |
mask = np.uint8(mask * 255) | |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) | |
mask[:, :, 2] = np.uint8(255 * explanation) | |
mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) | |
mask = np.float32(mask) / 255 | |
masks.append(mask) | |
mask = np.sum(np.float32(masks), axis=0) | |
result = img * image_weight + mask * (1 - image_weight) | |
result = np.uint8(result * 255) | |
if concept_labels is not None: | |
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches | |
fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) | |
plt.rcParams['legend.fontsize'] = 6 * result.shape[0] / 256 | |
lw = 5 * result.shape[0] / 256 | |
lines = [Line2D([0], [0], color=colors[i], lw=lw) | |
for i in range(n_components)] | |
plt.legend(lines, | |
concept_labels, | |
fancybox=False, | |
shadow=False, | |
frameon=False, | |
loc="center") | |
plt.tight_layout(pad=0, w_pad=0, h_pad=0) | |
plt.axis('off') | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
plt.close(fig=fig) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
data = cv2.resize(data, (result.shape[1], result.shape[0])) | |
result = np.vstack((result, data)) | |
return result | |
def create_labels(concept_scores, top_k=2): | |
""" Create a list with the image-net category names of the top scoring categories""" | |
concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] | |
concept_labels_topk = [] | |
for concept_index in range(concept_categories.shape[0]): | |
categories = concept_categories[concept_index, :] | |
concept_labels = [] | |
for category in categories: | |
score = concept_scores[concept_index, category] | |
label = f"{labels[category].split(',')[0]}:{score*100:.2f}%" | |
concept_labels.append(label) | |
concept_labels_topk.append("\n".join(concept_labels)) | |
return concept_labels_topk | |
def predict(rgb_img, top_k): | |
print(top_k) | |
inp_01 = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.4937, 0.5060, 0.5030], [ | |
0.2705, 0.2653, 0.2998]), | |
transforms.Resize((224, 224)), | |
])(rgb_img) | |
model.eval() | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax( | |
model(inp_01.unsqueeze(0))[0], dim=0) | |
confidences = {labels[i]: float(prediction[i]) | |
for i in range(len(labels))} | |
concepts, batch_explanations, concept_outputs = dff( | |
inp_01.unsqueeze(0), 5) | |
concept_outputs = torch.softmax( | |
torch.from_numpy(concept_outputs), axis=-1).numpy() | |
concept_label_strings = create_labels(concept_outputs, top_k=top_k) | |
print(inp_01.shape) | |
print(batch_explanations[0].shape) | |
res = cv2.resize(np.transpose( | |
batch_explanations[0], (1, 2, 0)), (rgb_img.size[0], rgb_img.size[1])) | |
res = np.transpose(res, (2, 0, 1)) | |
print(res.shape) | |
visualization_01 = show_factorization_on_image(np.float32(rgb_img)/255.0, | |
res, | |
image_weight=0.3, | |
concept_labels=concept_label_strings) | |
return confidences, visualization_01, | |
gr.Interface(fn=predict, | |
inputs=[gr.Image(type="pil"), gr.Slider( | |
minimum=1, maximum=4, label="Number of top results", step=1)], | |
outputs=[gr.Label(num_top_classes=5), "image"], | |
examples=[["./assets/bauhaus.jpg", 1], | |
["./assets/frank_gehry.jpg", 2], ["./assets/pyramid.jpg", 3]] | |
).launch() | |
# examples=["./assets/bauhaus.jpg", "./assets/frank_gehry.jpg", "./assets/pyramid.jpg"] | |