supercat666 commited on
Commit
4a303ce
·
1 Parent(s): 69d7c1c

change cas9

Browse files
Files changed (7) hide show
  1. app.py +9 -8
  2. cas12lstm.py +188 -0
  3. cas12lstmvcf.py +287 -0
  4. cas9att.py +299 -0
  5. cas9attvcf.py +397 -0
  6. cas9on.py +1 -3
  7. requirements.txt +3 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import tiger
3
- import cas9on
 
4
  import cas9off
5
  import cas12
6
  import pandas as pd
@@ -22,8 +23,8 @@ st.divider()
22
  CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
23
 
24
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
25
- cas9on_path = 'cas9_model/on-cla.h5'
26
- cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'
27
 
28
  #plot functions
29
  def generate_coolbox_plot(bigwig_path, region, output_image_path):
@@ -182,8 +183,8 @@ if selected_model == 'Cas9':
182
  # Process predictions
183
  if predict_button and gene_symbol:
184
  with st.spinner('Predicting... Please wait'):
185
- predictions, gene_sequence, exons = cas9on.process_gene(gene_symbol, cas9on_path)
186
- sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
187
  st.session_state['on_target_results'] = sorted_predictions
188
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
189
  st.session_state['exons'] = exons # Store exon data
@@ -283,9 +284,9 @@ if selected_model == 'Cas9':
283
 
284
 
285
  # Generate files
286
- cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
287
- cas9on.create_bed_file_from_df(df, bed_file_path)
288
- cas9on.create_csv_from_df(df, csv_file_path)
289
 
290
  # Prepare an in-memory buffer for the ZIP file
291
  zip_buffer = io.BytesIO()
 
1
  import os
2
  import tiger
3
+ import cas9att
4
+ import cas9attvcf
5
  import cas9off
6
  import cas12
7
  import pandas as pd
 
23
  CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
24
 
25
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
26
+ cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.keras'
27
+ cas12_path = 'cas12_model/BiLSTM_Cpf1_weights.keras'
28
 
29
  #plot functions
30
  def generate_coolbox_plot(bigwig_path, region, output_image_path):
 
183
  # Process predictions
184
  if predict_button and gene_symbol:
185
  with st.spinner('Predicting... Please wait'):
186
+ predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
187
+ sorted_predictions = sorted(predictions)[:10]
188
  st.session_state['on_target_results'] = sorted_predictions
189
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
190
  st.session_state['exons'] = exons # Store exon data
 
284
 
285
 
286
  # Generate files
287
+ cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
288
+ cas9att.create_bed_file_from_df(df, bed_file_path)
289
+ cas9att.create_csv_from_df(df, csv_file_path)
290
 
291
  # Prepare an in-memory buffer for the ZIP file
292
  zip_buffer = io.BytesIO()
cas12lstm.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras import regularizers
3
+ from keras.layers import Input, Dense, Dropout, Activation, Conv1D
4
+ from keras.layers import GlobalAveragePooling1D, AveragePooling1D
5
+ from keras.layers import Bidirectional, LSTM
6
+ from keras import Model
7
+ from keras.metrics import MeanSquaredError
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ import requests
13
+ from functools import reduce
14
+ from operator import add
15
+ import tabulate
16
+ from difflib import SequenceMatcher
17
+
18
+ import cyvcf2
19
+ import parasail
20
+
21
+ import re
22
+
23
+ ntmap = {'A': (1, 0, 0, 0),
24
+ 'C': (0, 1, 0, 0),
25
+ 'G': (0, 0, 1, 0),
26
+ 'T': (0, 0, 0, 1)
27
+ }
28
+
29
+ def get_seqcode(seq):
30
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
31
+
32
+ def BiLSTM_model(input_shape):
33
+ input = Input(shape=input_shape)
34
+
35
+ conv1 = Conv1D(128, 5, activation="relu")(input)
36
+ pool1 = AveragePooling1D(2)(conv1)
37
+ drop1 = Dropout(0.1)(pool1)
38
+
39
+ conv2 = Conv1D(128, 5, activation="relu")(drop1)
40
+ pool2 = AveragePooling1D(2)(conv2)
41
+ drop2 = Dropout(0.1)(pool2)
42
+
43
+ lstm1 = Bidirectional(LSTM(128,
44
+ dropout=0.1,
45
+ activation='tanh',
46
+ return_sequences=True,
47
+ kernel_regularizer=regularizers.l2(1e-4)))(drop2)
48
+ avgpool = GlobalAveragePooling1D()(lstm1)
49
+
50
+ dense1 = Dense(128,
51
+ kernel_regularizer=regularizers.l2(1e-4),
52
+ bias_regularizer=regularizers.l2(1e-4),
53
+ activation="relu")(avgpool)
54
+ drop3 = Dropout(0.1)(dense1)
55
+
56
+ dense2 = Dense(32,
57
+ kernel_regularizer=regularizers.l2(1e-4),
58
+ bias_regularizer=regularizers.l2(1e-4),
59
+ activation="relu")(drop3)
60
+ drop4 = Dropout(0.1)(dense2)
61
+
62
+ dense3 = Dense(32,
63
+ kernel_regularizer=regularizers.l2(1e-4),
64
+ bias_regularizer=regularizers.l2(1e-4),
65
+ activation="relu")(drop4)
66
+ drop5 = Dropout(0.1)(dense3)
67
+
68
+ output = Dense(1, activation="linear")(drop5)
69
+
70
+ model = Model(inputs=[input], outputs=[output])
71
+ return model
72
+
73
+ def fetch_ensembl_transcripts(gene_symbol):
74
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
75
+ response = requests.get(url)
76
+ if response.status_code == 200:
77
+ gene_data = response.json()
78
+ if 'Transcript' in gene_data:
79
+ return gene_data['Transcript']
80
+ else:
81
+ print("No transcripts found for gene:", gene_symbol)
82
+ return None
83
+ else:
84
+ print(f"Error fetching gene data from Ensembl: {response.text}")
85
+ return None
86
+
87
+ def fetch_ensembl_sequence(transcript_id):
88
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
89
+ response = requests.get(url)
90
+ if response.status_code == 200:
91
+ sequence_data = response.json()
92
+ if 'seq' in sequence_data:
93
+ return sequence_data['seq']
94
+ else:
95
+ print("No sequence found for transcript:", transcript_id)
96
+ return None
97
+ else:
98
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
99
+ return None
100
+
101
+ def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
102
+ targets = []
103
+ len_sequence = len(sequence)
104
+ #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
105
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
106
+
107
+ for i in range(len_sequence - target_length + 1):
108
+ target_seq = sequence[i:i + target_length]
109
+ if target_seq[4:7] == 'TTT':
110
+ if strand == -1:
111
+ tar_start = end - i - target_length + 1
112
+ tar_end = end -i
113
+ #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
114
+ else:
115
+ tar_start = start + i
116
+ tar_end = start + i + target_length - 1
117
+ #seq_in_ref = target_seq
118
+ gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
119
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
120
+ #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
121
+ return targets
122
+
123
+ def format_prediction_output(targets, model_path):
124
+ # Loading weights for the model
125
+ Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
126
+ Crispr_BiLSTM.load_weights(model_path)
127
+
128
+ formatted_data = []
129
+ for target in targets:
130
+ # Predict
131
+ encoded_seq = get_seqcode(target[0])
132
+ prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
133
+ if prediction > 100:
134
+ prediction = 100
135
+
136
+ # Format output
137
+ gRNA = target[1]
138
+ chr = target[2]
139
+ start = target[3]
140
+ end = target[4]
141
+ strand = target[5]
142
+ transcript_id = target[6]
143
+ exon_id = target[7]
144
+ #seq_in_ref = target[8]
145
+ #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction])
146
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
147
+
148
+ return formatted_data
149
+
150
+
151
+ def process_gene(gene_symbol, model_path):
152
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
153
+ results = []
154
+ all_exons = [] # To accumulate all exons
155
+ all_gene_sequences = [] # To accumulate all gene sequences
156
+
157
+ if transcripts:
158
+ for transcript in transcripts:
159
+ Exons = transcript['Exon']
160
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
161
+ transcript_id = transcript['id']
162
+
163
+ for Exon in Exons:
164
+ exon_id = Exon['id']
165
+ gene_sequence = fetch_ensembl_sequence(exon_id)
166
+ if gene_sequence:
167
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
168
+ chr = Exon['seq_region_name']
169
+ start = Exon['start']
170
+ end = Exon['end']
171
+ strand = Exon['strand']
172
+
173
+ targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
174
+ if targets:
175
+ # Predict on-target efficiency for each gRNA site
176
+ formatted_data = format_prediction_output(targets, model_path)
177
+ results.extend(formatted_data) # Flatten the results
178
+ else:
179
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
180
+ else:
181
+ print("Failed to retrieve transcripts.")
182
+
183
+ # Sort results based on prediction score (assuming score is at the 8th index)
184
+ sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
185
+
186
+ # Return the sorted output, combined gene sequences, and all exons
187
+ return sorted_results, all_gene_sequences, all_exons
188
+
cas12lstmvcf.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from keras import regularizers
3
+ from keras.layers import Input, Dense, Dropout, Activation, Conv1D
4
+ from keras.layers import GlobalAveragePooling1D, AveragePooling1D
5
+ from keras.layers import Bidirectional, LSTM
6
+ from keras import Model
7
+ from keras.metrics import MeanSquaredError
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ import requests
13
+ from functools import reduce
14
+ from operator import add
15
+ import tabulate
16
+ from difflib import SequenceMatcher
17
+
18
+ import cyvcf2
19
+ import parasail
20
+
21
+ import re
22
+
23
+ ntmap = {'A': (1, 0, 0, 0),
24
+ 'C': (0, 1, 0, 0),
25
+ 'G': (0, 0, 1, 0),
26
+ 'T': (0, 0, 0, 1)
27
+ }
28
+
29
+ def get_seqcode(seq):
30
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
31
+
32
+ def BiLSTM_model(input_shape):
33
+ input = Input(shape=input_shape)
34
+
35
+ conv1 = Conv1D(128, 5, activation="relu")(input)
36
+ pool1 = AveragePooling1D(2)(conv1)
37
+ drop1 = Dropout(0.1)(pool1)
38
+
39
+ conv2 = Conv1D(128, 5, activation="relu")(drop1)
40
+ pool2 = AveragePooling1D(2)(conv2)
41
+ drop2 = Dropout(0.1)(pool2)
42
+
43
+ lstm1 = Bidirectional(LSTM(128,
44
+ dropout=0.1,
45
+ activation='tanh',
46
+ return_sequences=True,
47
+ kernel_regularizer=regularizers.l2(1e-4)))(drop2)
48
+ avgpool = GlobalAveragePooling1D()(lstm1)
49
+
50
+ dense1 = Dense(128,
51
+ kernel_regularizer=regularizers.l2(1e-4),
52
+ bias_regularizer=regularizers.l2(1e-4),
53
+ activation="relu")(avgpool)
54
+ drop3 = Dropout(0.1)(dense1)
55
+
56
+ dense2 = Dense(32,
57
+ kernel_regularizer=regularizers.l2(1e-4),
58
+ bias_regularizer=regularizers.l2(1e-4),
59
+ activation="relu")(drop3)
60
+ drop4 = Dropout(0.1)(dense2)
61
+
62
+ dense3 = Dense(32,
63
+ kernel_regularizer=regularizers.l2(1e-4),
64
+ bias_regularizer=regularizers.l2(1e-4),
65
+ activation="relu")(drop4)
66
+ drop5 = Dropout(0.1)(dense3)
67
+
68
+ output = Dense(1, activation="linear")(drop5)
69
+
70
+ model = Model(inputs=[input], outputs=[output])
71
+ return model
72
+
73
+ def fetch_ensembl_transcripts(gene_symbol):
74
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
75
+ response = requests.get(url)
76
+ if response.status_code == 200:
77
+ gene_data = response.json()
78
+ if 'Transcript' in gene_data:
79
+ return gene_data['Transcript']
80
+ else:
81
+ print("No transcripts found for gene:", gene_symbol)
82
+ return None
83
+ else:
84
+ print(f"Error fetching gene data from Ensembl: {response.text}")
85
+ return None
86
+
87
+ def fetch_ensembl_sequence(transcript_id):
88
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
89
+ response = requests.get(url)
90
+ if response.status_code == 200:
91
+ sequence_data = response.json()
92
+ if 'seq' in sequence_data:
93
+ return sequence_data['seq']
94
+ else:
95
+ print("No sequence found for transcript:", transcript_id)
96
+ return None
97
+ else:
98
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
99
+ return None
100
+
101
+ def apply_mutation(ref_sequence, offset, ref, alt):
102
+ """
103
+ Apply a single mutation to the sequence.
104
+ """
105
+ if len(ref) == len(alt) and alt != "*": # SNP
106
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
107
+
108
+ elif len(ref) < len(alt): # Insertion
109
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
110
+
111
+ elif len(ref) == len(alt) and alt == "*": # Deletion
112
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
113
+
114
+ elif len(ref) > len(alt) and alt != "*": # Deletion
115
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
116
+
117
+ elif len(ref) > len(alt) and alt == "*": # Deletion
118
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
119
+
120
+ return mutated_seq
121
+
122
+
123
+ def construct_combinations(sequence, mutations):
124
+ """
125
+ Construct all combinations of mutations.
126
+ mutations is a list of tuples (position, ref, [alts])
127
+ """
128
+ if not mutations:
129
+ return [sequence]
130
+
131
+ # Take the first mutation and recursively construct combinations for the rest
132
+ first_mutation = mutations[0]
133
+ rest_mutations = mutations[1:]
134
+ offset, ref, alts = first_mutation
135
+
136
+ sequences = []
137
+ for alt in alts:
138
+ mutated_sequence = apply_mutation(sequence, offset, ref, alt)
139
+ sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
140
+
141
+ return sequences
142
+
143
+ def needleman_wunsch_alignment(query_seq, ref_seq):
144
+ """
145
+ Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
146
+ Use this position to represent the position of target sequence with mutations
147
+ """
148
+ # Needleman-Wunsch alignment
149
+ alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
150
+
151
+ # extract CIGAR object
152
+ cigar = alignment.cigar
153
+ cigar_string = cigar.decode.decode("utf-8")
154
+
155
+ # record ref_pos
156
+ ref_pos = 0
157
+
158
+ matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
159
+ max_num_before_equal = 0
160
+ max_equal_index = -1
161
+ total_before_max_equal = 0
162
+
163
+ for i, (num_str, op) in enumerate(matches):
164
+ num = int(num_str)
165
+ if op == '=':
166
+ if num > max_num_before_equal:
167
+ max_num_before_equal = num
168
+ max_equal_index = i
169
+ total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
170
+
171
+ ref_pos = total_before_max_equal
172
+
173
+ return ref_pos
174
+
175
+ def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
176
+ exon_id, gene_symbol, vcf_reader, pam="TTTN", target_length=34):
177
+ # initialization
178
+ mutated_sequences = [ref_sequence]
179
+
180
+ # find mutations within interested region
181
+ mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
182
+ if mutations:
183
+ # find mutations
184
+ mutation_list = []
185
+ for mutation in mutations:
186
+ offset = mutation.POS - start
187
+ ref = mutation.REF
188
+ alts = mutation.ALT[:-1]
189
+ mutation_list.append((offset, ref, alts))
190
+
191
+ # replace reference sequence of mutation
192
+ mutated_sequences = construct_combinations(ref_sequence, mutation_list)
193
+
194
+ # find gRNA in ref_sequence or all mutated_sequences
195
+ targets = []
196
+ for seq in mutated_sequences:
197
+ len_sequence = len(seq)
198
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
199
+ for i in range(len_sequence - target_length + 1):
200
+ target_seq = seq[i:i + target_length]
201
+ if target_seq[4:7] == 'TTT':
202
+ pos = ref_sequence.find(target_seq)
203
+ if pos != -1:
204
+ is_mut = False
205
+ if strand == -1:
206
+ tar_start = end - pos - target_length + 1
207
+ else:
208
+ tar_start = start + pos
209
+ else:
210
+ is_mut = True
211
+ nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
212
+ if strand == -1:
213
+ tar_start = str(end - nw_pos - target_length + 1) + '*'
214
+ else:
215
+ tar_start = str(start + nw_pos) + '*'
216
+ gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
217
+ targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
218
+
219
+ # filter duplicated targets
220
+ unique_targets_set = set(tuple(element) for element in targets)
221
+ unique_targets = [list(element) for element in unique_targets_set]
222
+
223
+ return unique_targets
224
+
225
+ def format_prediction_output_with_mutation(targets, model_path):
226
+ Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
227
+ Crispr_BiLSTM.load_weights(model_path)
228
+
229
+ formatted_data = []
230
+ for target in targets:
231
+ # Predict
232
+ encoded_seq = get_seqcode(target[0])
233
+ prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
234
+ if prediction > 100:
235
+ prediction = 100
236
+
237
+ # Format output
238
+ gRNA = target[1]
239
+ exon_chr = target[2]
240
+ strand = target[3]
241
+ tar_start = target[4]
242
+ transcript_id = target[5]
243
+ exon_id = target[6]
244
+ gene_symbol = target[7]
245
+ is_mut = target[8]
246
+ formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, exon_id, target[0], gRNA, prediction, is_mut])
247
+
248
+ return formatted_data
249
+
250
+ def process_gene(gene_symbol, vcf_reader, model_path):
251
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
252
+ results = []
253
+ all_exons = [] # To accumulate all exons
254
+ all_gene_sequences = [] # To accumulate all gene sequences
255
+
256
+ if transcripts:
257
+ for transcript in transcripts:
258
+ Exons = transcript['Exon']
259
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
260
+ transcript_id = transcript['id']
261
+
262
+ for Exon in Exons:
263
+ exon_id = Exon['id']
264
+ gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
265
+ if gene_sequence:
266
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
267
+ exon_chr = Exon['seq_region_name']
268
+ start = Exon['start']
269
+ end = Exon['end']
270
+ strand = Exon['strand']
271
+
272
+ targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, transcript_id, exon_id, gene_symbol, vcf_reader)
273
+ if targets:
274
+ # Predict on-target efficiency for each gRNA site
275
+ formatted_data = format_prediction_output_with_mutation(targets, model_path)
276
+ results.extend(formatted_data) # Flatten the results
277
+ else:
278
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
279
+ else:
280
+ print("Failed to retrieve transcripts.")
281
+
282
+ # Sort results based on prediction score (assuming score is at the 8th index)
283
+ sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
284
+
285
+ # Return the sorted output, combined gene sequences, and all exons
286
+ return sorted_results, all_gene_sequences, all_exons
287
+
cas9att.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
+ from operator import add
6
+ from functools import reduce
7
+ import random
8
+ import tabulate
9
+
10
+ from keras import Model
11
+ from keras import regularizers
12
+ from keras.optimizers import Adam
13
+ from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
14
+ from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
15
+ from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
16
+ from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
17
+ from keras.models import load_model
18
+ from keras.callbacks import EarlyStopping, ReduceLROnPlateau
19
+ from Bio import SeqIO
20
+ from Bio.SeqRecord import SeqRecord
21
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
22
+ from Bio.Seq import Seq
23
+
24
+ import cyvcf2
25
+ import parasail
26
+
27
+ import re
28
+
29
+ ntmap = {'A': (1, 0, 0, 0),
30
+ 'C': (0, 1, 0, 0),
31
+ 'G': (0, 0, 1, 0),
32
+ 'T': (0, 0, 0, 1)
33
+ }
34
+
35
+ def get_seqcode(seq):
36
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
37
+
38
+ class PositionalEncoding(Layer):
39
+ def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
40
+ super(PositionalEncoding, self).__init__()
41
+ self.sequence_len = sequence_len
42
+ self.embedding_dim = embedding_dim
43
+
44
+ def call(self, x):
45
+
46
+ position_embedding = np.array([
47
+ [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
48
+ for pos in range(self.sequence_len)])
49
+
50
+ position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
51
+ position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
52
+ position_embedding = tf.cast(position_embedding, dtype=tf.float32)
53
+
54
+ return position_embedding+x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ config.update({
59
+ 'sequence_len' : self.sequence_len,
60
+ 'embedding_dim' : self.embedding_dim,
61
+ })
62
+ return config
63
+
64
+ def MultiHeadAttention_model(input_shape):
65
+ input = Input(shape=input_shape)
66
+
67
+ conv1 = Conv1D(256, 3, activation="relu")(input)
68
+ pool1 = AveragePooling1D(2)(conv1)
69
+ drop1 = Dropout(0.4)(pool1)
70
+
71
+ conv2 = Conv1D(256, 3, activation="relu")(drop1)
72
+ pool2 = AveragePooling1D(2)(conv2)
73
+ drop2 = Dropout(0.4)(pool2)
74
+
75
+ lstm = Bidirectional(LSTM(128,
76
+ dropout=0.5,
77
+ activation='tanh',
78
+ return_sequences=True,
79
+ kernel_regularizer=regularizers.l2(0.01)))(drop2)
80
+
81
+ pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
82
+ atten = MultiHeadAttention(num_heads=2,
83
+ key_dim=64,
84
+ dropout=0.2,
85
+ kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
86
+
87
+ flat = Flatten()(atten)
88
+
89
+ dense1 = Dense(512,
90
+ kernel_regularizer=regularizers.l2(1e-4),
91
+ bias_regularizer=regularizers.l2(1e-4),
92
+ activation="relu")(flat)
93
+ drop3 = Dropout(0.1)(dense1)
94
+
95
+ dense2 = Dense(128,
96
+ kernel_regularizer=regularizers.l2(1e-4),
97
+ bias_regularizer=regularizers.l2(1e-4),
98
+ activation="relu")(drop3)
99
+ drop4 = Dropout(0.1)(dense2)
100
+
101
+ dense3 = Dense(256,
102
+ kernel_regularizer=regularizers.l2(1e-4),
103
+ bias_regularizer=regularizers.l2(1e-4),
104
+ activation="relu")(drop4)
105
+ drop5 = Dropout(0.1)(dense3)
106
+
107
+ output = Dense(1, activation="linear")(drop5)
108
+
109
+ model = Model(inputs=[input], outputs=[output])
110
+ return model
111
+
112
+ def fetch_ensembl_transcripts(gene_symbol):
113
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
114
+ response = requests.get(url)
115
+ if response.status_code == 200:
116
+ gene_data = response.json()
117
+ if 'Transcript' in gene_data:
118
+ return gene_data['Transcript']
119
+ else:
120
+ print("No transcripts found for gene:", gene_symbol)
121
+ return None
122
+ else:
123
+ print(f"Error fetching gene data from Ensembl: {response.text}")
124
+ return None
125
+
126
+ def fetch_ensembl_sequence(transcript_id):
127
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
128
+ response = requests.get(url)
129
+ if response.status_code == 200:
130
+ sequence_data = response.json()
131
+ if 'seq' in sequence_data:
132
+ return sequence_data['seq']
133
+ else:
134
+ print("No sequence found for transcript:", transcript_id)
135
+ return None
136
+ else:
137
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
138
+ return None
139
+
140
+ def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
141
+ targets = []
142
+ len_sequence = len(sequence)
143
+ #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
144
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
145
+
146
+ for i in range(len_sequence - len(pam) + 1):
147
+ if sequence[i + 1:i + 3] == pam[1:]:
148
+ if i >= target_length:
149
+ target_seq = sequence[i - target_length:i + 3]
150
+ if strand == -1:
151
+ tar_start = end - (i + 2)
152
+ tar_end = end - (i - target_length)
153
+ #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
154
+ else:
155
+ tar_start = start + i - target_length
156
+ tar_end = start + i + 3 - 1
157
+ #seq_in_ref = target_seq
158
+ gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
159
+ #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
160
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
161
+
162
+ return targets
163
+
164
+ # Function to predict on-target efficiency and format output
165
+ def format_prediction_output(targets, model_path):
166
+ model = MultiHeadAttention_model(input_shape=(23, 4))
167
+ model.load_weights(model_path)
168
+
169
+ formatted_data = []
170
+
171
+ for target in targets:
172
+ # Encode the gRNA sequence
173
+ encoded_seq = get_seqcode(target[0])
174
+
175
+ # Predict on-target efficiency using the model
176
+ prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
177
+ if prediction > 100:
178
+ prediction = 100
179
+
180
+ # Format output
181
+ gRNA = target[1]
182
+ chr = target[2]
183
+ start = target[3]
184
+ end = target[4]
185
+ strand = target[5]
186
+ transcript_id = target[6]
187
+ exon_id = target[7]
188
+ #seq_in_ref = target[8]
189
+ #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
190
+ formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
191
+
192
+ return formatted_data
193
+
194
+ def process_gene(gene_symbol, model_path):
195
+ # Fetch transcripts for the given gene symbol
196
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
197
+ results = []
198
+ all_exons = [] # To accumulate all exons
199
+ all_gene_sequences = [] # To accumulate all gene sequences
200
+
201
+ if transcripts:
202
+ for transcript in transcripts:
203
+ Exons = transcript['Exon']
204
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
205
+ transcript_id = transcript['id']
206
+
207
+ for exon in Exons:
208
+ exon_id = exon['id']
209
+ gene_sequence = fetch_ensembl_sequence(exon_id)
210
+ if gene_sequence:
211
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
212
+ start = exon['start']
213
+ end = exon['end']
214
+ strand = exon['strand']
215
+ chr = exon['seq_region_name']
216
+ # Find potential CRISPR targets within the exon
217
+ targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
218
+ if targets:
219
+ # Format the prediction output for the targets found
220
+ formatted_data = format_prediction_output(targets, model_path)
221
+ results.extend(formatted_data) # Append results
222
+ else:
223
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
224
+ else:
225
+ print("Failed to retrieve transcripts.")
226
+
227
+ # Sort results based on prediction score (assuming score is at the 8th index)
228
+ sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
229
+
230
+ # Return the sorted output, combined gene sequences, and all exons
231
+ return sorted_results, all_gene_sequences, all_exons
232
+
233
+
234
+ def create_genbank_features(data):
235
+ features = []
236
+
237
+ # If the input data is a DataFrame, convert it to a list of lists
238
+ if isinstance(data, pd.DataFrame):
239
+ formatted_data = data.values.tolist()
240
+ elif isinstance(data, list):
241
+ formatted_data = data
242
+ else:
243
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
244
+
245
+ for row in formatted_data:
246
+ try:
247
+ start = int(row[1])
248
+ end = int(row[2])
249
+ except ValueError as e:
250
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
251
+ continue
252
+
253
+ strand = 1 if row[3] == '+' else -1
254
+ location = FeatureLocation(start=start, end=end, strand=strand)
255
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
256
+ 'label': row[7], # Use gRNA as the label
257
+ 'note': f"Prediction: {row[8]}" # Include the prediction score
258
+ })
259
+ features.append(feature)
260
+
261
+ return features
262
+
263
+
264
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
265
+ # Ensure gene_sequence is a string before creating Seq object
266
+ if not isinstance(gene_sequence, str):
267
+ gene_sequence = str(gene_sequence)
268
+
269
+ features = create_genbank_features(df)
270
+
271
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
272
+ seq_obj = Seq(gene_sequence)
273
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
274
+ description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
275
+ record.annotations["molecule_type"] = "DNA"
276
+ SeqIO.write(record, output_path, "genbank")
277
+
278
+
279
+ def create_bed_file_from_df(df, output_path):
280
+ with open(output_path, 'w') as bed_file:
281
+ for index, row in df.iterrows():
282
+ chrom = row["Chr"]
283
+ start = int(row["Start Pos"])
284
+ end = int(row["End Pos"])
285
+ strand = '+' if row["Strand"] == '1' else '-'
286
+ gRNA = row["gRNA"]
287
+ score = str(row["Prediction"])
288
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
289
+ transcript_id = row["Transcript"]
290
+
291
+ # Writing only standard BED columns; additional columns can be appended as needed
292
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
293
+
294
+
295
+ def create_csv_from_df(df, output_path):
296
+ df.to_csv(output_path, index=False)
297
+
298
+
299
+
cas9attvcf.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import tensorflow as tf
3
+ import pandas as pd
4
+ import numpy as np
5
+ from operator import add
6
+ from functools import reduce
7
+ import random
8
+ import tabulate
9
+
10
+ from keras import Model
11
+ from keras import regularizers
12
+ from keras.optimizers import Adam
13
+ from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
14
+ from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
15
+ from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
16
+ from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
17
+ from keras.models import load_model
18
+ from keras.callbacks import EarlyStopping, ReduceLROnPlateau
19
+ from Bio import SeqIO
20
+ from Bio.SeqRecord import SeqRecord
21
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
22
+ from Bio.Seq import Seq
23
+
24
+ import cyvcf2
25
+ import parasail
26
+
27
+ import re
28
+
29
+ ntmap = {'A': (1, 0, 0, 0),
30
+ 'C': (0, 1, 0, 0),
31
+ 'G': (0, 0, 1, 0),
32
+ 'T': (0, 0, 0, 1)
33
+ }
34
+
35
+ def get_seqcode(seq):
36
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
37
+
38
+ class PositionalEncoding(Layer):
39
+ def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
40
+ super(PositionalEncoding, self).__init__()
41
+ self.sequence_len = sequence_len
42
+ self.embedding_dim = embedding_dim
43
+
44
+ def call(self, x):
45
+
46
+ position_embedding = np.array([
47
+ [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
48
+ for pos in range(self.sequence_len)])
49
+
50
+ position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
51
+ position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
52
+ position_embedding = tf.cast(position_embedding, dtype=tf.float32)
53
+
54
+ return position_embedding+x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ config.update({
59
+ 'sequence_len' : self.sequence_len,
60
+ 'embedding_dim' : self.embedding_dim,
61
+ })
62
+ return config
63
+
64
+ def MultiHeadAttention_model(input_shape):
65
+ input = Input(shape=input_shape)
66
+
67
+ conv1 = Conv1D(256, 3, activation="relu")(input)
68
+ pool1 = AveragePooling1D(2)(conv1)
69
+ drop1 = Dropout(0.4)(pool1)
70
+
71
+ conv2 = Conv1D(256, 3, activation="relu")(drop1)
72
+ pool2 = AveragePooling1D(2)(conv2)
73
+ drop2 = Dropout(0.4)(pool2)
74
+
75
+ lstm = Bidirectional(LSTM(128,
76
+ dropout=0.5,
77
+ activation='tanh',
78
+ return_sequences=True,
79
+ kernel_regularizer=regularizers.l2(0.01)))(drop2)
80
+
81
+ pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
82
+ atten = MultiHeadAttention(num_heads=2,
83
+ key_dim=64,
84
+ dropout=0.2,
85
+ kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
86
+
87
+ flat = Flatten()(atten)
88
+
89
+ dense1 = Dense(512,
90
+ kernel_regularizer=regularizers.l2(1e-4),
91
+ bias_regularizer=regularizers.l2(1e-4),
92
+ activation="relu")(flat)
93
+ drop3 = Dropout(0.1)(dense1)
94
+
95
+ dense2 = Dense(128,
96
+ kernel_regularizer=regularizers.l2(1e-4),
97
+ bias_regularizer=regularizers.l2(1e-4),
98
+ activation="relu")(drop3)
99
+ drop4 = Dropout(0.1)(dense2)
100
+
101
+ dense3 = Dense(256,
102
+ kernel_regularizer=regularizers.l2(1e-4),
103
+ bias_regularizer=regularizers.l2(1e-4),
104
+ activation="relu")(drop4)
105
+ drop5 = Dropout(0.1)(dense3)
106
+
107
+ output = Dense(1, activation="linear")(drop5)
108
+
109
+ model = Model(inputs=[input], outputs=[output])
110
+ return model
111
+
112
+ def fetch_ensembl_transcripts(gene_symbol):
113
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
114
+ response = requests.get(url)
115
+ if response.status_code == 200:
116
+ gene_data = response.json()
117
+ if 'Transcript' in gene_data:
118
+ return gene_data['Transcript']
119
+ else:
120
+ print("No transcripts found for gene:", gene_symbol)
121
+ return None
122
+ else:
123
+ print(f"Error fetching gene data from Ensembl: {response.text}")
124
+ return None
125
+
126
+ def fetch_ensembl_sequence(transcript_id):
127
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
128
+ response = requests.get(url)
129
+ if response.status_code == 200:
130
+ sequence_data = response.json()
131
+ if 'seq' in sequence_data:
132
+ return sequence_data['seq']
133
+ else:
134
+ print("No sequence found for transcript:", transcript_id)
135
+ return None
136
+ else:
137
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
138
+ return None
139
+
140
+ def apply_mutation(ref_sequence, offset, ref, alt):
141
+ """
142
+ Apply a single mutation to the sequence.
143
+ """
144
+ if len(ref) == len(alt) and alt != "*": # SNP
145
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
146
+
147
+ elif len(ref) < len(alt): # Insertion
148
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
149
+
150
+ elif len(ref) == len(alt) and alt == "*": # Deletion
151
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
152
+
153
+ elif len(ref) > len(alt) and alt != "*": # Deletion
154
+ mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
155
+
156
+ elif len(ref) > len(alt) and alt == "*": # Deletion
157
+ mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
158
+
159
+ return mutated_seq
160
+
161
+ def construct_combinations(sequence, mutations):
162
+ """
163
+ Construct all combinations of mutations.
164
+ mutations is a list of tuples (position, ref, [alts])
165
+ """
166
+ if not mutations:
167
+ return [sequence]
168
+
169
+ # Take the first mutation and recursively construct combinations for the rest
170
+ first_mutation = mutations[0]
171
+ rest_mutations = mutations[1:]
172
+ offset, ref, alts = first_mutation
173
+
174
+ sequences = []
175
+ for alt in alts:
176
+ mutated_sequence = apply_mutation(sequence, offset, ref, alt)
177
+ sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
178
+
179
+ return sequences
180
+
181
+ def needleman_wunsch_alignment(query_seq, ref_seq):
182
+ """
183
+ Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
184
+ Use this position to represent the position of target sequence with mutations
185
+ """
186
+ # Needleman-Wunsch alignment
187
+ alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
188
+
189
+ # extract CIGAR object
190
+ cigar = alignment.cigar
191
+ cigar_string = cigar.decode.decode("utf-8")
192
+
193
+ # record ref_pos
194
+ ref_pos = 0
195
+
196
+ matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
197
+ max_num_before_equal = 0
198
+ max_equal_index = -1
199
+ total_before_max_equal = 0
200
+
201
+ for i, (num_str, op) in enumerate(matches):
202
+ num = int(num_str)
203
+ if op == '=':
204
+ if num > max_num_before_equal:
205
+ max_num_before_equal = num
206
+ max_equal_index = i
207
+ total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
208
+
209
+ ref_pos = total_before_max_equal
210
+
211
+ return ref_pos
212
+
213
+ def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
214
+ exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
215
+ # initialization
216
+ mutated_sequences = [ref_sequence]
217
+
218
+ # find mutations within interested region
219
+ mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
220
+ if mutations:
221
+ # find mutations
222
+ mutation_list = []
223
+ for mutation in mutations:
224
+ offset = mutation.POS - start
225
+ ref = mutation.REF
226
+ alts = mutation.ALT[:-1]
227
+ mutation_list.append((offset, ref, alts))
228
+
229
+ # replace reference sequence of mutation
230
+ mutated_sequences = construct_combinations(ref_sequence, mutation_list)
231
+
232
+ # find gRNA in ref_sequence or all mutated_sequences
233
+ targets = []
234
+ for seq in mutated_sequences:
235
+ len_sequence = len(seq)
236
+ dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
237
+ for i in range(len_sequence - len(pam) + 1):
238
+ if seq[i + 1:i + 3] == pam[1:]:
239
+ if i >= target_length:
240
+ target_seq = seq[i - target_length:i + 3]
241
+ pos = ref_sequence.find(target_seq)
242
+ if pos != -1:
243
+ is_mut = False
244
+ if strand == -1:
245
+ tar_start = end - pos - target_length - 2
246
+ else:
247
+ tar_start = start + pos
248
+ else:
249
+ is_mut = True
250
+ nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
251
+ if strand == -1:
252
+ tar_start = str(end - nw_pos - target_length - 2) + '*'
253
+ else:
254
+ tar_start = str(start + nw_pos) + '*'
255
+ gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
256
+ targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
257
+
258
+ # filter duplicated targets
259
+ unique_targets_set = set(tuple(element) for element in targets)
260
+ unique_targets = [list(element) for element in unique_targets_set]
261
+
262
+ return unique_targets
263
+
264
+ def format_prediction_output_with_mutation(targets, model_path):
265
+ model = MultiHeadAttention_model(input_shape=(23, 4))
266
+ model.load_weights(model_path)
267
+
268
+ formatted_data = []
269
+
270
+ for target in targets:
271
+ # Encode the gRNA sequence
272
+ encoded_seq = get_seqcode(target[0])
273
+
274
+
275
+ # Predict on-target efficiency using the model
276
+ prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
277
+ if prediction > 100:
278
+ prediction = 100
279
+
280
+ # Format output
281
+ gRNA = target[1]
282
+ exon_chr = target[2]
283
+ strand = target[3]
284
+ tar_start = target[4]
285
+ transcript_id = target[5]
286
+ exon_id = target[6]
287
+ gene_symbol = target[7]
288
+ is_mut = target[8]
289
+ formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
290
+ exon_id, target[0], gRNA, prediction, is_mut])
291
+
292
+ return formatted_data
293
+
294
+
295
+ def process_gene(gene_symbol, vcf_reader, model_path):
296
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
297
+ results = []
298
+ all_exons = [] # To accumulate all exons
299
+ all_gene_sequences = [] # To accumulate all gene sequences
300
+
301
+ if transcripts:
302
+ for transcript in transcripts:
303
+ Exons = transcript['Exon']
304
+ all_exons.extend(Exons) # Add all exons from this transcript to the list
305
+ transcript_id = transcript['id']
306
+
307
+ for Exon in Exons:
308
+ exon_id = Exon['id']
309
+ gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
310
+ if gene_sequence:
311
+ all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
312
+ exon_chr = Exon['seq_region_name']
313
+ start = Exon['start']
314
+ end = Exon['end']
315
+ strand = Exon['strand']
316
+
317
+ targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
318
+ transcript_id, exon_id, gene_symbol, vcf_reader)
319
+ if targets:
320
+ # Predict on-target efficiency for each gRNA site including mutations
321
+ formatted_data = format_prediction_output_with_mutation(targets, model_path)
322
+ results.extend(formatted_data)
323
+ else:
324
+ print(f"Failed to retrieve gene sequence for exon {exon_id}.")
325
+ else:
326
+ print("Failed to retrieve transcripts.")
327
+
328
+ # Sort results based on prediction score (assuming score is at the 8th index)
329
+ sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
330
+
331
+ # Return the sorted output, combined gene sequences, and all exons
332
+ return sorted_results, all_gene_sequences, all_exons
333
+
334
+
335
+ def create_genbank_features(data):
336
+ features = []
337
+
338
+ # If the input data is a DataFrame, convert it to a list of lists
339
+ if isinstance(data, pd.DataFrame):
340
+ formatted_data = data.values.tolist()
341
+ elif isinstance(data, list):
342
+ formatted_data = data
343
+ else:
344
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
345
+
346
+ for row in formatted_data:
347
+ try:
348
+ start = int(row[1])
349
+ end = int(row[2])
350
+ except ValueError as e:
351
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
352
+ continue
353
+
354
+ strand = 1 if row[3] == '+' else -1
355
+ location = FeatureLocation(start=start, end=end, strand=strand)
356
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
357
+ 'label': row[7], # Use gRNA as the label
358
+ 'note': f"Prediction: {row[8]}" # Include the prediction score
359
+ })
360
+ features.append(feature)
361
+
362
+ return features
363
+
364
+
365
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
366
+ # Ensure gene_sequence is a string before creating Seq object
367
+ if not isinstance(gene_sequence, str):
368
+ gene_sequence = str(gene_sequence)
369
+
370
+ features = create_genbank_features(df)
371
+
372
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
373
+ seq_obj = Seq(gene_sequence)
374
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
375
+ description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
376
+ record.annotations["molecule_type"] = "DNA"
377
+ SeqIO.write(record, output_path, "genbank")
378
+
379
+
380
+ def create_bed_file_from_df(df, output_path):
381
+ with open(output_path, 'w') as bed_file:
382
+ for index, row in df.iterrows():
383
+ chrom = row["Chr"]
384
+ start = int(row["Start Pos"])
385
+ end = int(row["End Pos"])
386
+ strand = '+' if row["Strand"] == '1' else '-'
387
+ gRNA = row["gRNA"]
388
+ score = str(row["Prediction"])
389
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
390
+ transcript_id = row["Transcript"]
391
+
392
+ # Writing only standard BED columns; additional columns can be appended as needed
393
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
394
+
395
+
396
+ def create_csv_from_df(df, output_path):
397
+ df.to_csv(output_path, index=False)
cas9on.py CHANGED
@@ -8,9 +8,7 @@ from Bio import SeqIO
8
  from Bio.SeqRecord import SeqRecord
9
  from Bio.SeqFeature import SeqFeature, FeatureLocation
10
  from Bio.Seq import Seq
11
- from keras.models import load_model
12
- import random
13
- import pyBigWig
14
 
15
  # configure GPUs
16
  for gpu in tf.config.list_physical_devices('GPU'):
 
8
  from Bio.SeqRecord import SeqRecord
9
  from Bio.SeqFeature import SeqFeature, FeatureLocation
10
  from Bio.Seq import Seq
11
+
 
 
12
 
13
  # configure GPUs
14
  for gpu in tf.config.list_physical_devices('GPU'):
requirements.txt CHANGED
@@ -4,5 +4,8 @@ pandas==1.5.2
4
  tensorflow==2.11.0
5
  tensorflow-probability==0.19.0
6
  plotly==5.18.0
 
 
 
7
  gtracks
8
  pyGenomeTracks
 
4
  tensorflow==2.11.0
5
  tensorflow-probability==0.19.0
6
  plotly==5.18.0
7
+ tabulate
8
+ cyvcf2
9
+ parasail
10
  gtracks
11
  pyGenomeTracks