Q-SENN_Interface_singlecolor / visualization.py
Haaribo's picture
commit from qixuan
dc96f30
import gradio as gr
from load_model import extract_sel_mean_std_bias_assignemnt
from pathlib import Path
from architectures.model_mapping import get_model
from configs.dataset_params import dataset_constants
import torch
import torchvision.transforms as transforms
import pandas as pd
import cv2
import numpy as np
from PIL import Image
from get_data import get_augmentation
from configs.dataset_params import normalize_params
import random
from evaluation.diversity import MultiKCrossChannelMaxPooledSum
def overlapping_features_on_input(model,output, feature_maps, input, target):
W=model.linear.layer.weight
feature_maps=feature_maps.detach().cpu().numpy().squeeze()
print("feature_maps",feature_maps.shape)
if target !=None:
label=target-1
else:
output=output.detach().cpu().numpy()
label=np.argmax(output)
Interpretable_Selection= W[label,:]
print("W",Interpretable_Selection)
input_np=np.array(input)
h,w= input.shape[:2]
print("h,w:",h,w)
Interpretable_Features=[]
input_np=cv2.resize(input_np,(448,448))
Feature_image_list=[input_np]
# color_id=0 #set each feature to singel color
# COLOR=['R','G','B','Y','P','C']
for S in range(len(Interpretable_Selection)):
if Interpretable_Selection[S] != 0:
Interpretable_Features.append(feature_maps[S])
Feature_image=cv2.resize(feature_maps[S],(448,448))
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255)
Feature_image=Feature_image.astype(np.uint8)
#set each feature to singel color
# if color_id>len(COLOR)-1:
# color_id=color_id%len(COLOR)
# color=COLOR[color_id]
# if color == 'R':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,0]=Feature_image
# Feature_image=Feature_image_color
# if color == 'G':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,1]=Feature_image
# Feature_image=Feature_image_color
# if color == 'B':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,2]=Feature_image
# Feature_image=Feature_image_color
# if color == 'Y':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,0]=Feature_image
# Feature_image_color[:,:,1]=Feature_image
# Feature_image=Feature_image_color
# if color == 'P':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,0]=Feature_image
# Feature_image_color[:,:,2]=Feature_image
# Feature_image=Feature_image_color
# if color == 'C':
# Feature_image_color=np.zeros_like(input_np)
# Feature_image_color[:,:,1]=Feature_image
# Feature_image_color[:,:,2]=Feature_image
# Feature_image=Feature_image_color
# color_id+=1
# use Gamma correction
# Feature_image=np.power(Feature_image,1.5)
# use Gamma correction
#set each feature to singel color
Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)
Feature_image=0.3*Feature_image+0.7*input_np
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255)
Feature_image=Feature_image.astype(np.uint8)
# path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
# cv2.imwrite(path_to_featureimage,Feature_image)
Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR)
Feature_image_list.append(Feature_image)
print("len of Features:",len(Interpretable_Features))
return Feature_image_list
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True):
n_classes = dataset_constants[dataset]["num_classes"]
# image_re=np.array(input)
input=Image.fromarray(input)
print("input shape",input.size)
model = get_model(arch, n_classes, reduced_strides)
tr=transform_input_img(input,img_size)
# tr=transforms.Compose([
# transforms.Resize(500),
# transforms.CenterCrop(img_size),
# transforms.ToTensor(),
# ])
#TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
device = torch.device("cpu")
if folder is None:
folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth",map_location=torch.device('cpu')))
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth",map_location=torch.device('cpu'))
selection= torch.load(folder / f"SlDD_Selection_50.pt",map_location=torch.device('cpu'))
state_dict['linear.selection']=selection
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
model.load_state_dict(state_dict)
input = tr(input)
# path_to_input="/home/qixuan/tmp/FeatureImage/croped.jpg"
# path_to_input_re="/home/qixuan/tmp/FeatureImage/re.jpg"
# path_to_input_concat="/home/qixuan/tmp/FeatureImage/concate.jpg"
# image_re=cv2.cvtColor(image_re, cv2.COLOR_RGB2BGR)
# image_re=cv2.resize(image_re,(448,448))
# image_np = (input * 255).clamp(0, 255).byte()
# image_np = image_np.permute(1, 2, 0).numpy()
# image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# print("????",input.shape)
# concat=np.vstack((image_re, image_np))
# cv2.imwrite(path_to_input,image_np)
# cv2.imwrite(path_to_input_re,image_re)
# cv2.imwrite(path_to_input_concat,concat)
input= input.unsqueeze(0)
input= input.to(device)
model = model.to(device)
model.eval()
with torch.no_grad():
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
print("featuresmap size:",feature_maps.size())
output_np=output.detach().cpu().numpy()
output_np= np.argmax(output_np)+1
if with_featuremaps:
return output_np,model,feature_maps
else:
return output_np, model
def get_options_from_trainingset(output, model, TR, device,with_other_class):
print("outputclass:",output)
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target'])
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
options_output=labels[labels['target']==output]
print(options_output)
print(labels)
options=options_output.sample(4)
#mode 2
if with_other_class:
other_targets=random.sample([i for i in range(1,200)if i != output],3)
all_targets=[output]+other_targets
for tg in other_targets:
others=labels[labels['target']==tg]
options_others=others.sample(4)
options = pd.concat([options, options_others], ignore_index=True)
else:
all_targets=[output]
#shuffled_options = options.sample(frac=1).reset_index(drop=True)
print("shuffled:",options)
op=[]
# resample_img_id_list=[]#resample filter
W=model.linear.layer.weight# intergrate negative features
model.eval()
with torch.no_grad():
for t in all_targets:
# intergrate negative features
W_class=W[t-1,:]
features_id=[ f for f in W_class if f !=0 ]
features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0]
# intergrate negative features
image = cv2.imread(f"options_heatmap/{t}.jpg")
concatenate_class = np.array(image)
concatenate_class = cv2.cvtColor(concatenate_class, cv2.COLOR_RGB2BGR)
op.append((concatenate_class,features_id_neg))# intergrate negative features
return op
def transform_input_img(input,img_size):
h,w=input.size
rate=h/w
if h >= w:
w_new=img_size
h_new=int(w_new*rate)
else:
h_new=img_size
w_new=int(h_new/rate)
return transforms.Compose([
transforms.Resize((w_new,h_new)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
])
def post_next_image(OPT: str,key:str):
if OPT==key:
return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False))
else:
return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False))
def get_features_on_interface(input):
img_size=448
output,model=genreate_intepriable_output(input,dataset="CUB2011",
arch="resnet50",seed=123456,
model_type="qsenn", n_features = 50,n_per_class=5,
img_size=448, reduced_strides=False, folder = None,with_featuremaps=False)
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
device = torch.device("cpu")
op= get_options_from_trainingset(output, model, TR, device,with_other_class=True)
key=op[0][0]# intergrate negative features
random.shuffle(op)
option=[(op[0][0],"A"),
(op[1][0],"B"),
(op[2][0],"C"),
(op[3][0],"D")]
for value,char in option:
if np.array_equal(value,key):
key_op=char
print("key",key_op)
# if op[0][1]!=[]:
# option[0][1]=f"A,features{', '.join(map(str, op[0][1]))} are negative."
# if op[1][1]!=[]:
# option[1][1]=f"B,features{', '.join(map(str, op[0][1]))} are negative."
# if op[2][1]!=[]:
# option[2][1]=f"C,features{', '.join(map(str, op[0][1]))} are negative."
# if op[3][1]!=[]:
# option[3][1]=f"D,features{', '.join(map(str, op[0][1]))} are negative."
return option, key_op," These are some class explanations from our model for different classes,which of these classes has our model predicted?",gr.update(interactive=False)
def direct_inference(input):
img_size=448
output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011",
arch="resnet50",seed=123456,
model_type="qsenn", n_features = 50,n_per_class=5,
img_size=448, reduced_strides=False, folder = None,with_featuremaps=True)
# image_list=overlapping_features_on_input(model,output,feature_maps,input,target=None)
# image_arrays = [np.array(img) for img in image_list]
# concatenated_image = np.concatenate(image_arrays, axis=0)
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
device = torch.device("cpu")
concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False)
#original
Input=Image.fromarray(input)
tr=transform_input_img(Input,img_size)
Input=tr(Input)
image_np = (Input * 255).clamp(0, 255).byte()
image_np = image_np.permute(1, 2, 0).numpy()
# image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output)#input image_np
ORI_arrays = [np.array(img) for img in ORI]
concatenated_ORI = np.concatenate(ORI_arrays, axis=0)
print(concatenated_ORI.shape,concatenated_image[0][0].shape)
concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1)
print(concatenated_image_final_array.shape)
#original
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
if concatenated_image[0][1]!=[]:
output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative."
else:
output_name_and_features=f"{output_name}, all features are positive."
return concatenated_image_final_array, output_name_and_features
def filter_with_diversity(featuremaps,output,weight):
localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None)
localizer(output.to("cpu"),featuremaps.to("cpu"))
locality, exlusive_locality = localizer.get_result()
diversity = locality[4]
diversity=diversity.item()
return diversity