File size: 4,367 Bytes
d60982d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import sys
import os
import torch
import random
import numpy as np
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from .ViT import *
from .gcn import GCNBlock
from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool
from torch.nn import Linear
class Classifier(nn.Module):
def __init__(self, n_class):
super(Classifier, self).__init__()
self.n_class = n_class
self.embed_dim = 64
self.num_layers = 3
self.node_cluster_num = 100
self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.criterion = nn.CrossEntropyLoss()
self.bn = 1
self.add_self = 1
self.normalize_embedding = 1
self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128
self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20
def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True):
# node_feat, labels = self.PrepareFeatureLabel(batch_graph)
cls_loss=node_feat.new_zeros(self.num_layers)
rank_loss=node_feat.new_zeros(self.num_layers-1)
X=node_feat
p_t=[]
pred_logits=0
visualize_tools=[]
if labels is not None:
visualize_tools1=[labels.cpu()]
embeds=0
concats=[]
layer_acc=[]
X=mask.unsqueeze(2)*X
X = self.conv1(X, adj, mask)
s = self.pool1(X)
graphcam_tensors = {}
if graphcam_flag:
s_matrix = torch.argmax(s[0], dim=1)
if to_file:
from os import path
os.makedirs('graphcam', exist_ok=True)
torch.save(s_matrix, 'graphcam/s_matrix.pt')
torch.save(s[0], 'graphcam/s_matrix_ori.pt')
if path.exists('graphcam/att_1.pt'):
os.remove('graphcam/att_1.pt')
os.remove('graphcam/att_2.pt')
os.remove('graphcam/att_3.pt')
if not to_file:
graphcam_tensors['s_matrix'] = s_matrix
graphcam_tensors['s_matrix_ori'] = s[0]
X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask)
b, _, _ = X.shape
cls_token = self.cls_token.repeat(b, 1, 1)
X = torch.cat([cls_token, X], dim=1)
out = self.transformer(X)
loss = None
if labels is not None:
# loss
loss = self.criterion(out, labels)
loss = loss + mc1 + o1
# pred
pred = out.data.max(1)[1]
if graphcam_flag:
#print('GraphCAM enabled')
#print(out.shape)
p = F.softmax(out)
#print(p.shape)
if to_file:
torch.save(p, 'graphcam/prob.pt')
if not to_file:
graphcam_tensors['prob'] = p
index = np.argmax(out.cpu().data.numpy(), axis=-1)
for index_ in range(self.n_class):
one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32)
one_hot[0, index_] = out[0][index_]
one_hot_vector = one_hot
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p
self.transformer.zero_grad()
one_hot.backward(retain_graph=True)
kwargs = {"alpha": 1}
cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False,
start_layer=0, **kwargs)
if to_file:
torch.save(cam, 'graphcam/cam_{}.pt'.format(index_))
if not to_file:
graphcam_tensors[f'cam_{index_}'] = cam
if not to_file:
return pred,labels,loss, graphcam_tensors
return pred,labels,loss
|