NiniCat commited on
Commit
e6bbcf7
·
1 Parent(s): e93c659

Add tiger file

Browse files
Files changed (1) hide show
  1. tiger.py +417 -0
tiger.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import gzip
4
+ import pickle
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tensorflow as tf
8
+ from Bio import SeqIO
9
+
10
+ # column names
11
+ ID_COL = 'Transcript ID'
12
+ SEQ_COL = 'Transcript Sequence'
13
+ TARGET_COL = 'Target Sequence'
14
+ GUIDE_COL = 'Guide Sequence'
15
+ MM_COL = 'Number of Mismatches'
16
+ SCORE_COL = 'Guide Score'
17
+
18
+ # nucleotide tokens
19
+ NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T', 'N'], [0, 1, 2, 3, 255]))
20
+ NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))
21
+
22
+ # model hyper-parameters
23
+ GUIDE_LEN = 23
24
+ CONTEXT_5P = 3
25
+ CONTEXT_3P = 0
26
+ TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
27
+ UNIT_INTERVAL_MAP = 'sigmoid'
28
+
29
+ # reference transcript files
30
+ REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
31
+
32
+ # application configuration
33
+ BATCH_SIZE_COMPUTE = 500
34
+ BATCH_SIZE_SCAN = 20
35
+ BATCH_SIZE_TRANSCRIPTS = 50
36
+ NUM_TOP_GUIDES = 10
37
+ NUM_MISMATCHES = 3
38
+ RUN_MODES = dict(
39
+ all='All on-target guides per transcript',
40
+ top_guides='Top {:d} guides per transcript'.format(NUM_TOP_GUIDES),
41
+ titration='Top {:d} guides per transcript & their titration candidates'.format(NUM_TOP_GUIDES)
42
+ )
43
+
44
+
45
+ # configure GPUs
46
+ for gpu in tf.config.list_physical_devices('GPU'):
47
+ tf.config.experimental.set_memory_growth(gpu, enable=True)
48
+ if len(tf.config.list_physical_devices('GPU')) > 0:
49
+ tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
50
+
51
+
52
+ def load_transcripts(fasta_files: list, enforce_unique_ids: bool = True):
53
+
54
+ # load all transcripts from fasta files into a DataFrame
55
+ transcripts = pd.DataFrame()
56
+ for file in fasta_files:
57
+ try:
58
+ if os.path.splitext(file)[1] == '.gz':
59
+ with gzip.open(file, 'rt') as f:
60
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(f, 'fasta')], columns=[ID_COL, SEQ_COL])
61
+ else:
62
+ df = pd.DataFrame([(t.id, str(t.seq)) for t in SeqIO.parse(file, 'fasta')], columns=[ID_COL, SEQ_COL])
63
+ except Exception as e:
64
+ print(e, 'while loading', file)
65
+ continue
66
+ transcripts = pd.concat([transcripts, df])
67
+
68
+ # set index
69
+ transcripts[ID_COL] = transcripts[ID_COL].apply(lambda s: s.split('|')[0])
70
+ transcripts.set_index(ID_COL, inplace=True)
71
+ if enforce_unique_ids:
72
+ assert not transcripts.index.has_duplicates, "duplicate transcript ID's detected in fasta file"
73
+
74
+ return transcripts
75
+
76
+
77
+ def sequence_complement(sequence: list):
78
+ return [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in sequence]
79
+
80
+
81
+ def one_hot_encode_sequence(sequence: list, add_context_padding: bool = False):
82
+
83
+ # stack list of sequences into a tensor
84
+ sequence = tf.ragged.stack([tf.constant(list(seq)) for seq in sequence], axis=0)
85
+
86
+ # tokenize sequence
87
+ nucleotide_table = tf.lookup.StaticVocabularyTable(
88
+ initializer=tf.lookup.KeyValueTensorInitializer(
89
+ keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
90
+ values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
91
+ num_oov_buckets=1)
92
+ sequence = tf.RaggedTensor.from_row_splits(values=nucleotide_table.lookup(sequence.values),
93
+ row_splits=sequence.row_splits).to_tensor(255)
94
+
95
+ # add context padding if requested
96
+ if add_context_padding:
97
+ pad_5p = 255 * tf.ones([sequence.shape[0], CONTEXT_5P], dtype=sequence.dtype)
98
+ pad_3p = 255 * tf.ones([sequence.shape[0], CONTEXT_3P], dtype=sequence.dtype)
99
+ sequence = tf.concat([pad_5p, sequence, pad_3p], axis=1)
100
+
101
+ # one-hot encode
102
+ sequence = tf.one_hot(sequence, depth=4, dtype=tf.float16)
103
+
104
+ return sequence
105
+
106
+
107
+ def process_data(transcript_seq: str):
108
+
109
+ # convert to upper case
110
+ transcript_seq = transcript_seq.upper()
111
+
112
+ # get all target sites
113
+ target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN + 1)]
114
+
115
+ # prepare guide sequences
116
+ guide_seq = sequence_complement([seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq])
117
+
118
+ # model inputs
119
+ model_inputs = tf.concat([
120
+ tf.reshape(one_hot_encode_sequence(target_seq, add_context_padding=False), [len(target_seq), -1]),
121
+ tf.reshape(one_hot_encode_sequence(guide_seq, add_context_padding=True), [len(guide_seq), -1]),
122
+ ], axis=-1)
123
+ return target_seq, guide_seq, model_inputs
124
+
125
+
126
+ def calibrate_predictions(predictions: np.array, num_mismatches: np.array, params: pd.DataFrame = None):
127
+ if params is None:
128
+ params = pd.read_pickle('calibration_params.pkl')
129
+ correction = np.squeeze(params.set_index('num_mismatches').loc[num_mismatches, 'slope'].to_numpy())
130
+ return correction * predictions
131
+
132
+
133
+ def score_predictions(predictions: np.array, params: pd.DataFrame = None):
134
+ if params is None:
135
+ params = pd.read_pickle('scoring_params.pkl')
136
+
137
+ if UNIT_INTERVAL_MAP == 'sigmoid':
138
+ params = params.iloc[0]
139
+ return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
140
+
141
+ elif UNIT_INTERVAL_MAP == 'min-max':
142
+ return 1 - (predictions - params['a']) / (params['b'] - params['a'])
143
+
144
+ elif UNIT_INTERVAL_MAP == 'exp-lin-exp':
145
+ # regime indices
146
+ active_saturation = predictions < params['a']
147
+ linear_regime = (params['a'] <= predictions) & (predictions <= params['c'])
148
+ inactive_saturation = params['c'] < predictions
149
+
150
+ # linear regime
151
+ slope = (params['d'] - params['b']) / (params['c'] - params['a'])
152
+ intercept = -params['a'] * slope + params['b']
153
+ predictions[linear_regime] = slope * predictions[linear_regime] + intercept
154
+
155
+ # active saturation regime
156
+ alpha = slope / params['b']
157
+ beta = alpha * params['a'] - np.log(params['b'])
158
+ predictions[active_saturation] = np.exp(alpha * predictions[active_saturation] - beta)
159
+
160
+ # inactive saturation regime
161
+ alpha = slope / (1 - params['d'])
162
+ beta = -alpha * params['c'] - np.log(1 - params['d'])
163
+ predictions[inactive_saturation] = 1 - np.exp(-alpha * predictions[inactive_saturation] - beta)
164
+
165
+ return 1 - predictions
166
+
167
+ else:
168
+ raise NotImplementedError
169
+
170
+
171
+ def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model, status_update_fn=None):
172
+
173
+ # loop over transcripts
174
+ predictions = pd.DataFrame()
175
+ for i, (index, row) in enumerate(transcripts.iterrows()):
176
+
177
+ # parse transcript sequence
178
+ target_seq, guide_seq, model_inputs = process_data(row[SEQ_COL])
179
+
180
+ # get predictions
181
+ lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
182
+ lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
183
+ scores = score_predictions(lfc_estimate)
184
+ predictions = pd.concat([predictions, pd.DataFrame({
185
+ ID_COL: [index] * len(scores),
186
+ TARGET_COL: target_seq,
187
+ GUIDE_COL: guide_seq,
188
+ SCORE_COL: scores})])
189
+
190
+ # progress update
191
+ percent_complete = 100 * min((i + 1) / len(transcripts), 1)
192
+ update_text = 'Evaluating on-target guides for each transcript: {:.2f}%'.format(percent_complete)
193
+ print('\r' + update_text, end='')
194
+ if status_update_fn is not None:
195
+ status_update_fn(update_text, percent_complete)
196
+ print('')
197
+
198
+ return predictions
199
+
200
+
201
+ def top_guides_per_transcript(predictions: pd.DataFrame):
202
+
203
+ # select and sort top guides for each transcript
204
+ top_guides = pd.DataFrame()
205
+ for transcript in predictions[ID_COL].unique():
206
+ df = predictions.loc[predictions[ID_COL] == transcript]
207
+ df = df.sort_values(SCORE_COL, ascending=False).reset_index(drop=True).iloc[:NUM_TOP_GUIDES]
208
+ top_guides = pd.concat([top_guides, df])
209
+
210
+ return top_guides.reset_index(drop=True)
211
+
212
+
213
+ def get_titration_candidates(top_guide_predictions: pd.DataFrame):
214
+
215
+ # generate a table of all titration candidates
216
+ titration_candidates = pd.DataFrame()
217
+ for _, row in top_guide_predictions.iterrows():
218
+ for i in range(len(row[GUIDE_COL])):
219
+ nt = row[GUIDE_COL][i]
220
+ for mutation in set(NUCLEOTIDE_TOKENS.keys()) - {nt, 'N'}:
221
+ sm_guide = list(row[GUIDE_COL])
222
+ sm_guide[i] = mutation
223
+ sm_guide = ''.join(sm_guide)
224
+ assert row[GUIDE_COL] != sm_guide
225
+ titration_candidates = pd.concat([titration_candidates, pd.DataFrame({
226
+ ID_COL: [row[ID_COL]],
227
+ TARGET_COL: [row[TARGET_COL]],
228
+ GUIDE_COL: [sm_guide],
229
+ MM_COL: [1]
230
+ })])
231
+
232
+ return titration_candidates
233
+
234
+
235
+ def find_off_targets(top_guides: pd.DataFrame, status_update_fn=None):
236
+
237
+ # load reference transcripts
238
+ reference_transcripts = load_transcripts([os.path.join('transcripts', f) for f in REFERENCE_TRANSCRIPTS])
239
+
240
+ # one-hot encode guides to form a filter
241
+ guide_filter = one_hot_encode_sequence(sequence_complement(top_guides[GUIDE_COL]), add_context_padding=False)
242
+ guide_filter = tf.transpose(guide_filter, [1, 2, 0])
243
+
244
+ # loop over transcripts in batches
245
+ i = 0
246
+ off_targets = pd.DataFrame()
247
+ while i < len(reference_transcripts):
248
+ # select batch
249
+ df_batch = reference_transcripts.iloc[i:min(i + BATCH_SIZE_SCAN, len(reference_transcripts))]
250
+ i += BATCH_SIZE_SCAN
251
+
252
+ # find locations of off-targets
253
+ transcripts = one_hot_encode_sequence(df_batch[SEQ_COL].values.tolist(), add_context_padding=False)
254
+ num_mismatches = GUIDE_LEN - tf.nn.conv1d(transcripts, guide_filter, stride=1, padding='SAME')
255
+ loc_off_targets = tf.where(tf.round(num_mismatches) <= NUM_MISMATCHES).numpy()
256
+
257
+ # off-targets discovered
258
+ if len(loc_off_targets) > 0:
259
+
260
+ # log off-targets
261
+ dict_off_targets = pd.DataFrame({
262
+ 'On-target ' + ID_COL: top_guides.iloc[loc_off_targets[:, 2]][ID_COL],
263
+ GUIDE_COL: top_guides.iloc[loc_off_targets[:, 2]][GUIDE_COL],
264
+ 'Off-target ' + ID_COL: df_batch.index.values[loc_off_targets[:, 0]],
265
+ 'Guide Midpoint': loc_off_targets[:, 1],
266
+ SEQ_COL: df_batch[SEQ_COL].values[loc_off_targets[:, 0]],
267
+ MM_COL: tf.gather_nd(num_mismatches, loc_off_targets).numpy().astype(int),
268
+ }).to_dict('records')
269
+
270
+ # trim transcripts to targets
271
+ for row in dict_off_targets:
272
+ start_location = row['Guide Midpoint'] - (GUIDE_LEN // 2)
273
+ del row['Guide Midpoint']
274
+ target = row[SEQ_COL]
275
+ del row[SEQ_COL]
276
+ if start_location < CONTEXT_5P:
277
+ target = target[0:GUIDE_LEN + CONTEXT_3P]
278
+ target = 'N' * (TARGET_LEN - len(target)) + target
279
+ elif start_location + GUIDE_LEN + CONTEXT_3P > len(target):
280
+ target = target[start_location - CONTEXT_5P:]
281
+ target = target + 'N' * (TARGET_LEN - len(target))
282
+ else:
283
+ target = target[start_location - CONTEXT_5P:start_location + GUIDE_LEN + CONTEXT_3P]
284
+ if row[MM_COL] == 0 and 'N' not in target:
285
+ assert row[GUIDE_COL] == sequence_complement([target[CONTEXT_5P:TARGET_LEN - CONTEXT_3P]])[0]
286
+ row[TARGET_COL] = target
287
+
288
+ # append new off-targets
289
+ off_targets = pd.concat([off_targets, pd.DataFrame(dict_off_targets)])
290
+
291
+ # progress update
292
+ percent_complete = 100 * min((i + 1) / len(reference_transcripts), 1)
293
+ update_text = 'Scanning for off-targets: {:.2f}%'.format(percent_complete)
294
+ print('\r' + update_text, end='')
295
+ if status_update_fn is not None:
296
+ status_update_fn(update_text, percent_complete)
297
+ print('')
298
+
299
+ return off_targets
300
+
301
+
302
+ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
303
+ if len(off_targets) == 0:
304
+ return pd.DataFrame()
305
+
306
+ # compute off-target predictions
307
+ model_inputs = tf.concat([
308
+ tf.reshape(one_hot_encode_sequence(off_targets[TARGET_COL], add_context_padding=False), [len(off_targets), -1]),
309
+ tf.reshape(one_hot_encode_sequence(off_targets[GUIDE_COL], add_context_padding=True), [len(off_targets), -1]),
310
+ ], axis=-1)
311
+ lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
312
+ lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
313
+ off_targets[SCORE_COL] = score_predictions(lfc_estimate)
314
+
315
+ return off_targets.reset_index(drop=True)
316
+
317
+
318
+ def tiger_exhibit(transcripts: pd.DataFrame, mode: str, check_off_targets: bool, status_update_fn=None):
319
+
320
+ # load model
321
+ if os.path.exists('model'):
322
+ tiger = tf.keras.models.load_model('model')
323
+ else:
324
+ print('no saved model!')
325
+ exit()
326
+
327
+ # evaluate all on-target guides per transcript
328
+ on_target_predictions = get_on_target_predictions(transcripts, tiger, status_update_fn)
329
+
330
+ # initialize other outputs
331
+ titration_predictions = off_target_predictions = None
332
+
333
+ if mode == 'all' and not check_off_targets:
334
+ off_target_candidates = None
335
+
336
+ elif mode == 'top_guides':
337
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
338
+ off_target_candidates = on_target_predictions
339
+
340
+ elif mode == 'titration':
341
+ on_target_predictions = top_guides_per_transcript(on_target_predictions)
342
+ titration_candidates = get_titration_candidates(on_target_predictions)
343
+ titration_predictions = predict_off_target(titration_candidates, model=tiger)
344
+ off_target_candidates = pd.concat([on_target_predictions, titration_predictions])
345
+
346
+ else:
347
+ raise NotImplementedError
348
+
349
+ # check off-target effects for top guides
350
+ if check_off_targets and off_target_candidates is not None:
351
+ off_target_candidates = find_off_targets(off_target_candidates, status_update_fn)
352
+ off_target_predictions = predict_off_target(off_target_candidates, model=tiger)
353
+ if len(off_target_predictions) > 0:
354
+ off_target_predictions = off_target_predictions.sort_values(SCORE_COL, ascending=False)
355
+ off_target_predictions = off_target_predictions.reset_index(drop=True)
356
+
357
+ # finalize tables
358
+ for df in [on_target_predictions, titration_predictions, off_target_predictions]:
359
+ if df is not None and len(df) > 0:
360
+ for col in df.columns:
361
+ if ID_COL in col and set(df[col].unique()) == {'ManualEntry'}:
362
+ del df[col]
363
+ df[GUIDE_COL] = df[GUIDE_COL].apply(lambda s: s[::-1]) # reverse guide sequences
364
+ df[TARGET_COL] = df[TARGET_COL].apply(lambda seq: seq[CONTEXT_5P:len(seq) - CONTEXT_3P]) # remove context
365
+
366
+ return on_target_predictions, titration_predictions, off_target_predictions
367
+
368
+
369
+ if __name__ == '__main__':
370
+
371
+ # common arguments
372
+ parser = argparse.ArgumentParser()
373
+ parser.add_argument('--mode', type=str, default='titration')
374
+ parser.add_argument('--check_off_targets', action='store_true', default=False)
375
+ parser.add_argument('--fasta_path', type=str, default=None)
376
+ args = parser.parse_args()
377
+
378
+ # check for any existing results
379
+ if os.path.exists('on_target.csv') or os.path.exists('titration.csv') or os.path.exists('off_target.csv'):
380
+ raise FileExistsError('please rename or delete existing results')
381
+
382
+ # load transcripts from a directory of fasta files
383
+ if args.fasta_path is not None and os.path.exists(args.fasta_path):
384
+ df_transcripts = load_transcripts([os.path.join(args.fasta_path, f) for f in os.listdir(args.fasta_path)])
385
+
386
+ # otherwise consider simple test case with first 50 nucleotides from EIF3B-003's CDS
387
+ else:
388
+ df_transcripts = pd.DataFrame({
389
+ ID_COL: ['ManualEntry'],
390
+ SEQ_COL: ['ATGCAGGACGCGGAGAACGTGGCGGTGCCCGAGGCGGCCGAGGAGCGCGC']})
391
+ df_transcripts.set_index(ID_COL, inplace=True)
392
+
393
+ # process in batches
394
+ batch = 0
395
+ num_batches = len(df_transcripts) // BATCH_SIZE_TRANSCRIPTS
396
+ num_batches += (len(df_transcripts) % BATCH_SIZE_TRANSCRIPTS > 0)
397
+ for idx in range(0, len(df_transcripts), BATCH_SIZE_TRANSCRIPTS):
398
+ batch += 1
399
+ print('Batch {:d} of {:d}'.format(batch, num_batches))
400
+
401
+ # run batch
402
+ idx_stop = min(idx + BATCH_SIZE_TRANSCRIPTS, len(df_transcripts))
403
+ df_on_target, df_titration, df_off_target = tiger_exhibit(
404
+ transcripts=df_transcripts[idx:idx_stop],
405
+ mode=args.mode,
406
+ check_off_targets=args.check_off_targets
407
+ )
408
+
409
+ # save batch results
410
+ df_on_target.to_csv('on_target.csv', header=batch == 1, index=False, mode='a')
411
+ if df_titration is not None:
412
+ df_titration.to_csv('titration.csv', header=batch == 1, index=False, mode='a')
413
+ if df_off_target is not None:
414
+ df_off_target.to_csv('off_target.csv', header=batch == 1, index=False, mode='a')
415
+
416
+ # clear session to prevent memory blow up
417
+ tf.keras.backend.clear_session()