Q-SENN_Interface_heatmap / FeatureDiversityLoss.py
Haaribo's picture
Add application file
8d4ee22
import torch
from torch import nn
"""
Feature Diversity Loss:
Usage to replicate paper:
Call
loss_function = FeatureDiversityLoss(0.196, linear)
to inititalize loss with linear layer of model.
At each mini batch get feature maps (Output of final convolutional layer) and add to Loss:
loss += loss_function(feature_maps, outputs)
"""
class FeatureDiversityLoss(nn.Module):
def __init__(self, scaling_factor, linear):
super().__init__()
self.scaling_factor = scaling_factor #* 0
print("Scaling Factor: ", self.scaling_factor)
self.linearLayer = linear
def initialize(self, linearLayer):
self.linearLayer = linearLayer
def get_weights(self, outputs):
weight_matrix = self.linearLayer.weight
weight_matrix = torch.abs(weight_matrix)
top_classes = torch.argmax(outputs, dim=1)
relevant_weights = weight_matrix[top_classes]
return relevant_weights
def forward(self, feature_maps, outputs):
relevant_weights = self.get_weights(outputs)
relevant_weights = norm_vector(relevant_weights)
feature_maps = preserve_avg_func(feature_maps)
flattened_feature_maps = feature_maps.flatten(2)
batch, features, map_size = flattened_feature_maps.size()
relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None]
diversity_loss = torch.sum(
torch.amax(relevant_feature_maps, dim=1))
return -diversity_loss / batch * self.scaling_factor
def norm_vector(x):
return x / (torch.norm(x, dim=1) + 1e-5)[:, None]
def preserve_avg_func(x):
avgs = torch.mean(x, dim=[2, 3])
max_avgs = torch.max(avgs, dim=1)[0]
scaling_factor = avgs / torch.clamp(max_avgs[..., None], min=1e-6)
softmaxed_maps = softmax_feature_maps(x)
scaled_maps = softmaxed_maps * scaling_factor[..., None, None]
return scaled_maps
def softmax_feature_maps(x):
return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)