account18hackathon
commited on
Commit
·
e19e1b1
1
Parent(s):
2cf93bf
Upload 4 files
Browse files- pretrain.py +0 -75
- requirements.txt +8 -0
pretrain.py
CHANGED
@@ -34,18 +34,11 @@ import pickle as pkl
|
|
34 |
from sophia import SophiaG
|
35 |
|
36 |
|
37 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
38 |
|
39 |
-
# # constants
|
40 |
-
|
41 |
-
# NUM_BATCHES = int(1e5)
|
42 |
-
# BATCH_SIZE = 4
|
43 |
GRADIENT_ACCUMULATE_EVERY = 4
|
44 |
LEARNING_RATE = 1e-4
|
45 |
VALIDATE_EVERY = 100
|
46 |
GENERATE_EVERY = 500
|
47 |
-
# GENERATE_LENGTH = 2048
|
48 |
-
# SEQ_LEN = 4096
|
49 |
|
50 |
|
51 |
parser = argparse.ArgumentParser()
|
@@ -65,9 +58,6 @@ parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory
|
|
65 |
parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.')
|
66 |
|
67 |
args = parser.parse_args()
|
68 |
-
# rank = int(os.environ["RANK"])
|
69 |
-
# local_rank = args.local_rank
|
70 |
-
# is_master = local_rank == 0
|
71 |
|
72 |
SEED = args.seed
|
73 |
EPOCHS = args.epoch
|
@@ -86,14 +76,6 @@ POS_EMBED_USING = args.pos_embed
|
|
86 |
model_name = args.model_name
|
87 |
ckpt_dir = args.ckpt_dir
|
88 |
|
89 |
-
# dist.init_process_group(backend='nccl')
|
90 |
-
# torch.cuda.set_device(local_rank)
|
91 |
-
# device = torch.device("cuda", local_rank)
|
92 |
-
# world_size = torch.distributed.get_world_size()
|
93 |
-
|
94 |
-
# seed_all(SEED + torch.distributed.get_rank())
|
95 |
-
|
96 |
-
|
97 |
|
98 |
# helpers
|
99 |
|
@@ -127,27 +109,7 @@ model = PerformerLM(
|
|
127 |
model = AutoregressiveWrapper(model)
|
128 |
model.cuda()
|
129 |
|
130 |
-
|
131 |
-
|
132 |
# prepare sc data
|
133 |
-
|
134 |
-
class SCDataset(Dataset):
|
135 |
-
def __init__(self, data, label):
|
136 |
-
super().__init__()
|
137 |
-
self.data = data
|
138 |
-
self.label = label
|
139 |
-
|
140 |
-
def __getitem__(self, index):
|
141 |
-
rand_start = random.randint(0, self.data.shape[0]-1)
|
142 |
-
full_seq = self.data[rand_start].toarray()[0]
|
143 |
-
full_seq[full_seq > (CLASS - 2)] = CLASS - 2
|
144 |
-
full_seq = torch.from_numpy(full_seq).long()
|
145 |
-
full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
|
146 |
-
seq_label = self.label[rand_start]
|
147 |
-
return full_seq, seq_label
|
148 |
-
|
149 |
-
def __len__(self):
|
150 |
-
return self.data.shape[0]
|
151 |
|
152 |
class SCDatasetPretrain(Dataset):
|
153 |
def __init__(self, data, seq_len):
|
@@ -169,19 +131,8 @@ class SCDatasetPretrain(Dataset):
|
|
169 |
|
170 |
def __len__(self):
|
171 |
return self.data.shape[0]
|
172 |
-
|
173 |
|
174 |
data = sc.read_h5ad(args.data_path)
|
175 |
-
#data = data[:1000, :]
|
176 |
-
# label_dict, label = np.unique(np.array(data.obs['cell_type']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored
|
177 |
-
# #store the label dict and label for prediction
|
178 |
-
# with open('label_dict', 'wb') as fp:
|
179 |
-
# pkl.dump(label_dict, fp)
|
180 |
-
# with open('label', 'wb') as fp:
|
181 |
-
# pkl.dump(label, fp)
|
182 |
-
# class_num = np.unique(label, return_counts=True)[1].tolist()
|
183 |
-
# class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
|
184 |
-
# label = torch.from_numpy(label)
|
185 |
data = data.X
|
186 |
|
187 |
acc = []
|
@@ -190,18 +141,6 @@ f1w = []
|
|
190 |
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
|
191 |
pred_list = pd.Series(['un'] * data.shape[0])
|
192 |
|
193 |
-
# sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
|
194 |
-
# for index_train in sss.split(data):
|
195 |
-
# data_train = data[index_train]
|
196 |
-
# data_val = data[index_val]
|
197 |
-
# train_dataset = SCDatasetPretrain(data_train, SEQ_LEN)
|
198 |
-
# val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
|
199 |
-
|
200 |
-
# train_sampler = DistributedSampler(train_dataset)
|
201 |
-
# val_sampler = DistributedSampler(val_dataset)
|
202 |
-
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
203 |
-
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
|
204 |
-
|
205 |
index_train = int(data.shape[0]*0.8)
|
206 |
data_train = data[:index_train]
|
207 |
data_val = data[index_train:]
|
@@ -210,15 +149,11 @@ val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
|
|
210 |
|
211 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
212 |
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
213 |
-
# train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
214 |
-
# val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
215 |
|
216 |
# optimizer
|
217 |
|
218 |
optim = SophiaG(model.parameters(), lr=2e-4,
|
219 |
betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)
|
220 |
-
# optim = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
|
221 |
-
# optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
222 |
scaler = GradScaler()
|
223 |
|
224 |
# training
|
@@ -244,14 +179,6 @@ for i in tqdm(range(EPOCHS), mininterval=10., desc='training'):
|
|
244 |
scaler.update()
|
245 |
optim.zero_grad()
|
246 |
|
247 |
-
# if i % VALIDATE_EVERY == 0:
|
248 |
-
# model.eval()
|
249 |
-
# with torch.no_grad():
|
250 |
-
# #loss = model(next(val_loader), return_loss = True)
|
251 |
-
# for index, data_batch in enumerate(tqdm(val_loader)):
|
252 |
-
# loss = model(data_batch, return_loss = True)
|
253 |
-
# print(f'validation loss: {loss.item()}')
|
254 |
-
|
255 |
if i % GENERATE_EVERY == 0 and i != 0:
|
256 |
model.eval()
|
257 |
inp = random.choice(val_dataset)[:-1]
|
@@ -266,5 +193,3 @@ for i in tqdm(range(EPOCHS), mininterval=10., desc='training'):
|
|
266 |
print('save model')
|
267 |
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
|
268 |
torch.save(checkpoint, os.path.join(ckpt_dir, 'model_gene_attn.pth'))
|
269 |
-
|
270 |
-
a=1
|
|
|
34 |
from sophia import SophiaG
|
35 |
|
36 |
|
|
|
37 |
|
|
|
|
|
|
|
|
|
38 |
GRADIENT_ACCUMULATE_EVERY = 4
|
39 |
LEARNING_RATE = 1e-4
|
40 |
VALIDATE_EVERY = 100
|
41 |
GENERATE_EVERY = 500
|
|
|
|
|
42 |
|
43 |
|
44 |
parser = argparse.ArgumentParser()
|
|
|
58 |
parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.')
|
59 |
|
60 |
args = parser.parse_args()
|
|
|
|
|
|
|
61 |
|
62 |
SEED = args.seed
|
63 |
EPOCHS = args.epoch
|
|
|
76 |
model_name = args.model_name
|
77 |
ckpt_dir = args.ckpt_dir
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# helpers
|
81 |
|
|
|
109 |
model = AutoregressiveWrapper(model)
|
110 |
model.cuda()
|
111 |
|
|
|
|
|
112 |
# prepare sc data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
class SCDatasetPretrain(Dataset):
|
115 |
def __init__(self, data, seq_len):
|
|
|
131 |
|
132 |
def __len__(self):
|
133 |
return self.data.shape[0]
|
|
|
134 |
|
135 |
data = sc.read_h5ad(args.data_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
data = data.X
|
137 |
|
138 |
acc = []
|
|
|
141 |
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
|
142 |
pred_list = pd.Series(['un'] * data.shape[0])
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
index_train = int(data.shape[0]*0.8)
|
145 |
data_train = data[:index_train]
|
146 |
data_val = data[index_train:]
|
|
|
149 |
|
150 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
151 |
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
|
|
|
|
152 |
|
153 |
# optimizer
|
154 |
|
155 |
optim = SophiaG(model.parameters(), lr=2e-4,
|
156 |
betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)
|
|
|
|
|
157 |
scaler = GradScaler()
|
158 |
|
159 |
# training
|
|
|
179 |
scaler.update()
|
180 |
optim.zero_grad()
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
if i % GENERATE_EVERY == 0 and i != 0:
|
183 |
model.eval()
|
184 |
inp = random.choice(val_dataset)[:-1]
|
|
|
193 |
print('save model')
|
194 |
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
|
195 |
torch.save(checkpoint, os.path.join(ckpt_dir, 'model_gene_attn.pth'))
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.1
|
2 |
+
torchvision==0.9.1
|
3 |
+
transformers==4.6.1
|
4 |
+
scanpy==1.7.2
|
5 |
+
scikit-learn==0.24.2
|
6 |
+
scipy==1.5.4
|
7 |
+
numpy==1.19.2
|
8 |
+
pandas==1.1.5
|