Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from evaluation.helpers import softmax_feature_maps | |
class MultiKCrossChannelMaxPooledSum: | |
def __init__(self, top_k_range, weights, interactions, func="softmax"): | |
self.top_k_range = top_k_range | |
self.weights = weights | |
self.failed = False | |
self.max_ks = self.get_max_ks(weights) | |
self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device) | |
self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device) | |
self.ns_k = torch.zeros(len(top_k_range), device=weights.device) | |
self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device) | |
self.interactions = interactions | |
self.func = func | |
def get_max_ks(self, weights): | |
nonzeros = torch.count_nonzero(torch.tensor(weights), 1) | |
return nonzeros | |
def get_top_n_locality(self, outputs, initial_feature_maps, k): | |
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, | |
initial_feature_maps) | |
max_ks = self.max_ks[top_classes] | |
max_k_based_row_selection = max_ks >= k | |
result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps, | |
separated=True) | |
return result | |
def get_locality(self, outputs, initial_feature_maps, n): | |
answer = self.get_top_n_locality(outputs, initial_feature_maps, n) | |
return answer | |
def get_result(self): | |
# if torch.sum(self.exclusive_ns) ==0: | |
# end_idx = len(self.exclusive_ns) - 1 | |
# else: | |
exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features) | |
local_array = torch.zeros_like(self.locality_of_used_features) | |
# if self.failed: | |
# return local_array, exclusive_array | |
cumulated = torch.cumsum(self.exclusive_ns, 0) | |
end_idx = torch.argmax(cumulated) | |
exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1] | |
exclusivity_array[exclusivity_array != exclusivity_array] = 0 | |
exclusive_array[:len(exclusivity_array)] = exclusivity_array | |
locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[ | |
self.locality_of_used_features != 0] | |
local_array[:len(locality_array)] = locality_array | |
return local_array, exclusive_array | |
def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False): | |
relevant_indices = get_relevant_indices(relevant_weights, k)[mask] | |
# this should have size batch x k x featuremapsize squared] | |
indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size) | |
sub_feature_maps = torch.gather(feature_maps[mask], 1, indices) | |
# shape batch x featuremapsquared: For each "pixel" the highest value | |
cross_pooled = torch.max(sub_feature_maps, 1)[0] | |
if separated: | |
return torch.sum(cross_pooled, 1) / k | |
else: | |
ns = len(cross_pooled) | |
result = torch.sum(cross_pooled) / (k) | |
# should be batch x map size | |
return ns, result | |
def adapt_feature_maps(self, outputs, initial_feature_maps): | |
if self.func == "softmax": | |
feature_maps = softmax_feature_maps(initial_feature_maps) | |
feature_maps = torch.flatten(feature_maps, 2) | |
vector_size = feature_maps.shape[2] | |
top_classes = torch.argmax(outputs, dim=1) | |
relevant_weights = self.weights[top_classes] | |
if relevant_weights.shape[1] != feature_maps.shape[1]: | |
feature_maps = self.interactions.get_localized_features(initial_feature_maps) | |
feature_maps = softmax_feature_maps(feature_maps) | |
feature_maps = torch.flatten(feature_maps, 2) | |
return feature_maps, relevant_weights, vector_size, top_classes | |
def calculate_locality(self, outputs, initial_feature_maps): | |
feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, | |
initial_feature_maps) | |
max_ks = self.max_ks[top_classes] | |
for k in self.top_k_range: | |
# relevant_k_s = max_ks[] | |
max_k_based_row_selection = max_ks >= k | |
if torch.sum(max_k_based_row_selection) == 0: | |
break | |
exclusive_k = max_ks == k | |
if torch.sum(exclusive_k) != 0: | |
ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps) | |
self.locality_of_exclusely_used_features[k - 1] += result | |
self.exclusive_ns[k - 1] += ns | |
ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps) | |
self.ns_k[k - 1] += ns | |
self.locality_of_used_features[k - 1] += result | |
def __call__(self, outputs, initial_feature_maps): | |
self.calculate_locality(outputs, initial_feature_maps) | |
def get_relevant_indices(weights, top_k): | |
top_k = weights.topk(top_k)[1] | |
return top_k |