Yvonnefanf
commited on
Commit
·
7b5e67a
1
Parent(s):
404d7f6
first
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +29 -3
- active_learning.py +243 -0
- proxy.py +275 -0
- requirements.txt +33 -0
- singleVis/SingleVisualizationModel.py +95 -0
- singleVis/__init__.py +0 -0
- singleVis/__pycache__/SingleVisualizationModel.cpython-37.pyc +0 -0
- singleVis/__pycache__/SingleVisualizationModel.cpython-39.pyc +0 -0
- singleVis/__pycache__/__init__.cpython-37.pyc +0 -0
- singleVis/__pycache__/__init__.cpython-39.pyc +0 -0
- singleVis/__pycache__/backend.cpython-37.pyc +0 -0
- singleVis/__pycache__/backend.cpython-39.pyc +0 -0
- singleVis/__pycache__/custom_weighted_random_sampler.cpython-37.pyc +0 -0
- singleVis/__pycache__/custom_weighted_random_sampler.cpython-39.pyc +0 -0
- singleVis/__pycache__/data.cpython-37.pyc +0 -0
- singleVis/__pycache__/data.cpython-39.pyc +0 -0
- singleVis/__pycache__/edge_dataset.cpython-37.pyc +0 -0
- singleVis/__pycache__/edge_dataset.cpython-39.pyc +0 -0
- singleVis/__pycache__/intrinsic_dim.cpython-37.pyc +0 -0
- singleVis/__pycache__/intrinsic_dim.cpython-39.pyc +0 -0
- singleVis/__pycache__/jj1sk.cpython-37.pyc +0 -0
- singleVis/__pycache__/jj51sk.cpython-37.pyc +0 -0
- singleVis/__pycache__/jj551sk.cpython-37.pyc +0 -0
- singleVis/__pycache__/jjsk.cpython-37.pyc +0 -0
- singleVis/__pycache__/kcenter_greedy.cpython-37.pyc +0 -0
- singleVis/__pycache__/kcenter_greedy.cpython-39.pyc +0 -0
- singleVis/__pycache__/losses.cpython-37.pyc +0 -0
- singleVis/__pycache__/losses.cpython-39.pyc +0 -0
- singleVis/__pycache__/projector.cpython-37.pyc +0 -0
- singleVis/__pycache__/sVis.cpython-37.pyc +0 -0
- singleVis/__pycache__/s_Vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeVis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeleVis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skele_Vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skele_vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skele_viser.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeletonVis.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeletonViser.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeletonVisualizer.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeleton_generator.cpython-37.pyc +0 -0
- singleVis/__pycache__/skeleton_vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/spatial_edge_constructor.cpython-37.pyc +0 -0
- singleVis/__pycache__/spatial_edge_constructor.cpython-39.pyc +0 -0
- singleVis/__pycache__/spatial_edge_constructor_.cpython-37.pyc +0 -0
- singleVis/__pycache__/spatial_skeleton_edge_constructor.cpython-37.pyc +0 -0
- singleVis/__pycache__/ss_Vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/ssjj_Vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/ssjjjjj_Vis.cpython-37.pyc +0 -0
- singleVis/__pycache__/sss_Vis.cpython-37.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,29 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training Dynamic
|
2 |
+
demo data store in /training_dynamic
|
3 |
+
# evaluate subject model
|
4 |
+
|
5 |
+
```
|
6 |
+
conda activate myvenv
|
7 |
+
python subject_model_eval.py
|
8 |
+
```
|
9 |
+
The trainig dynamic performance will be store in /training_dynamic/Model/subject_model_eval.json
|
10 |
+
|
11 |
+
|
12 |
+
# Run trustvis
|
13 |
+
```
|
14 |
+
|
15 |
+
conda activate deepdebugger
|
16 |
+
# proxy only
|
17 |
+
python porxy.py --epoch 1/2/3 (default 3)
|
18 |
+
|
19 |
+
the vis result will be store in /training_dynamic/Proxy/***.png
|
20 |
+
the evaluation resulte wiil be store in /training_dynamic/Model/proxy_eval.json
|
21 |
+
|
22 |
+
# trustvis with AL
|
23 |
+
python active_learning.py --epoch 1/2/3 (default 3)
|
24 |
+
|
25 |
+
the vis result will be store in /training_dynamic/Trust_al/***.png
|
26 |
+
|
27 |
+
the evaluation resulte wiil be store in /training_dynamic/Model/trustvis_al_eval.json
|
28 |
+
|
29 |
+
```
|
active_learning.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################################
|
2 |
+
# IMPORT #
|
3 |
+
########################################################################################################################
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.utils.data import WeightedRandomSampler
|
14 |
+
from umap.umap_ import find_ab_params
|
15 |
+
|
16 |
+
from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler
|
17 |
+
from singleVis.SingleVisualizationModel import VisModel
|
18 |
+
from singleVis.losses import UmapLoss, ReconstructionLoss, TemporalLoss, DVILoss, SingleVisLoss, DummyTemporalLoss
|
19 |
+
from singleVis.edge_dataset import DVIDataHandler
|
20 |
+
from singleVis.trainer import DVIALMODITrainer
|
21 |
+
from singleVis.data import NormalDataProvider
|
22 |
+
from singleVis.spatial_skeleton_edge_constructor import OriginSingleEpochSpatialEdgeConstructor, PredDistSingleEpochSpatialEdgeConstructor
|
23 |
+
from singleVis.projector import DVIProjector
|
24 |
+
from singleVis.eval.evaluator import Evaluator
|
25 |
+
from singleVis.utils import find_neighbor_preserving_rate
|
26 |
+
from singleVis.visualizer import visualizer
|
27 |
+
from trustVis.skeleton_generator import CenterSkeletonGenerator
|
28 |
+
########################################################################################################################
|
29 |
+
# PARAMETERS #
|
30 |
+
########################################################################################################################
|
31 |
+
"""This serve as an example of DeepVisualInsight implementation in pytorch."""
|
32 |
+
VIS_METHOD = "DVI" # DeepVisualInsight
|
33 |
+
|
34 |
+
########################################################################################################################
|
35 |
+
# LOAD PARAMETERS #
|
36 |
+
########################################################################################################################
|
37 |
+
parser = argparse.ArgumentParser(description='Process hyperparameters...')
|
38 |
+
|
39 |
+
# get workspace dir
|
40 |
+
current_path = os.getcwd()
|
41 |
+
|
42 |
+
new_path = os.path.join(current_path, 'training_dynamic')
|
43 |
+
|
44 |
+
parser.add_argument('--content_path', type=str,default=new_path)
|
45 |
+
parser.add_argument('--base', type=str,default='proxy')
|
46 |
+
parser.add_argument('--name', type=str,default='trustvis')
|
47 |
+
parser.add_argument('--start', type=int,default=1)
|
48 |
+
parser.add_argument('--end', type=int,default=3)
|
49 |
+
parser.add_argument('--epoch', type=int,default=3)
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
|
53 |
+
SAVED_NAME = args.name
|
54 |
+
|
55 |
+
|
56 |
+
CONTENT_PATH = args.content_path
|
57 |
+
sys.path.append(CONTENT_PATH)
|
58 |
+
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
|
59 |
+
config = json.load(f)
|
60 |
+
config = config[VIS_METHOD]
|
61 |
+
|
62 |
+
|
63 |
+
SETTING = config["SETTING"]
|
64 |
+
CLASSES = config["CLASSES"]
|
65 |
+
DATASET = config["DATASET"]
|
66 |
+
PREPROCESS = config["VISUALIZATION"]["PREPROCESS"]
|
67 |
+
GPU_ID = config["GPU"]
|
68 |
+
EPOCH_START = args.epoch
|
69 |
+
EPOCH_END = args.epoch
|
70 |
+
EPOCH_PERIOD = 1
|
71 |
+
|
72 |
+
# Training parameter (subject model)
|
73 |
+
TRAINING_PARAMETER = config["TRAINING"]
|
74 |
+
NET = TRAINING_PARAMETER["NET"]
|
75 |
+
LEN = TRAINING_PARAMETER["train_num"]
|
76 |
+
|
77 |
+
# Training parameter (visualization model)
|
78 |
+
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
|
79 |
+
LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"]
|
80 |
+
LAMBDA2 = VISUALIZATION_PARAMETER["LAMBDA2"]
|
81 |
+
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
|
82 |
+
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
|
83 |
+
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
|
84 |
+
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
|
85 |
+
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
|
86 |
+
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
|
87 |
+
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
|
88 |
+
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]
|
89 |
+
|
90 |
+
VIS_MODEL_NAME = VISUALIZATION_PARAMETER["VIS_MODEL_NAME"]
|
91 |
+
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]
|
92 |
+
|
93 |
+
# Define hyperparameters
|
94 |
+
GPU_ID = 1
|
95 |
+
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
|
96 |
+
|
97 |
+
import Model.model as subject_model
|
98 |
+
net = eval("subject_model.{}()".format(NET))
|
99 |
+
|
100 |
+
########################################################################################################################
|
101 |
+
# TRAINING SETTING #
|
102 |
+
########################################################################################################################
|
103 |
+
BASE_MODEL_NAME = args.base
|
104 |
+
# PREPROCESS = 1
|
105 |
+
# Define data_provider
|
106 |
+
data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, classes=CLASSES, epoch_name='Epoch', verbose=1)
|
107 |
+
# if PREPROCESS:
|
108 |
+
# data_provider._meta_data()
|
109 |
+
# if B_N_EPOCHS >0:
|
110 |
+
# data_provider._estimate_boundary(LEN//10, l_bound=L_BOUND)
|
111 |
+
|
112 |
+
# Define visualization models
|
113 |
+
model = VisModel(ENCODER_DIMS, DECODER_DIMS)
|
114 |
+
|
115 |
+
# Define Losses
|
116 |
+
negative_sample_rate = 5
|
117 |
+
min_dist = .1
|
118 |
+
_a, _b = find_ab_params(1.0, min_dist)
|
119 |
+
umap_loss_fn = UmapLoss(negative_sample_rate, DEVICE, _a, _b, repulsion_strength=1.0)
|
120 |
+
recon_loss_fn = ReconstructionLoss(beta=1.0)
|
121 |
+
single_loss_fn = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA1)
|
122 |
+
# Define Projector
|
123 |
+
projector = DVIProjector(vis_model=model, content_path=CONTENT_PATH, vis_model_name=BASE_MODEL_NAME, device=DEVICE) # vis_model_name 一个初始的dvi
|
124 |
+
|
125 |
+
start_flag = 1
|
126 |
+
prev_model = VisModel(ENCODER_DIMS, DECODER_DIMS)
|
127 |
+
|
128 |
+
for iteration in range(EPOCH_START, EPOCH_END+EPOCH_PERIOD, EPOCH_PERIOD):
|
129 |
+
# Define DVI Loss
|
130 |
+
if start_flag:
|
131 |
+
temporal_loss_fn = DummyTemporalLoss(DEVICE)
|
132 |
+
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=0.0, device=DEVICE)
|
133 |
+
start_flag = 0
|
134 |
+
else:
|
135 |
+
# TODO AL mode, redefine train_representation
|
136 |
+
prev_data = data_provider.train_representation(iteration-EPOCH_PERIOD)
|
137 |
+
curr_data = data_provider.train_representation(iteration)
|
138 |
+
npr = find_neighbor_preserving_rate(prev_data, curr_data, N_NEIGHBORS)
|
139 |
+
temporal_loss_fn = TemporalLoss(w_prev, DEVICE)
|
140 |
+
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=torch.from_numpy(LAMBDA2*npr), device=DEVICE)
|
141 |
+
|
142 |
+
vis = visualizer(data_provider, projector, 200, "tab10")
|
143 |
+
grid_high, grid_emd ,border = vis.get_epoch_decision_view(iteration,400,None, True)
|
144 |
+
train_data_embedding = projector.batch_project(iteration, data_provider.train_representation(iteration))
|
145 |
+
from sklearn.neighbors import NearestNeighbors
|
146 |
+
import numpy as np
|
147 |
+
|
148 |
+
# 假设 train_data_embedding 和 grid_emd 都是 numpy arrays,每一行都是一个点
|
149 |
+
threshold = 5 # hyper-peremeter
|
150 |
+
|
151 |
+
# use train_data_embedding initialize NearestNeighbors
|
152 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(train_data_embedding)
|
153 |
+
# for each grid_emd,find train_data_embedding nearest sample
|
154 |
+
distances, indices = nbrs.kneighbors(grid_emd)
|
155 |
+
# filter by distance
|
156 |
+
mask = distances.ravel() < threshold
|
157 |
+
selected_indices = np.arange(grid_emd.shape[0])[mask]
|
158 |
+
|
159 |
+
grid_high_mask = grid_high[selected_indices]
|
160 |
+
|
161 |
+
skeleton_generator = CenterSkeletonGenerator(data_provider,iteration,0.5,500)
|
162 |
+
high_bom, high_rad = skeleton_generator.center_skeleton_genertaion()
|
163 |
+
print("number",len(high_bom))
|
164 |
+
|
165 |
+
# Define training parameters
|
166 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=.01, weight_decay=1e-5)
|
167 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=.1)
|
168 |
+
# Define Edge dataset
|
169 |
+
t0 = time.time()
|
170 |
+
spatial_cons = OriginSingleEpochSpatialEdgeConstructor(data_provider, iteration, S_N_EPOCHS, B_N_EPOCHS, N_NEIGHBORS)
|
171 |
+
edge_to, edge_from, probs, feature_vectors, attention = spatial_cons.construct()
|
172 |
+
t1 = time.time()
|
173 |
+
|
174 |
+
probs = probs / (probs.max()+1e-3)
|
175 |
+
eliminate_zeros = probs> 1e-3 #1e-3
|
176 |
+
edge_to = edge_to[eliminate_zeros]
|
177 |
+
edge_from = edge_from[eliminate_zeros]
|
178 |
+
probs = probs[eliminate_zeros]
|
179 |
+
|
180 |
+
dataset = DVIDataHandler(edge_to, edge_from, feature_vectors, attention)
|
181 |
+
|
182 |
+
n_samples = int(np.sum(S_N_EPOCHS * probs) // 1)
|
183 |
+
# chose sampler based on the number of dataset
|
184 |
+
if len(edge_to) > pow(2,24):
|
185 |
+
sampler = CustomWeightedRandomSampler(probs, n_samples, replacement=True)
|
186 |
+
else:
|
187 |
+
sampler = WeightedRandomSampler(probs, n_samples, replacement=True)
|
188 |
+
edge_loader = DataLoader(dataset, batch_size=2000, sampler=sampler, num_workers=8, prefetch_factor=10)
|
189 |
+
|
190 |
+
########################################################################################################################
|
191 |
+
# TRAIN #
|
192 |
+
########################################################################################################################
|
193 |
+
file_path = os.path.join(data_provider.content_path, "Model", "Epoch_{}".format(iteration), "{}.pth".format(BASE_MODEL_NAME))
|
194 |
+
save_model = torch.load(file_path, map_location="cpu")
|
195 |
+
model.load_state_dict(save_model["state_dict"])
|
196 |
+
|
197 |
+
trainer = DVIALMODITrainer(model, criterion, optimizer, lr_scheduler, edge_loader=edge_loader, DEVICE=DEVICE, grid_high_mask=grid_high_mask, high_bom=high_bom, high_rad=high_rad, iteration=iteration, data_provider=data_provider, prev_model=prev_model, S_N_EPOCHS=S_N_EPOCHS, B_N_EPOCHS=B_N_EPOCHS, N_NEIGHBORS=N_NEIGHBORS)
|
198 |
+
|
199 |
+
t2=time.time()
|
200 |
+
trainer.train(PATIENT, MAX_EPOCH)
|
201 |
+
t3 = time.time()
|
202 |
+
|
203 |
+
# save result
|
204 |
+
save_dir = data_provider.model_path
|
205 |
+
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "complex_construction", str(iteration), t1-t0)
|
206 |
+
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "training", str(iteration), t3-t2)
|
207 |
+
save_dir = os.path.join(data_provider.model_path, "Epoch_{}".format(iteration))
|
208 |
+
trainer.save(save_dir=save_dir, file_name="{}".format(SAVED_NAME))
|
209 |
+
|
210 |
+
print("Finish epoch {}...".format(iteration))
|
211 |
+
|
212 |
+
prev_model.load_state_dict(model.state_dict())
|
213 |
+
for param in prev_model.parameters():
|
214 |
+
param.requires_grad = False
|
215 |
+
w_prev = dict(prev_model.named_parameters())
|
216 |
+
|
217 |
+
print('aaacccllll runtime', t3-t0)
|
218 |
+
########################################################################################################################
|
219 |
+
# VISUALIZATION #
|
220 |
+
########################################################################################################################
|
221 |
+
|
222 |
+
from singleVis.visualizer import visualizer
|
223 |
+
|
224 |
+
vis = visualizer(data_provider, projector, 200, "tab10")
|
225 |
+
save_dir = os.path.join(data_provider.content_path, "Trust_al")
|
226 |
+
if not os.path.exists(save_dir):
|
227 |
+
os.mkdir(save_dir)
|
228 |
+
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
|
229 |
+
vis.savefig(i, path=os.path.join(save_dir, "{}_{}_{}.png".format(VIS_MODEL_NAME, i, VIS_METHOD)))
|
230 |
+
|
231 |
+
|
232 |
+
########################################################################################################################
|
233 |
+
# EVALUATION #
|
234 |
+
########################################################################################################################
|
235 |
+
|
236 |
+
evaluator = Evaluator(data_provider, projector)
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
Evaluation_NAME = 'trustvis_al_eval'
|
242 |
+
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
|
243 |
+
evaluator.save_epoch_eval(i, 15, temporal_k=5, file_name="{}".format(Evaluation_NAME))
|
proxy.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################################
|
2 |
+
# IMPORT #
|
3 |
+
########################################################################################################################
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.utils.data import WeightedRandomSampler
|
14 |
+
from umap.umap_ import find_ab_params
|
15 |
+
|
16 |
+
from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler
|
17 |
+
from singleVis.SingleVisualizationModel import VisModel
|
18 |
+
from singleVis.losses import UmapLoss, ReconstructionLoss, TemporalLoss, DVILoss, SingleVisLoss, DummyTemporalLoss
|
19 |
+
from singleVis.edge_dataset import DVIDataHandler
|
20 |
+
from singleVis.trainer import DVITrainer
|
21 |
+
from singleVis.eval.evaluator import Evaluator
|
22 |
+
from singleVis.data import NormalDataProvider
|
23 |
+
# from singleVis.spatial_edge_constructor import SingleEpochSpatialEdgeConstructor
|
24 |
+
from singleVis.spatial_skeleton_edge_constructor import ProxyBasedSpatialEdgeConstructor
|
25 |
+
|
26 |
+
from singleVis.projector import DVIProjector
|
27 |
+
from singleVis.utils import find_neighbor_preserving_rate
|
28 |
+
|
29 |
+
from trustVis.skeleton_generator import CenterSkeletonGenerator
|
30 |
+
########################################################################################################################
|
31 |
+
# DVI PARAMETERS #
|
32 |
+
########################################################################################################################
|
33 |
+
"""This serve as an example of DeepVisualInsight implementation in pytorch."""
|
34 |
+
VIS_METHOD = "DVI" # DeepVisualInsight
|
35 |
+
|
36 |
+
########################################################################################################################
|
37 |
+
# LOAD PARAMETERS #
|
38 |
+
########################################################################################################################
|
39 |
+
|
40 |
+
|
41 |
+
parser = argparse.ArgumentParser(description='Process hyperparameters...')
|
42 |
+
|
43 |
+
# get workspace dir
|
44 |
+
current_path = os.getcwd()
|
45 |
+
|
46 |
+
new_path = os.path.join(current_path, 'training_dynamic')
|
47 |
+
|
48 |
+
parser.add_argument('--content_path', type=str,default=new_path)
|
49 |
+
# parser.add_argument('--start', type=int,default=1)
|
50 |
+
# parser.add_argument('--end', type=int,default=3)
|
51 |
+
parser.add_argument('--epoch' ,default=3)
|
52 |
+
|
53 |
+
# parser.add_argument('--epoch_end', type=int)
|
54 |
+
parser.add_argument('--epoch_period', type=int,default=1)
|
55 |
+
parser.add_argument('--preprocess', type=int,default=0)
|
56 |
+
parser.add_argument('--base',type=bool,default=False)
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
CONTENT_PATH = args.content_path
|
60 |
+
sys.path.append(CONTENT_PATH)
|
61 |
+
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
|
62 |
+
config = json.load(f)
|
63 |
+
config = config[VIS_METHOD]
|
64 |
+
|
65 |
+
# record output information
|
66 |
+
# now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
|
67 |
+
# sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w")
|
68 |
+
|
69 |
+
SETTING = config["SETTING"]
|
70 |
+
CLASSES = config["CLASSES"]
|
71 |
+
DATASET = config["DATASET"]
|
72 |
+
PREPROCESS = config["VISUALIZATION"]["PREPROCESS"]
|
73 |
+
GPU_ID = config["GPU"]
|
74 |
+
GPU_ID = 0
|
75 |
+
EPOCH_START = config["EPOCH_START"]
|
76 |
+
EPOCH_END = config["EPOCH_END"]
|
77 |
+
EPOCH_PERIOD = config["EPOCH_PERIOD"]
|
78 |
+
|
79 |
+
EPOCH_START = args.epoch
|
80 |
+
EPOCH_END = args.epoch
|
81 |
+
EPOCH_PERIOD = args.epoch_period
|
82 |
+
|
83 |
+
# Training parameter (subject model)
|
84 |
+
TRAINING_PARAMETER = config["TRAINING"]
|
85 |
+
NET = TRAINING_PARAMETER["NET"]
|
86 |
+
LEN = TRAINING_PARAMETER["train_num"]
|
87 |
+
|
88 |
+
# Training parameter (visualization model)
|
89 |
+
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
|
90 |
+
LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"]
|
91 |
+
LAMBDA2 = VISUALIZATION_PARAMETER["LAMBDA2"]
|
92 |
+
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
|
93 |
+
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
|
94 |
+
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
|
95 |
+
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
|
101 |
+
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
|
102 |
+
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
|
103 |
+
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]
|
104 |
+
|
105 |
+
VIS_MODEL_NAME = 'proxy' ### saved_as
|
106 |
+
|
107 |
+
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]
|
108 |
+
|
109 |
+
# Define hyperparameters
|
110 |
+
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
|
111 |
+
|
112 |
+
import Model.model as subject_model
|
113 |
+
net = eval("subject_model.{}()".format(NET))
|
114 |
+
|
115 |
+
########################################################################################################################
|
116 |
+
# TRAINING SETTING #
|
117 |
+
########################################################################################################################
|
118 |
+
# Define data_provider
|
119 |
+
data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, epoch_name='Epoch',classes=CLASSES,verbose=1)
|
120 |
+
PREPROCESS = args.preprocess
|
121 |
+
if PREPROCESS:
|
122 |
+
data_provider._meta_data()
|
123 |
+
if B_N_EPOCHS >0:
|
124 |
+
data_provider._estimate_boundary(LEN // 10, l_bound=L_BOUND)
|
125 |
+
|
126 |
+
# Define visualization models
|
127 |
+
model = VisModel(ENCODER_DIMS, DECODER_DIMS)
|
128 |
+
|
129 |
+
|
130 |
+
# Define Losses
|
131 |
+
negative_sample_rate = 5
|
132 |
+
min_dist = .1
|
133 |
+
_a, _b = find_ab_params(1.0, min_dist)
|
134 |
+
umap_loss_fn = UmapLoss(negative_sample_rate, DEVICE, _a, _b, repulsion_strength=1.0)
|
135 |
+
recon_loss_fn = ReconstructionLoss(beta=1.0)
|
136 |
+
single_loss_fn = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA1)
|
137 |
+
# Define Projector
|
138 |
+
projector = DVIProjector(vis_model=model, content_path=CONTENT_PATH, vis_model_name=VIS_MODEL_NAME, device=DEVICE)
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
start_flag = 1
|
145 |
+
|
146 |
+
prev_model = VisModel(ENCODER_DIMS, DECODER_DIMS)
|
147 |
+
|
148 |
+
for iteration in range(EPOCH_START, EPOCH_END+EPOCH_PERIOD, EPOCH_PERIOD):
|
149 |
+
# Define DVI Loss
|
150 |
+
if start_flag:
|
151 |
+
temporal_loss_fn = DummyTemporalLoss(DEVICE)
|
152 |
+
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=0.0,device=DEVICE)
|
153 |
+
start_flag = 0
|
154 |
+
else:
|
155 |
+
# TODO AL mode, redefine train_representation
|
156 |
+
prev_data = data_provider.train_representation(iteration-EPOCH_PERIOD)
|
157 |
+
prev_data = prev_data.reshape(prev_data.shape[0],prev_data.shape[1])
|
158 |
+
curr_data = data_provider.train_representation(iteration)
|
159 |
+
curr_data = curr_data.reshape(curr_data.shape[0],curr_data.shape[1])
|
160 |
+
t_1= time.time()
|
161 |
+
npr = torch.tensor(find_neighbor_preserving_rate(prev_data, curr_data, N_NEIGHBORS)).to(DEVICE)
|
162 |
+
t_2= time.time()
|
163 |
+
|
164 |
+
temporal_loss_fn = TemporalLoss(w_prev, DEVICE)
|
165 |
+
criterion = DVILoss(umap_loss_fn, recon_loss_fn, temporal_loss_fn, lambd1=LAMBDA1, lambd2=LAMBDA2*npr,device=DEVICE)
|
166 |
+
# Define training parameters
|
167 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=.01, weight_decay=1e-5)
|
168 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=.1)
|
169 |
+
# Define Edge dataset
|
170 |
+
|
171 |
+
###### generate the skeleton
|
172 |
+
|
173 |
+
skeleton_generator = CenterSkeletonGenerator(data_provider,EPOCH_START,1)
|
174 |
+
# Start timing
|
175 |
+
start_time = time.time()
|
176 |
+
## gennerate skeleton
|
177 |
+
high_bom,_ = skeleton_generator.center_skeleton_genertaion()
|
178 |
+
|
179 |
+
|
180 |
+
end_time = time.time()
|
181 |
+
elapsed_time = end_time - start_time
|
182 |
+
print("proxy generation finished ")
|
183 |
+
|
184 |
+
|
185 |
+
t0 = time.time()
|
186 |
+
##### construct the spitial complex
|
187 |
+
spatial_cons = ProxyBasedSpatialEdgeConstructor(data_provider, iteration, S_N_EPOCHS, B_N_EPOCHS, N_NEIGHBORS, net,high_bom)
|
188 |
+
edge_to, edge_from, probs, feature_vectors, attention = spatial_cons.construct()
|
189 |
+
t1 = time.time()
|
190 |
+
|
191 |
+
print('complex-construct:', t1-t0)
|
192 |
+
|
193 |
+
probs = probs / (probs.max()+1e-3)
|
194 |
+
eliminate_zeros = probs> 1e-3 #1e-3
|
195 |
+
edge_to = edge_to[eliminate_zeros]
|
196 |
+
edge_from = edge_from[eliminate_zeros]
|
197 |
+
probs = probs[eliminate_zeros]
|
198 |
+
|
199 |
+
dataset = DVIDataHandler(edge_to, edge_from, feature_vectors, attention)
|
200 |
+
|
201 |
+
n_samples = int(np.sum(S_N_EPOCHS * probs) // 1)
|
202 |
+
# chose sampler based on the number of dataset
|
203 |
+
if len(edge_to) > pow(2,24):
|
204 |
+
sampler = CustomWeightedRandomSampler(probs, n_samples, replacement=True)
|
205 |
+
else:
|
206 |
+
sampler = WeightedRandomSampler(probs, n_samples, replacement=True)
|
207 |
+
edge_loader = DataLoader(dataset, batch_size=2000, sampler=sampler, num_workers=8, prefetch_factor=10)
|
208 |
+
|
209 |
+
########################################################################################################################
|
210 |
+
# TRAIN #
|
211 |
+
########################################################################################################################
|
212 |
+
|
213 |
+
trainer = DVITrainer(model, criterion, optimizer, lr_scheduler, edge_loader=edge_loader, DEVICE=DEVICE)
|
214 |
+
|
215 |
+
t2=time.time()
|
216 |
+
trainer.train(PATIENT, MAX_EPOCH)
|
217 |
+
t3 = time.time()
|
218 |
+
print('training:', t3-t2)
|
219 |
+
# save result
|
220 |
+
save_dir = data_provider.model_path
|
221 |
+
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "complex_construction", str(iteration), t1-t0)
|
222 |
+
trainer.record_time(save_dir, "time_{}".format(VIS_MODEL_NAME), "training", str(iteration), t3-t2)
|
223 |
+
save_dir = os.path.join(data_provider.model_path, "Epoch_{}".format(iteration))
|
224 |
+
trainer.save(save_dir=save_dir, file_name="{}".format(VIS_MODEL_NAME))
|
225 |
+
|
226 |
+
print("Finish epoch {}...".format(iteration))
|
227 |
+
|
228 |
+
prev_model.load_state_dict(model.state_dict())
|
229 |
+
for param in prev_model.parameters():
|
230 |
+
param.requires_grad = False
|
231 |
+
w_prev = dict(prev_model.named_parameters())
|
232 |
+
|
233 |
+
|
234 |
+
########################################################################################################################
|
235 |
+
# VISUALIZATION #
|
236 |
+
########################################################################################################################
|
237 |
+
|
238 |
+
from singleVis.visualizer import visualizer
|
239 |
+
now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
|
240 |
+
vis = visualizer(data_provider, projector, 200, "tab10")
|
241 |
+
save_dir = os.path.join(data_provider.content_path, "Proxy")
|
242 |
+
|
243 |
+
if not os.path.exists(save_dir):
|
244 |
+
os.mkdir(save_dir)
|
245 |
+
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
|
246 |
+
vis.savefig(i, path=os.path.join(save_dir, "{}_{}_{}_{}.png".format(DATASET, i, VIS_METHOD,now)))
|
247 |
+
data = data_provider.train_representation(i)
|
248 |
+
data = data.reshape(data.shape[0],data.shape[1])
|
249 |
+
|
250 |
+
##### save embeddings and background for visualization
|
251 |
+
emb = projector.batch_project(i,data)
|
252 |
+
np.save(os.path.join(CONTENT_PATH, 'Model', 'Epoch_{}'.format(i), 'embedding.npy'), emb)
|
253 |
+
vis.get_background(i,200)
|
254 |
+
|
255 |
+
# emb = projector.batch_project(data_provider)
|
256 |
+
|
257 |
+
|
258 |
+
########################################################################################################################
|
259 |
+
# EVALUATION #
|
260 |
+
########################################################################################################################
|
261 |
+
# eval_epochs = range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD)
|
262 |
+
# EVAL_EPOCH_DICT = {
|
263 |
+
# "mnist":[1,10,15],
|
264 |
+
# "fmnist":[1,25,50],
|
265 |
+
# "cifar10":[1,100,199]
|
266 |
+
# }
|
267 |
+
# eval_epochs = EVAL_EPOCH_DICT[DATASET]
|
268 |
+
evaluator = Evaluator(data_provider, projector)
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
Evaluation_NAME = 'proxy_eval'
|
274 |
+
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
|
275 |
+
evaluator.save_epoch_eval(i, 15, temporal_k=5, file_name="{}".format(Evaluation_NAME))
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
certifi @ file:///croot/certifi_1671487769961/work/certifi
|
2 |
+
cycler==0.11.0
|
3 |
+
fonttools==4.38.0
|
4 |
+
importlib-metadata==6.7.0
|
5 |
+
Jinja2==3.1.2
|
6 |
+
joblib==1.3.2
|
7 |
+
kiwisolver==1.4.5
|
8 |
+
kmapper==2.0.1
|
9 |
+
llvmlite==0.39.1
|
10 |
+
MarkupSafe==2.1.3
|
11 |
+
matplotlib==3.5.2
|
12 |
+
networkx==2.6.3
|
13 |
+
numba==0.56.4
|
14 |
+
numpy==1.21.6
|
15 |
+
nvidia-cublas-cu11==11.10.3.66
|
16 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
17 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
18 |
+
nvidia-cudnn-cu11==8.5.0.96
|
19 |
+
packaging==23.2
|
20 |
+
Pillow==9.2.0
|
21 |
+
pynndescent==0.5.11
|
22 |
+
pyparsing==3.1.1
|
23 |
+
python-dateutil==2.8.2
|
24 |
+
scikit-learn==1.0.2
|
25 |
+
scipy==1.7.3
|
26 |
+
six==1.16.0
|
27 |
+
threadpoolctl==3.1.0
|
28 |
+
torch==1.13.1
|
29 |
+
tqdm==4.66.1
|
30 |
+
typing_extensions==4.7.1
|
31 |
+
umap==0.1.1
|
32 |
+
umap-learn==0.5.3
|
33 |
+
zipp==3.15.0
|
singleVis/SingleVisualizationModel.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class SingleVisualizationModel(nn.Module):
|
5 |
+
def __init__(self, input_dims, output_dims, units, hidden_layer=3):
|
6 |
+
super(SingleVisualizationModel, self).__init__()
|
7 |
+
|
8 |
+
self.input_dims = input_dims
|
9 |
+
self.output_dims = output_dims
|
10 |
+
self.units = units
|
11 |
+
self.hidden_layer = hidden_layer
|
12 |
+
self._init_autoencoder()
|
13 |
+
|
14 |
+
# TODO find the best model architecture
|
15 |
+
def _init_autoencoder(self):
|
16 |
+
self.encoder = nn.Sequential(
|
17 |
+
nn.Linear(self.input_dims, self.units),
|
18 |
+
nn.ReLU(True))
|
19 |
+
for h in range(self.hidden_layer):
|
20 |
+
self.encoder.add_module("{}".format(2*h+2), nn.Linear(self.units, self.units))
|
21 |
+
self.encoder.add_module("{}".format(2*h+3), nn.ReLU(True))
|
22 |
+
self.encoder.add_module("{}".format(2*(self.hidden_layer+1)), nn.Linear(self.units, self.output_dims))
|
23 |
+
|
24 |
+
self.decoder = nn.Sequential(
|
25 |
+
nn.Linear(self.output_dims, self.units),
|
26 |
+
nn.ReLU(True))
|
27 |
+
for h in range(self.hidden_layer):
|
28 |
+
self.decoder.add_module("{}".format(2*h+2), nn.Linear(self.units, self.units))
|
29 |
+
self.decoder.add_module("{}".format(2*h+3), nn.ReLU(True))
|
30 |
+
self.decoder.add_module("{}".format(2*(self.hidden_layer+1)), nn.Linear(self.units, self.input_dims))
|
31 |
+
|
32 |
+
def forward(self, edge_to, edge_from):
|
33 |
+
outputs = dict()
|
34 |
+
embedding_to = self.encoder(edge_to)
|
35 |
+
embedding_from = self.encoder(edge_from)
|
36 |
+
recon_to = self.decoder(embedding_to)
|
37 |
+
recon_from = self.decoder(embedding_from)
|
38 |
+
|
39 |
+
outputs["umap"] = (embedding_to, embedding_from)
|
40 |
+
outputs["recon"] = (recon_to, recon_from)
|
41 |
+
|
42 |
+
return outputs
|
43 |
+
|
44 |
+
class VisModel(nn.Module):
|
45 |
+
"""define you own visualizatio model by specifying the structure
|
46 |
+
|
47 |
+
"""
|
48 |
+
def __init__(self, encoder_dims, decoder_dims):
|
49 |
+
"""define you own visualizatio model by specifying the structure
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
encoder_dims : list of int
|
54 |
+
the neuron number of your encoder
|
55 |
+
for example, [100,50,2], denote two fully connect layers, with shape (100,50) and (50,2)
|
56 |
+
decoder_dims : list of int
|
57 |
+
same as encoder_dims
|
58 |
+
"""
|
59 |
+
super(VisModel, self).__init__()
|
60 |
+
assert len(encoder_dims) > 1
|
61 |
+
assert len(decoder_dims) > 1
|
62 |
+
self.encoder_dims = encoder_dims
|
63 |
+
self.decoder_dims = decoder_dims
|
64 |
+
self._init_autoencoder()
|
65 |
+
|
66 |
+
def _init_autoencoder(self):
|
67 |
+
self.encoder = nn.Sequential()
|
68 |
+
for i in range(0, len(self.encoder_dims)-2):
|
69 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.Linear(self.encoder_dims[i], self.encoder_dims[i+1]))
|
70 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.ReLU(True))
|
71 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.Linear(self.encoder_dims[-2], self.encoder_dims[-1]))
|
72 |
+
|
73 |
+
self.decoder = nn.Sequential()
|
74 |
+
for i in range(0, len(self.decoder_dims)-2):
|
75 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.Linear(self.decoder_dims[i], self.decoder_dims[i+1]))
|
76 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.ReLU(True))
|
77 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.Linear(self.decoder_dims[-2], self.decoder_dims[-1]))
|
78 |
+
|
79 |
+
|
80 |
+
def forward(self, edge_to, edge_from):
|
81 |
+
outputs = dict()
|
82 |
+
embedding_to = self.encoder(edge_to)
|
83 |
+
embedding_from = self.encoder(edge_from)
|
84 |
+
recon_to = self.decoder(embedding_to)
|
85 |
+
recon_from = self.decoder(embedding_from)
|
86 |
+
|
87 |
+
outputs["umap"] = (embedding_to, embedding_from)
|
88 |
+
outputs["recon"] = (recon_to, recon_from)
|
89 |
+
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
|
93 |
+
'''
|
94 |
+
The visualization model definition class
|
95 |
+
'''
|
singleVis/__init__.py
ADDED
File without changes
|
singleVis/__pycache__/SingleVisualizationModel.cpython-37.pyc
ADDED
Binary file (3.32 kB). View file
|
|
singleVis/__pycache__/SingleVisualizationModel.cpython-39.pyc
ADDED
Binary file (5.93 kB). View file
|
|
singleVis/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (154 Bytes). View file
|
|
singleVis/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (148 Bytes). View file
|
|
singleVis/__pycache__/backend.cpython-37.pyc
ADDED
Binary file (4.78 kB). View file
|
|
singleVis/__pycache__/backend.cpython-39.pyc
ADDED
Binary file (5.12 kB). View file
|
|
singleVis/__pycache__/custom_weighted_random_sampler.cpython-37.pyc
ADDED
Binary file (1.11 kB). View file
|
|
singleVis/__pycache__/custom_weighted_random_sampler.cpython-39.pyc
ADDED
Binary file (1.12 kB). View file
|
|
singleVis/__pycache__/data.cpython-37.pyc
ADDED
Binary file (35.8 kB). View file
|
|
singleVis/__pycache__/data.cpython-39.pyc
ADDED
Binary file (32.5 kB). View file
|
|
singleVis/__pycache__/edge_dataset.cpython-37.pyc
ADDED
Binary file (3.64 kB). View file
|
|
singleVis/__pycache__/edge_dataset.cpython-39.pyc
ADDED
Binary file (5.14 kB). View file
|
|
singleVis/__pycache__/intrinsic_dim.cpython-37.pyc
ADDED
Binary file (4.46 kB). View file
|
|
singleVis/__pycache__/intrinsic_dim.cpython-39.pyc
ADDED
Binary file (4.43 kB). View file
|
|
singleVis/__pycache__/jj1sk.cpython-37.pyc
ADDED
Binary file (16.5 kB). View file
|
|
singleVis/__pycache__/jj51sk.cpython-37.pyc
ADDED
Binary file (16.5 kB). View file
|
|
singleVis/__pycache__/jj551sk.cpython-37.pyc
ADDED
Binary file (16.5 kB). View file
|
|
singleVis/__pycache__/jjsk.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/kcenter_greedy.cpython-37.pyc
ADDED
Binary file (5.33 kB). View file
|
|
singleVis/__pycache__/kcenter_greedy.cpython-39.pyc
ADDED
Binary file (4.9 kB). View file
|
|
singleVis/__pycache__/losses.cpython-37.pyc
ADDED
Binary file (9.27 kB). View file
|
|
singleVis/__pycache__/losses.cpython-39.pyc
ADDED
Binary file (11.9 kB). View file
|
|
singleVis/__pycache__/projector.cpython-37.pyc
ADDED
Binary file (12.3 kB). View file
|
|
singleVis/__pycache__/sVis.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/s_Vis.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/skeVis.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
singleVis/__pycache__/skeleVis.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/skele_Vis.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
singleVis/__pycache__/skele_vis.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
singleVis/__pycache__/skele_viser.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
singleVis/__pycache__/skeletonVis.cpython-37.pyc
ADDED
Binary file (16.3 kB). View file
|
|
singleVis/__pycache__/skeletonViser.cpython-37.pyc
ADDED
Binary file (16.3 kB). View file
|
|
singleVis/__pycache__/skeletonVisualizer.cpython-37.pyc
ADDED
Binary file (16.3 kB). View file
|
|
singleVis/__pycache__/skeleton_generator.cpython-37.pyc
ADDED
Binary file (1.6 kB). View file
|
|
singleVis/__pycache__/skeleton_vis.cpython-37.pyc
ADDED
Binary file (16.5 kB). View file
|
|
singleVis/__pycache__/spatial_edge_constructor.cpython-37.pyc
ADDED
Binary file (49.7 kB). View file
|
|
singleVis/__pycache__/spatial_edge_constructor.cpython-39.pyc
ADDED
Binary file (44.8 kB). View file
|
|
singleVis/__pycache__/spatial_edge_constructor_.cpython-37.pyc
ADDED
Binary file (46.8 kB). View file
|
|
singleVis/__pycache__/spatial_skeleton_edge_constructor.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/ss_Vis.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
singleVis/__pycache__/ssjj_Vis.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/ssjjjjj_Vis.cpython-37.pyc
ADDED
Binary file (16.6 kB). View file
|
|
singleVis/__pycache__/sss_Vis.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|