account18hackathon commited on
Commit
4e40454
·
1 Parent(s): 43ca29c

Upload 3 files

Browse files
Files changed (3) hide show
  1. pretrain.py +270 -0
  2. sophia.py +202 -0
  3. utils.py +376 -0
pretrain.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from performer_pytorch import PerformerLM
2
+ from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
3
+
4
+ import argparse
5
+ import random
6
+ import os
7
+ from tqdm import tqdm
8
+ import gzip
9
+ import numpy as np
10
+ import torch
11
+ import torch.optim as optim
12
+ from torch.nn import functional as F
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from torch.cuda.amp import autocast, GradScaler
15
+
16
+ from functools import reduce
17
+ import pandas as pd
18
+ from scipy import sparse
19
+ from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
20
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report
21
+ from torch import nn
22
+ from torch.optim import Adam, SGD, AdamW
23
+ from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR
24
+ from torch.utils.data import DataLoader, Dataset
25
+ from torch.utils.data.distributed import DistributedSampler
26
+ from torch.nn.parallel import DistributedDataParallel as DDP
27
+ import torch.distributed as dist
28
+
29
+ import scanpy as sc
30
+ import anndata as ad
31
+ from utils import *
32
+ import pickle as pkl
33
+
34
+ from sophia import SophiaG
35
+
36
+
37
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
38
+
39
+ # # constants
40
+
41
+ # NUM_BATCHES = int(1e5)
42
+ # BATCH_SIZE = 4
43
+ GRADIENT_ACCUMULATE_EVERY = 4
44
+ LEARNING_RATE = 1e-4
45
+ VALIDATE_EVERY = 100
46
+ GENERATE_EVERY = 500
47
+ # GENERATE_LENGTH = 2048
48
+ # SEQ_LEN = 4096
49
+
50
+
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--local_rank", type=int, default=-1, help='Local process rank.')
53
+ parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.')
54
+ parser.add_argument("--gene_num", type=int, default=16906, help='Number of genes.')
55
+ parser.add_argument("--epoch", type=int, default=1, help='Number of epochs.')
56
+ parser.add_argument("--seed", type=int, default=2021, help='Random seed.')
57
+ parser.add_argument("--batch_size", type=int, default=8, help='Number of batch size.')
58
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.')
59
+ parser.add_argument("--grad_acc", type=int, default=60, help='Number of gradient accumulation.')
60
+ parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.')
61
+ parser.add_argument("--pos_embed", type=bool, default=True, help='Using Gene2vec encoding or not.')
62
+ parser.add_argument("--data_path", type=str, default='./data/panglao_human.h5ad', help='Path of data for finetune.')
63
+ parser.add_argument("--model_path", type=str, default='./panglao_pretrained.pth', help='Path of pretrained model.')
64
+ parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory of checkpoint to save.')
65
+ parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.')
66
+
67
+ args = parser.parse_args()
68
+ # rank = int(os.environ["RANK"])
69
+ # local_rank = args.local_rank
70
+ # is_master = local_rank == 0
71
+
72
+ SEED = args.seed
73
+ EPOCHS = args.epoch
74
+ BATCH_SIZE = args.batch_size
75
+ GRADIENT_ACCUMULATION = args.grad_acc
76
+ LEARNING_RATE = args.learning_rate
77
+ SEQ_LEN = args.gene_num + 1
78
+ VALIDATE_EVERY = args.valid_every
79
+
80
+ PATIENCE = 10
81
+ UNASSIGN_THRES = 0.0
82
+
83
+ CLASS = args.bin_num + 2
84
+ POS_EMBED_USING = args.pos_embed
85
+
86
+ model_name = args.model_name
87
+ ckpt_dir = args.ckpt_dir
88
+
89
+ # dist.init_process_group(backend='nccl')
90
+ # torch.cuda.set_device(local_rank)
91
+ # device = torch.device("cuda", local_rank)
92
+ # world_size = torch.distributed.get_world_size()
93
+
94
+ # seed_all(SEED + torch.distributed.get_rank())
95
+
96
+
97
+
98
+ # helpers
99
+
100
+ def cycle(loader):
101
+ while True:
102
+ for data in loader:
103
+ yield data
104
+
105
+ def decode_token(token):
106
+ return str(chr(max(32, token)))
107
+
108
+ def decode_tokens(tokens):
109
+ return ''.join(list(map(decode_token, tokens)))
110
+
111
+ # instantiate model
112
+
113
+ model = PerformerLM(
114
+ num_tokens = args.bin_num + 2,
115
+ dim = 200,
116
+ depth = 3,
117
+ max_seq_len = SEQ_LEN,
118
+ heads = 5,
119
+ causal = False,
120
+ reversible = False,
121
+ use_scalenorm = True,
122
+ local_attn_heads = 0,
123
+ g2v_position_emb = POS_EMBED_USING,
124
+ generalized_attention = True
125
+ )
126
+
127
+ model = AutoregressiveWrapper(model)
128
+ model.cuda()
129
+
130
+
131
+
132
+ # prepare sc data
133
+
134
+ class SCDataset(Dataset):
135
+ def __init__(self, data, label):
136
+ super().__init__()
137
+ self.data = data
138
+ self.label = label
139
+
140
+ def __getitem__(self, index):
141
+ rand_start = random.randint(0, self.data.shape[0]-1)
142
+ full_seq = self.data[rand_start].toarray()[0]
143
+ full_seq[full_seq > (CLASS - 2)] = CLASS - 2
144
+ full_seq = torch.from_numpy(full_seq).long()
145
+ full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
146
+ seq_label = self.label[rand_start]
147
+ return full_seq, seq_label
148
+
149
+ def __len__(self):
150
+ return self.data.shape[0]
151
+
152
+ class SCDatasetPretrain(Dataset):
153
+ def __init__(self, data, seq_len):
154
+ super().__init__()
155
+ self.data = data
156
+ self.seq_len = seq_len
157
+
158
+ def __getitem__(self, index):
159
+ # rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
160
+ # full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
161
+
162
+ rand_start = random.randint(0, self.data.shape[0]-1)
163
+ full_seq = self.data[rand_start].toarray()[0]
164
+ full_seq[full_seq > (CLASS - 2)] = CLASS - 2
165
+ full_seq = torch.from_numpy(full_seq).long()
166
+ full_seq = torch.cat((full_seq, torch.tensor([0])))
167
+
168
+ return full_seq.cuda()
169
+
170
+ def __len__(self):
171
+ return self.data.shape[0]
172
+
173
+
174
+ data = sc.read_h5ad(args.data_path)
175
+ #data = data[:1000, :]
176
+ # label_dict, label = np.unique(np.array(data.obs['cell_type']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored
177
+ # #store the label dict and label for prediction
178
+ # with open('label_dict', 'wb') as fp:
179
+ # pkl.dump(label_dict, fp)
180
+ # with open('label', 'wb') as fp:
181
+ # pkl.dump(label, fp)
182
+ # class_num = np.unique(label, return_counts=True)[1].tolist()
183
+ # class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
184
+ # label = torch.from_numpy(label)
185
+ data = data.X
186
+
187
+ acc = []
188
+ f1 = []
189
+ f1w = []
190
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
191
+ pred_list = pd.Series(['un'] * data.shape[0])
192
+
193
+ # sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
194
+ # for index_train in sss.split(data):
195
+ # data_train = data[index_train]
196
+ # data_val = data[index_val]
197
+ # train_dataset = SCDatasetPretrain(data_train, SEQ_LEN)
198
+ # val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
199
+
200
+ # train_sampler = DistributedSampler(train_dataset)
201
+ # val_sampler = DistributedSampler(val_dataset)
202
+ # train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
203
+ # val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
204
+
205
+ index_train = int(data.shape[0]*0.8)
206
+ data_train = data[:index_train]
207
+ data_val = data[index_train:]
208
+ train_dataset = SCDatasetPretrain(data_train, SEQ_LEN)
209
+ val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
210
+
211
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
212
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
213
+ # train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
214
+ # val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
215
+
216
+ # optimizer
217
+
218
+ optim = SophiaG(model.parameters(), lr=2e-4,
219
+ betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)
220
+ # optim = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
221
+ # optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
222
+ scaler = GradScaler()
223
+
224
+ # training
225
+
226
+ for i in tqdm(range(EPOCHS), mininterval=10., desc='training'):
227
+ model.train()
228
+
229
+ # for __ in range(GRADIENT_ACCUMULATE_EVERY):
230
+ with autocast():
231
+ # loss = model(next(train_loader), return_loss = True)
232
+ for index, data_batch in enumerate(tqdm(train_loader)):
233
+ loss = model(data_batch, return_loss = True)
234
+ #print(f'training loss: {loss.item()}')
235
+
236
+ scaler.scale(loss).backward()
237
+ #print(f'training loss: {loss.item()}')
238
+
239
+ print(f'training loss: {loss.item()}')
240
+
241
+ scaler.unscale_(optim)
242
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
243
+ scaler.step(optim)
244
+ scaler.update()
245
+ optim.zero_grad()
246
+
247
+ # if i % VALIDATE_EVERY == 0:
248
+ # model.eval()
249
+ # with torch.no_grad():
250
+ # #loss = model(next(val_loader), return_loss = True)
251
+ # for index, data_batch in enumerate(tqdm(val_loader)):
252
+ # loss = model(data_batch, return_loss = True)
253
+ # print(f'validation loss: {loss.item()}')
254
+
255
+ if i % GENERATE_EVERY == 0 and i != 0:
256
+ model.eval()
257
+ inp = random.choice(val_dataset)[:-1]
258
+ prime = decode_tokens(inp)
259
+ print(f'%s \n\n %s', (prime, '*' * 100))
260
+
261
+ sample = model.generate(inp, GENERATE_LENGTH)
262
+ output_str = decode_tokens(sample)
263
+ print(output_str)
264
+
265
+ # save model
266
+ print('save model')
267
+ checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
268
+ torch.save(checkpoint, os.path.join(ckpt_dir, 'model_gene_attn.pth'))
269
+
270
+ a=1
sophia.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import Tensor
4
+ from torch.optim.optimizer import Optimizer
5
+ from typing import List, Optional
6
+
7
+
8
+ class SophiaG(Optimizer):
9
+ def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04,
10
+ weight_decay=1e-1, *, maximize: bool = False,
11
+ capturable: bool = False):
12
+ if not 0.0 <= lr:
13
+ raise ValueError("Invalid learning rate: {}".format(lr))
14
+ if not 0.0 <= betas[0] < 1.0:
15
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
16
+ if not 0.0 <= betas[1] < 1.0:
17
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
18
+ if not 0.0 <= rho:
19
+ raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
20
+ if not 0.0 <= weight_decay:
21
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
22
+ defaults = dict(lr=lr, betas=betas, rho=rho,
23
+ weight_decay=weight_decay,
24
+ maximize=maximize, capturable=capturable)
25
+ super(SophiaG, self).__init__(params, defaults)
26
+
27
+ def __setstate__(self, state):
28
+ super().__setstate__(state)
29
+ for group in self.param_groups:
30
+ group.setdefault('maximize', False)
31
+ group.setdefault('capturable', False)
32
+ state_values = list(self.state.values())
33
+ step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
34
+ if not step_is_tensor:
35
+ for s in state_values:
36
+ s['step'] = torch.tensor(float(s['step']))
37
+
38
+ @torch.no_grad()
39
+ def update_hessian(self):
40
+ for group in self.param_groups:
41
+ beta1, beta2 = group['betas']
42
+ for p in group['params']:
43
+ if p.grad is None:
44
+ continue
45
+ state = self.state[p]
46
+
47
+ if len(state) == 0:
48
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
49
+ if self.defaults['capturable'] else torch.tensor(0.)
50
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
51
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
52
+
53
+ if 'hessian' not in state.keys():
54
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
55
+
56
+ state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
57
+
58
+
59
+ @torch.no_grad()
60
+ def step(self, closure=None, bs=5120):
61
+ loss = None
62
+ if closure is not None:
63
+ with torch.enable_grad():
64
+ loss = closure()
65
+
66
+ for group in self.param_groups:
67
+ params_with_grad = []
68
+ grads = []
69
+ exp_avgs = []
70
+ state_steps = []
71
+ hessian = []
72
+ beta1, beta2 = group['betas']
73
+
74
+ for p in group['params']:
75
+ if p.grad is None:
76
+ continue
77
+ params_with_grad.append(p)
78
+
79
+ if p.grad.is_sparse:
80
+ raise RuntimeError('Hero does not support sparse gradients')
81
+ grads.append(p.grad)
82
+ state = self.state[p]
83
+ # State initialization
84
+ if len(state) == 0:
85
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
86
+ if self.defaults['capturable'] else torch.tensor(0.)
87
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
88
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
89
+
90
+ if 'hessian' not in state.keys():
91
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
92
+
93
+ exp_avgs.append(state['exp_avg'])
94
+ state_steps.append(state['step'])
95
+ hessian.append(state['hessian'])
96
+
97
+ if self.defaults['capturable']:
98
+ bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs
99
+
100
+ sophiag(params_with_grad,
101
+ grads,
102
+ exp_avgs,
103
+ hessian,
104
+ state_steps,
105
+ bs=bs,
106
+ beta1=beta1,
107
+ beta2=beta2,
108
+ rho=group['rho'],
109
+ lr=group['lr'],
110
+ weight_decay=group['weight_decay'],
111
+ maximize=group['maximize'],
112
+ capturable=group['capturable'])
113
+
114
+ return loss
115
+
116
+ def sophiag(params: List[Tensor],
117
+ grads: List[Tensor],
118
+ exp_avgs: List[Tensor],
119
+ hessian: List[Tensor],
120
+ state_steps: List[Tensor],
121
+ capturable: bool = False,
122
+ *,
123
+ bs: int,
124
+ beta1: float,
125
+ beta2: float,
126
+ rho: float,
127
+ lr: float,
128
+ weight_decay: float,
129
+ maximize: bool):
130
+
131
+ if not all(isinstance(t, torch.Tensor) for t in state_steps):
132
+ raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
133
+
134
+
135
+ func = _single_tensor_sophiag
136
+
137
+ func(params,
138
+ grads,
139
+ exp_avgs,
140
+ hessian,
141
+ state_steps,
142
+ bs=bs,
143
+ beta1=beta1,
144
+ beta2=beta2,
145
+ rho=rho,
146
+ lr=lr,
147
+ weight_decay=weight_decay,
148
+ maximize=maximize,
149
+ capturable=capturable)
150
+
151
+ def _single_tensor_sophiag(params: List[Tensor],
152
+ grads: List[Tensor],
153
+ exp_avgs: List[Tensor],
154
+ hessian: List[Tensor],
155
+ state_steps: List[Tensor],
156
+ *,
157
+ bs: int,
158
+ beta1: float,
159
+ beta2: float,
160
+ rho: float,
161
+ lr: float,
162
+ weight_decay: float,
163
+ maximize: bool,
164
+ capturable: bool):
165
+
166
+ for i, param in enumerate(params):
167
+ grad = grads[i] if not maximize else -grads[i]
168
+ exp_avg = exp_avgs[i]
169
+ hess = hessian[i]
170
+ step_t = state_steps[i]
171
+
172
+ if capturable:
173
+ assert param.is_cuda and step_t.is_cuda and bs.is_cuda
174
+
175
+ if torch.is_complex(param):
176
+ grad = torch.view_as_real(grad)
177
+ exp_avg = torch.view_as_real(exp_avg)
178
+ hess = torch.view_as_real(hess)
179
+ param = torch.view_as_real(param)
180
+
181
+ # update step
182
+ step_t += 1
183
+
184
+ # Perform stepweight decay
185
+ param.mul_(1 - lr * weight_decay)
186
+
187
+ # Decay the first and second moment running average coefficient
188
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
189
+
190
+ if capturable:
191
+ step = step_t
192
+ step_size = lr
193
+ step_size_neg = step_size.neg()
194
+
195
+ ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
196
+ param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
197
+ else:
198
+ step = step_t.item()
199
+ step_size_neg = - lr
200
+
201
+ ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
202
+ param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import print_function
4
+ import json
5
+ import os
6
+ import struct
7
+ import sys
8
+ import platform
9
+ import re
10
+ import time
11
+ import traceback
12
+ import requests
13
+ import socket
14
+ import random
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import logging
19
+ import datetime
20
+ from torch.optim.lr_scheduler import _LRScheduler
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+ from torch.nn.modules.loss import _WeightedLoss
24
+
25
+
26
+
27
+ def seed_all(seed_value, cuda_deterministic=False):
28
+ """
29
+ 设置所有的随机种子
30
+ """
31
+ random.seed(seed_value)
32
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
33
+ np.random.seed(seed_value)
34
+ torch.manual_seed(seed_value)
35
+ if torch.cuda.is_available():
36
+ torch.cuda.manual_seed(seed_value)
37
+ torch.cuda.manual_seed_all(seed_value)
38
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
39
+ if cuda_deterministic: # slower, more reproducible
40
+ torch.backends.cudnn.deterministic = True
41
+ torch.backends.cudnn.benchmark = False
42
+ else: # faster, less reproducible
43
+ torch.backends.cudnn.deterministic = False
44
+ torch.backends.cudnn.benchmark = True
45
+
46
+
47
+ def set_log(logfileName, rank=-1):
48
+ """
49
+ master节点保存所有log,其他节点只保存warning及error
50
+ """
51
+ log_file_folder = os.path.dirname(logfileName)
52
+ time_now = datetime.datetime.now()
53
+ logfileName = f'{logfileName}_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}_{time_now.minute}.log'
54
+ if not os.path.exists(log_file_folder):
55
+ os.makedirs(log_file_folder)
56
+ else:
57
+ pass
58
+
59
+ logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN,
60
+ format='[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s',
61
+ datefmt='[%X]',
62
+ handlers=[logging.FileHandler(logfileName), logging.StreamHandler()]
63
+ )
64
+ logger = logging.getLogger()
65
+ return logger
66
+
67
+
68
+ def save_ckpt(epoch, model, optimizer, scheduler, losses, model_name, ckpt_folder):
69
+ """
70
+ 保存模型checkpoint
71
+ """
72
+ if not os.path.exists(ckpt_folder):
73
+ os.makedirs(ckpt_folder)
74
+ torch.save(
75
+ {
76
+ 'epoch': epoch,
77
+ 'model_state_dict': model.module.state_dict(),
78
+ 'optimizer_state_dict': optimizer.state_dict(),
79
+ 'scheduler_state_dict': scheduler.state_dict(),
80
+ 'losses': losses,
81
+ },
82
+ f'{ckpt_folder}{model_name}_{epoch}.pth'
83
+ )
84
+
85
+ def save_simple_ckpt(model, model_name, ckpt_folder):
86
+ """
87
+ 保存模型checkpoint
88
+ """
89
+ if not os.path.exists(ckpt_folder):
90
+ os.makedirs(ckpt_folder)
91
+ torch.save(
92
+ {
93
+ 'model_state_dict': model.module.state_dict()
94
+ },
95
+ f'{ckpt_folder}{model_name}.pth'
96
+ )
97
+
98
+ def save_best_ckpt(epoch, model, optimizer, scheduler, losses, model_name, ckpt_folder):
99
+ """
100
+ 保存模型checkpoint
101
+ """
102
+ if not os.path.exists(ckpt_folder):
103
+ os.makedirs(ckpt_folder)
104
+ torch.save(
105
+ {
106
+ 'epoch': epoch,
107
+ 'model_state_dict': model.module.state_dict(),
108
+ 'optimizer_state_dict': optimizer.state_dict(),
109
+ 'scheduler_state_dict': scheduler.state_dict(),
110
+ 'losses': losses,
111
+ },
112
+ f'{ckpt_folder}{model_name}_best.pth'
113
+ )
114
+
115
+ def get_reduced(tensor, current_device, dest_device, world_size):
116
+ """
117
+ 将不同GPU上的变量或tensor集中在主GPU上,并得到均值
118
+ """
119
+ tensor = tensor.clone().detach() if torch.is_tensor(tensor) else torch.tensor(tensor)
120
+ tensor = tensor.to(current_device)
121
+ torch.distributed.reduce(tensor, dst=dest_device)
122
+ tensor_mean = tensor.item() / world_size
123
+ return tensor_mean
124
+
125
+ def get_ndtensor_reduced(tensor, current_device, dest_device, world_size):
126
+ """
127
+ 将不同GPU上的变量或tensor集中在主GPU上,并得到均值, 需要是2维张量
128
+ """
129
+ tensor = tensor.clone().detach() if torch.is_tensor(tensor) else torch.tensor(tensor)
130
+ tensor = tensor.to(current_device)
131
+ torch.distributed.reduce(tensor, dst=dest_device)
132
+ tensor_mean = torch.zeros(tensor.shape)
133
+ if len(tensor.shape) == 2:
134
+ for i in range(tensor.shape[0]):
135
+ for j in range(tensor.shape[1]):
136
+ tensor_mean[i,j] = tensor[i,j].item() / world_size
137
+ elif len(tensor.shape) == 1:
138
+ for i in range(tensor.shape[0]):
139
+ tensor_mean[i] = tensor[i].item() / world_size
140
+ return tensor_mean
141
+
142
+ def numel(m: torch.nn.Module, only_trainable: bool = False):
143
+ """
144
+ returns the total number of parameters used by `m` (only counting
145
+ shared parameters once); if `only_trainable` is True, then only
146
+ includes parameters with `requires_grad = True`
147
+ """
148
+ parameters = m.parameters()
149
+ if only_trainable:
150
+ parameters = list(p for p in parameters if p.requires_grad)
151
+ unique = dict((p.data_ptr(), p) for p in parameters).values()
152
+ return sum(p.numel() for p in unique)
153
+
154
+
155
+ def label_smooth(y, K, epsilon=0.1):
156
+ """
157
+ Label smoothing for multiclass labels
158
+ One hot encode labels `y` over `K` classes. `y` should be of the form [1, 6, 3, etc.]
159
+ """
160
+ m = len(y)
161
+ out = np.ones((m, K)) * epsilon / K
162
+ for index in range(m):
163
+ out[index][y[index] - 1] += 1 - epsilon
164
+ return torch.tensor(out)
165
+
166
+
167
+ class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
168
+ """
169
+ Distributed Sampler that subsamples indicies sequentially,
170
+ making it easier to collate all results at the end.
171
+ Even though we only use this sampler for eval and predict (no training),
172
+ which means that the model params won't have to be synced (i.e. will not hang
173
+ for synchronization even if varied number of forward passes), we still add extra
174
+ samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
175
+ to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
176
+ """
177
+
178
+ def __init__(self, dataset, batch_size, world_size, rank=None, num_replicas=None):
179
+ if num_replicas is None:
180
+ if not torch.distributed.is_available():
181
+ raise RuntimeError("Requires distributed package to be available")
182
+ num_replicas = world_size
183
+ if rank is None:
184
+ if not torch.distributed.is_available():
185
+ raise RuntimeError("Requires distributed package to be available")
186
+ rank = torch.distributed.get_rank()
187
+ self.dataset = dataset
188
+ self.num_replicas = num_replicas
189
+ self.rank = rank
190
+ self.batch_size = batch_size
191
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
192
+ self.total_size = self.num_samples * self.num_replicas
193
+
194
+ def __iter__(self):
195
+ indices = list(range(len(self.dataset)))
196
+ # add extra samples to make it evenly divisible
197
+ indices += [indices[-1]] * (self.total_size - len(indices))
198
+ # subsample
199
+ indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
200
+ return iter(indices)
201
+
202
+ def __len__(self):
203
+ return self.num_samples
204
+
205
+
206
+ def distributed_concat(tensor, num_total_examples, world_size):
207
+ """
208
+ 合并不同进程的inference结果
209
+ """
210
+ output_tensors = [tensor.clone() for _ in range(world_size)]
211
+ torch.distributed.all_gather(output_tensors, tensor)
212
+ concat = torch.cat(output_tensors, dim=0)
213
+ # truncate the dummy elements added by SequentialDistributedSampler
214
+ return concat[:num_total_examples]
215
+
216
+
217
+ class CosineAnnealingWarmupRestarts(_LRScheduler):
218
+ """
219
+ optimizer (Optimizer): Wrapped optimizer.
220
+ first_cycle_steps (int): First cycle step size.
221
+ cycle_mult(float): Cycle steps magnification. Default: -1.
222
+ max_lr(float): First cycle's max learning rate. Default: 0.1.
223
+ min_lr(float): Min learning rate. Default: 0.001.
224
+ warmup_steps(int): Linear warmup step size. Default: 0.
225
+ gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
226
+ last_epoch (int): The index of last epoch. Default: -1.
227
+ """
228
+
229
+ def __init__(self,
230
+ optimizer : torch.optim.Optimizer,
231
+ first_cycle_steps : int,
232
+ cycle_mult : float = 1.,
233
+ max_lr : float = 0.1,
234
+ min_lr : float = 0.001,
235
+ warmup_steps : int = 0,
236
+ gamma : float = 1.,
237
+ last_epoch : int = -1
238
+ ):
239
+ assert warmup_steps < first_cycle_steps
240
+
241
+ self.first_cycle_steps = first_cycle_steps # first cycle step size
242
+ self.cycle_mult = cycle_mult # cycle steps magnification
243
+ self.base_max_lr = max_lr # first max learning rate
244
+ self.max_lr = max_lr # max learning rate in the current cycle
245
+ self.min_lr = min_lr # min learning rate
246
+ self.warmup_steps = warmup_steps # warmup step size
247
+ self.gamma = gamma # decrease rate of max learning rate by cycle
248
+
249
+ self.cur_cycle_steps = first_cycle_steps # first cycle step size
250
+ self.cycle = 0 # cycle count
251
+ self.step_in_cycle = last_epoch # step size of the current cycle
252
+
253
+ super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
254
+
255
+ # set learning rate min_lr
256
+ self.init_lr()
257
+
258
+ def init_lr(self):
259
+ self.base_lrs = []
260
+ for param_group in self.optimizer.param_groups:
261
+ param_group['lr'] = self.min_lr
262
+ self.base_lrs.append(self.min_lr)
263
+
264
+ def get_lr(self):
265
+ if self.step_in_cycle == -1:
266
+ return self.base_lrs
267
+ elif self.step_in_cycle < self.warmup_steps:
268
+ return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
269
+ else:
270
+ return [base_lr + (self.max_lr - base_lr) \
271
+ * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
272
+ / (self.cur_cycle_steps - self.warmup_steps))) / 2
273
+ for base_lr in self.base_lrs]
274
+
275
+ def step(self, epoch=None):
276
+ if epoch is None:
277
+ epoch = self.last_epoch + 1
278
+ self.step_in_cycle = self.step_in_cycle + 1
279
+ if self.step_in_cycle >= self.cur_cycle_steps:
280
+ self.cycle += 1
281
+ self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
282
+ self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
283
+ else:
284
+ if epoch >= self.first_cycle_steps:
285
+ if self.cycle_mult == 1.:
286
+ self.step_in_cycle = epoch % self.first_cycle_steps
287
+ self.cycle = epoch // self.first_cycle_steps
288
+ else:
289
+ n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
290
+ self.cycle = n
291
+ self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
292
+ self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
293
+ else:
294
+ self.cur_cycle_steps = self.first_cycle_steps
295
+ self.step_in_cycle = epoch
296
+
297
+ self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
298
+ self.last_epoch = math.floor(epoch)
299
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
300
+ param_group['lr'] = lr
301
+
302
+
303
+ class DistanceLoss(_WeightedLoss):
304
+ """
305
+ CrossEntropyLoss with Distance Weighted
306
+ """
307
+ def __init__(self, weight=None, reduction='mean', ignore_index = None):
308
+ super().__init__(weight=weight, reduction=reduction)
309
+ self.weight = weight
310
+ self.reduction = reduction
311
+ self.ignore_index = ignore_index
312
+ def forward(self, inputs, targets):
313
+ if len(inputs.shape) > 2:
314
+ inputs = inputs.reshape(-1, inputs.size(-1))
315
+ if len(targets.shape) > 1:
316
+ targets = targets.reshape(-1)
317
+ if self.ignore_index is not None:
318
+ keep_index = (targets != self.ignore_index).nonzero(as_tuple=True)[0]
319
+ targets = torch.index_select(targets, 0, keep_index) #targets[targets != self.ignore_index]
320
+ inputs = torch.index_select(inputs, 0, keep_index)
321
+ lsm = F.log_softmax(inputs, -1)
322
+ targets = torch.empty(size=(targets.size(0), inputs.size(-1)), device=targets.device).fill_(0).scatter_(1, targets.data.unsqueeze(1), 1)
323
+ if self.weight is not None:
324
+ lsm = lsm * self.weight.unsqueeze(0)
325
+ loss = -(targets * lsm).sum(-1)
326
+ inputs = nn.Softmax(dim=-1)(inputs)[..., 1:-1].argmax(dim=-1) + 1
327
+ # print('inputs', inputs.device, inputs.shape)
328
+ targets = nn.Softmax(dim=-1)(targets)[..., 1:-1].argmax(dim=-1) + 1
329
+ # print('targets', targets.device, targets.shape)
330
+ distance = abs(inputs - targets) + 1e-2
331
+ # print('loss.shape', loss.shape)
332
+ # print('distance.shape', distance.shape)
333
+ loss = loss * distance
334
+ if self.reduction == 'sum':
335
+ loss = loss.sum()
336
+ elif self.reduction == 'mean':
337
+ loss = loss.mean()
338
+ return loss
339
+
340
+
341
+ class LabelSmoothCrossEntropyLoss(_WeightedLoss):
342
+ """
343
+ CrossEntropyLoss with Label Somoothing
344
+ """
345
+ def __init__(self, weight=None, reduction='mean', smoothing=0.0):
346
+ super().__init__(weight=weight, reduction=reduction)
347
+ self.smoothing = smoothing
348
+ self.weight = weight
349
+ self.reduction = reduction
350
+
351
+ @staticmethod
352
+ def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0):
353
+ assert 0 <= smoothing < 1
354
+ with torch.no_grad():
355
+ targets = torch.empty(size=(targets.size(0), n_classes),
356
+ device=targets.device) \
357
+ .fill_(smoothing / (n_classes - 1)) \
358
+ .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing)
359
+ return targets
360
+
361
+ def forward(self, inputs, targets):
362
+ targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
363
+ self.smoothing)
364
+ lsm = F.log_softmax(inputs, -1)
365
+
366
+ if self.weight is not None:
367
+ lsm = lsm * self.weight.unsqueeze(0)
368
+
369
+ loss = -(targets * lsm).sum(-1)
370
+
371
+ if self.reduction == 'sum':
372
+ loss = loss.sum()
373
+ elif self.reduction == 'mean':
374
+ loss = loss.mean()
375
+
376
+ return loss