Haaribo's picture
Add application file
8d4ee22
import numpy as np
import torch
from evaluation.helpers import softmax_feature_maps
class MultiKCrossChannelMaxPooledSum:
def __init__(self, top_k_range, weights, interactions, func="softmax"):
self.top_k_range = top_k_range
self.weights = weights
self.failed = False
self.max_ks = self.get_max_ks(weights)
self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device)
self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device)
self.ns_k = torch.zeros(len(top_k_range), device=weights.device)
self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device)
self.interactions = interactions
self.func = func
def get_max_ks(self, weights):
nonzeros = torch.count_nonzero(torch.tensor(weights), 1)
return nonzeros
def get_top_n_locality(self, outputs, initial_feature_maps, k):
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
initial_feature_maps)
max_ks = self.max_ks[top_classes]
max_k_based_row_selection = max_ks >= k
result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps,
separated=True)
return result
def get_locality(self, outputs, initial_feature_maps, n):
answer = self.get_top_n_locality(outputs, initial_feature_maps, n)
return answer
def get_result(self):
# if torch.sum(self.exclusive_ns) ==0:
# end_idx = len(self.exclusive_ns) - 1
# else:
exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features)
local_array = torch.zeros_like(self.locality_of_used_features)
# if self.failed:
# return local_array, exclusive_array
cumulated = torch.cumsum(self.exclusive_ns, 0)
end_idx = torch.argmax(cumulated)
exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1]
exclusivity_array[exclusivity_array != exclusivity_array] = 0
exclusive_array[:len(exclusivity_array)] = exclusivity_array
locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[
self.locality_of_used_features != 0]
local_array[:len(locality_array)] = locality_array
return local_array, exclusive_array
def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False):
relevant_indices = get_relevant_indices(relevant_weights, k)[mask]
# this should have size batch x k x featuremapsize squared]
indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size)
sub_feature_maps = torch.gather(feature_maps[mask], 1, indices)
# shape batch x featuremapsquared: For each "pixel" the highest value
cross_pooled = torch.max(sub_feature_maps, 1)[0]
if separated:
return torch.sum(cross_pooled, 1) / k
else:
ns = len(cross_pooled)
result = torch.sum(cross_pooled) / (k)
# should be batch x map size
return ns, result
def adapt_feature_maps(self, outputs, initial_feature_maps):
if self.func == "softmax":
feature_maps = softmax_feature_maps(initial_feature_maps)
feature_maps = torch.flatten(feature_maps, 2)
vector_size = feature_maps.shape[2]
top_classes = torch.argmax(outputs, dim=1)
relevant_weights = self.weights[top_classes]
if relevant_weights.shape[1] != feature_maps.shape[1]:
feature_maps = self.interactions.get_localized_features(initial_feature_maps)
feature_maps = softmax_feature_maps(feature_maps)
feature_maps = torch.flatten(feature_maps, 2)
return feature_maps, relevant_weights, vector_size, top_classes
def calculate_locality(self, outputs, initial_feature_maps):
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
initial_feature_maps)
max_ks = self.max_ks[top_classes]
for k in self.top_k_range:
# relevant_k_s = max_ks[]
max_k_based_row_selection = max_ks >= k
if torch.sum(max_k_based_row_selection) == 0:
break
exclusive_k = max_ks == k
if torch.sum(exclusive_k) != 0:
ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps)
self.locality_of_exclusely_used_features[k - 1] += result
self.exclusive_ns[k - 1] += ns
ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps)
self.ns_k[k - 1] += ns
self.locality_of_used_features[k - 1] += result
def __call__(self, outputs, initial_feature_maps):
self.calculate_locality(outputs, initial_feature_maps)
def get_relevant_indices(weights, top_k):
top_k = weights.topk(top_k)[1]
return top_k