Spaces:
Sleeping
Sleeping
supercat666
commited on
Commit
·
4a303ce
1
Parent(s):
69d7c1c
change cas9
Browse files- app.py +9 -8
- cas12lstm.py +188 -0
- cas12lstmvcf.py +287 -0
- cas9att.py +299 -0
- cas9attvcf.py +397 -0
- cas9on.py +1 -3
- requirements.txt +3 -0
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import tiger
|
3 |
-
import
|
|
|
4 |
import cas9off
|
5 |
import cas12
|
6 |
import pandas as pd
|
@@ -22,8 +23,8 @@ st.divider()
|
|
22 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
23 |
|
24 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
25 |
-
|
26 |
-
cas12_path = 'cas12_model/
|
27 |
|
28 |
#plot functions
|
29 |
def generate_coolbox_plot(bigwig_path, region, output_image_path):
|
@@ -182,8 +183,8 @@ if selected_model == 'Cas9':
|
|
182 |
# Process predictions
|
183 |
if predict_button and gene_symbol:
|
184 |
with st.spinner('Predicting... Please wait'):
|
185 |
-
predictions, gene_sequence, exons =
|
186 |
-
sorted_predictions = sorted(predictions
|
187 |
st.session_state['on_target_results'] = sorted_predictions
|
188 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
189 |
st.session_state['exons'] = exons # Store exon data
|
@@ -283,9 +284,9 @@ if selected_model == 'Cas9':
|
|
283 |
|
284 |
|
285 |
# Generate files
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
|
290 |
# Prepare an in-memory buffer for the ZIP file
|
291 |
zip_buffer = io.BytesIO()
|
|
|
1 |
import os
|
2 |
import tiger
|
3 |
+
import cas9att
|
4 |
+
import cas9attvcf
|
5 |
import cas9off
|
6 |
import cas12
|
7 |
import pandas as pd
|
|
|
23 |
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
|
24 |
|
25 |
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
|
26 |
+
cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.keras'
|
27 |
+
cas12_path = 'cas12_model/BiLSTM_Cpf1_weights.keras'
|
28 |
|
29 |
#plot functions
|
30 |
def generate_coolbox_plot(bigwig_path, region, output_image_path):
|
|
|
183 |
# Process predictions
|
184 |
if predict_button and gene_symbol:
|
185 |
with st.spinner('Predicting... Please wait'):
|
186 |
+
predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
|
187 |
+
sorted_predictions = sorted(predictions)[:10]
|
188 |
st.session_state['on_target_results'] = sorted_predictions
|
189 |
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
|
190 |
st.session_state['exons'] = exons # Store exon data
|
|
|
284 |
|
285 |
|
286 |
# Generate files
|
287 |
+
cas9att.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
|
288 |
+
cas9att.create_bed_file_from_df(df, bed_file_path)
|
289 |
+
cas9att.create_csv_from_df(df, csv_file_path)
|
290 |
|
291 |
# Prepare an in-memory buffer for the ZIP file
|
292 |
zip_buffer = io.BytesIO()
|
cas12lstm.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras import regularizers
|
3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
5 |
+
from keras.layers import Bidirectional, LSTM
|
6 |
+
from keras import Model
|
7 |
+
from keras.metrics import MeanSquaredError
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import requests
|
13 |
+
from functools import reduce
|
14 |
+
from operator import add
|
15 |
+
import tabulate
|
16 |
+
from difflib import SequenceMatcher
|
17 |
+
|
18 |
+
import cyvcf2
|
19 |
+
import parasail
|
20 |
+
|
21 |
+
import re
|
22 |
+
|
23 |
+
ntmap = {'A': (1, 0, 0, 0),
|
24 |
+
'C': (0, 1, 0, 0),
|
25 |
+
'G': (0, 0, 1, 0),
|
26 |
+
'T': (0, 0, 0, 1)
|
27 |
+
}
|
28 |
+
|
29 |
+
def get_seqcode(seq):
|
30 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
31 |
+
|
32 |
+
def BiLSTM_model(input_shape):
|
33 |
+
input = Input(shape=input_shape)
|
34 |
+
|
35 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
36 |
+
pool1 = AveragePooling1D(2)(conv1)
|
37 |
+
drop1 = Dropout(0.1)(pool1)
|
38 |
+
|
39 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
40 |
+
pool2 = AveragePooling1D(2)(conv2)
|
41 |
+
drop2 = Dropout(0.1)(pool2)
|
42 |
+
|
43 |
+
lstm1 = Bidirectional(LSTM(128,
|
44 |
+
dropout=0.1,
|
45 |
+
activation='tanh',
|
46 |
+
return_sequences=True,
|
47 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
48 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
49 |
+
|
50 |
+
dense1 = Dense(128,
|
51 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
52 |
+
bias_regularizer=regularizers.l2(1e-4),
|
53 |
+
activation="relu")(avgpool)
|
54 |
+
drop3 = Dropout(0.1)(dense1)
|
55 |
+
|
56 |
+
dense2 = Dense(32,
|
57 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
58 |
+
bias_regularizer=regularizers.l2(1e-4),
|
59 |
+
activation="relu")(drop3)
|
60 |
+
drop4 = Dropout(0.1)(dense2)
|
61 |
+
|
62 |
+
dense3 = Dense(32,
|
63 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
64 |
+
bias_regularizer=regularizers.l2(1e-4),
|
65 |
+
activation="relu")(drop4)
|
66 |
+
drop5 = Dropout(0.1)(dense3)
|
67 |
+
|
68 |
+
output = Dense(1, activation="linear")(drop5)
|
69 |
+
|
70 |
+
model = Model(inputs=[input], outputs=[output])
|
71 |
+
return model
|
72 |
+
|
73 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
74 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
75 |
+
response = requests.get(url)
|
76 |
+
if response.status_code == 200:
|
77 |
+
gene_data = response.json()
|
78 |
+
if 'Transcript' in gene_data:
|
79 |
+
return gene_data['Transcript']
|
80 |
+
else:
|
81 |
+
print("No transcripts found for gene:", gene_symbol)
|
82 |
+
return None
|
83 |
+
else:
|
84 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
85 |
+
return None
|
86 |
+
|
87 |
+
def fetch_ensembl_sequence(transcript_id):
|
88 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
89 |
+
response = requests.get(url)
|
90 |
+
if response.status_code == 200:
|
91 |
+
sequence_data = response.json()
|
92 |
+
if 'seq' in sequence_data:
|
93 |
+
return sequence_data['seq']
|
94 |
+
else:
|
95 |
+
print("No sequence found for transcript:", transcript_id)
|
96 |
+
return None
|
97 |
+
else:
|
98 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
99 |
+
return None
|
100 |
+
|
101 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="TTTN", target_length=34):
|
102 |
+
targets = []
|
103 |
+
len_sequence = len(sequence)
|
104 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
105 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
106 |
+
|
107 |
+
for i in range(len_sequence - target_length + 1):
|
108 |
+
target_seq = sequence[i:i + target_length]
|
109 |
+
if target_seq[4:7] == 'TTT':
|
110 |
+
if strand == -1:
|
111 |
+
tar_start = end - i - target_length + 1
|
112 |
+
tar_end = end -i
|
113 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
114 |
+
else:
|
115 |
+
tar_start = start + i
|
116 |
+
tar_end = start + i + target_length - 1
|
117 |
+
#seq_in_ref = target_seq
|
118 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
119 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
120 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
121 |
+
return targets
|
122 |
+
|
123 |
+
def format_prediction_output(targets, model_path):
|
124 |
+
# Loading weights for the model
|
125 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
126 |
+
Crispr_BiLSTM.load_weights(model_path)
|
127 |
+
|
128 |
+
formatted_data = []
|
129 |
+
for target in targets:
|
130 |
+
# Predict
|
131 |
+
encoded_seq = get_seqcode(target[0])
|
132 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
133 |
+
if prediction > 100:
|
134 |
+
prediction = 100
|
135 |
+
|
136 |
+
# Format output
|
137 |
+
gRNA = target[1]
|
138 |
+
chr = target[2]
|
139 |
+
start = target[3]
|
140 |
+
end = target[4]
|
141 |
+
strand = target[5]
|
142 |
+
transcript_id = target[6]
|
143 |
+
exon_id = target[7]
|
144 |
+
#seq_in_ref = target[8]
|
145 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction])
|
146 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
147 |
+
|
148 |
+
return formatted_data
|
149 |
+
|
150 |
+
|
151 |
+
def process_gene(gene_symbol, model_path):
|
152 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
153 |
+
results = []
|
154 |
+
all_exons = [] # To accumulate all exons
|
155 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
156 |
+
|
157 |
+
if transcripts:
|
158 |
+
for transcript in transcripts:
|
159 |
+
Exons = transcript['Exon']
|
160 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
161 |
+
transcript_id = transcript['id']
|
162 |
+
|
163 |
+
for Exon in Exons:
|
164 |
+
exon_id = Exon['id']
|
165 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
166 |
+
if gene_sequence:
|
167 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
168 |
+
chr = Exon['seq_region_name']
|
169 |
+
start = Exon['start']
|
170 |
+
end = Exon['end']
|
171 |
+
strand = Exon['strand']
|
172 |
+
|
173 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
174 |
+
if targets:
|
175 |
+
# Predict on-target efficiency for each gRNA site
|
176 |
+
formatted_data = format_prediction_output(targets, model_path)
|
177 |
+
results.extend(formatted_data) # Flatten the results
|
178 |
+
else:
|
179 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
180 |
+
else:
|
181 |
+
print("Failed to retrieve transcripts.")
|
182 |
+
|
183 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
184 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
185 |
+
|
186 |
+
# Return the sorted output, combined gene sequences, and all exons
|
187 |
+
return sorted_results, all_gene_sequences, all_exons
|
188 |
+
|
cas12lstmvcf.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from keras import regularizers
|
3 |
+
from keras.layers import Input, Dense, Dropout, Activation, Conv1D
|
4 |
+
from keras.layers import GlobalAveragePooling1D, AveragePooling1D
|
5 |
+
from keras.layers import Bidirectional, LSTM
|
6 |
+
from keras import Model
|
7 |
+
from keras.metrics import MeanSquaredError
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import requests
|
13 |
+
from functools import reduce
|
14 |
+
from operator import add
|
15 |
+
import tabulate
|
16 |
+
from difflib import SequenceMatcher
|
17 |
+
|
18 |
+
import cyvcf2
|
19 |
+
import parasail
|
20 |
+
|
21 |
+
import re
|
22 |
+
|
23 |
+
ntmap = {'A': (1, 0, 0, 0),
|
24 |
+
'C': (0, 1, 0, 0),
|
25 |
+
'G': (0, 0, 1, 0),
|
26 |
+
'T': (0, 0, 0, 1)
|
27 |
+
}
|
28 |
+
|
29 |
+
def get_seqcode(seq):
|
30 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
31 |
+
|
32 |
+
def BiLSTM_model(input_shape):
|
33 |
+
input = Input(shape=input_shape)
|
34 |
+
|
35 |
+
conv1 = Conv1D(128, 5, activation="relu")(input)
|
36 |
+
pool1 = AveragePooling1D(2)(conv1)
|
37 |
+
drop1 = Dropout(0.1)(pool1)
|
38 |
+
|
39 |
+
conv2 = Conv1D(128, 5, activation="relu")(drop1)
|
40 |
+
pool2 = AveragePooling1D(2)(conv2)
|
41 |
+
drop2 = Dropout(0.1)(pool2)
|
42 |
+
|
43 |
+
lstm1 = Bidirectional(LSTM(128,
|
44 |
+
dropout=0.1,
|
45 |
+
activation='tanh',
|
46 |
+
return_sequences=True,
|
47 |
+
kernel_regularizer=regularizers.l2(1e-4)))(drop2)
|
48 |
+
avgpool = GlobalAveragePooling1D()(lstm1)
|
49 |
+
|
50 |
+
dense1 = Dense(128,
|
51 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
52 |
+
bias_regularizer=regularizers.l2(1e-4),
|
53 |
+
activation="relu")(avgpool)
|
54 |
+
drop3 = Dropout(0.1)(dense1)
|
55 |
+
|
56 |
+
dense2 = Dense(32,
|
57 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
58 |
+
bias_regularizer=regularizers.l2(1e-4),
|
59 |
+
activation="relu")(drop3)
|
60 |
+
drop4 = Dropout(0.1)(dense2)
|
61 |
+
|
62 |
+
dense3 = Dense(32,
|
63 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
64 |
+
bias_regularizer=regularizers.l2(1e-4),
|
65 |
+
activation="relu")(drop4)
|
66 |
+
drop5 = Dropout(0.1)(dense3)
|
67 |
+
|
68 |
+
output = Dense(1, activation="linear")(drop5)
|
69 |
+
|
70 |
+
model = Model(inputs=[input], outputs=[output])
|
71 |
+
return model
|
72 |
+
|
73 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
74 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
75 |
+
response = requests.get(url)
|
76 |
+
if response.status_code == 200:
|
77 |
+
gene_data = response.json()
|
78 |
+
if 'Transcript' in gene_data:
|
79 |
+
return gene_data['Transcript']
|
80 |
+
else:
|
81 |
+
print("No transcripts found for gene:", gene_symbol)
|
82 |
+
return None
|
83 |
+
else:
|
84 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
85 |
+
return None
|
86 |
+
|
87 |
+
def fetch_ensembl_sequence(transcript_id):
|
88 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
89 |
+
response = requests.get(url)
|
90 |
+
if response.status_code == 200:
|
91 |
+
sequence_data = response.json()
|
92 |
+
if 'seq' in sequence_data:
|
93 |
+
return sequence_data['seq']
|
94 |
+
else:
|
95 |
+
print("No sequence found for transcript:", transcript_id)
|
96 |
+
return None
|
97 |
+
else:
|
98 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
99 |
+
return None
|
100 |
+
|
101 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
102 |
+
"""
|
103 |
+
Apply a single mutation to the sequence.
|
104 |
+
"""
|
105 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
106 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
107 |
+
|
108 |
+
elif len(ref) < len(alt): # Insertion
|
109 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
110 |
+
|
111 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
112 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
113 |
+
|
114 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
115 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
116 |
+
|
117 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
118 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
119 |
+
|
120 |
+
return mutated_seq
|
121 |
+
|
122 |
+
|
123 |
+
def construct_combinations(sequence, mutations):
|
124 |
+
"""
|
125 |
+
Construct all combinations of mutations.
|
126 |
+
mutations is a list of tuples (position, ref, [alts])
|
127 |
+
"""
|
128 |
+
if not mutations:
|
129 |
+
return [sequence]
|
130 |
+
|
131 |
+
# Take the first mutation and recursively construct combinations for the rest
|
132 |
+
first_mutation = mutations[0]
|
133 |
+
rest_mutations = mutations[1:]
|
134 |
+
offset, ref, alts = first_mutation
|
135 |
+
|
136 |
+
sequences = []
|
137 |
+
for alt in alts:
|
138 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
139 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
140 |
+
|
141 |
+
return sequences
|
142 |
+
|
143 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
144 |
+
"""
|
145 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
146 |
+
Use this position to represent the position of target sequence with mutations
|
147 |
+
"""
|
148 |
+
# Needleman-Wunsch alignment
|
149 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
150 |
+
|
151 |
+
# extract CIGAR object
|
152 |
+
cigar = alignment.cigar
|
153 |
+
cigar_string = cigar.decode.decode("utf-8")
|
154 |
+
|
155 |
+
# record ref_pos
|
156 |
+
ref_pos = 0
|
157 |
+
|
158 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
159 |
+
max_num_before_equal = 0
|
160 |
+
max_equal_index = -1
|
161 |
+
total_before_max_equal = 0
|
162 |
+
|
163 |
+
for i, (num_str, op) in enumerate(matches):
|
164 |
+
num = int(num_str)
|
165 |
+
if op == '=':
|
166 |
+
if num > max_num_before_equal:
|
167 |
+
max_num_before_equal = num
|
168 |
+
max_equal_index = i
|
169 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
170 |
+
|
171 |
+
ref_pos = total_before_max_equal
|
172 |
+
|
173 |
+
return ref_pos
|
174 |
+
|
175 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
176 |
+
exon_id, gene_symbol, vcf_reader, pam="TTTN", target_length=34):
|
177 |
+
# initialization
|
178 |
+
mutated_sequences = [ref_sequence]
|
179 |
+
|
180 |
+
# find mutations within interested region
|
181 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
182 |
+
if mutations:
|
183 |
+
# find mutations
|
184 |
+
mutation_list = []
|
185 |
+
for mutation in mutations:
|
186 |
+
offset = mutation.POS - start
|
187 |
+
ref = mutation.REF
|
188 |
+
alts = mutation.ALT[:-1]
|
189 |
+
mutation_list.append((offset, ref, alts))
|
190 |
+
|
191 |
+
# replace reference sequence of mutation
|
192 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
193 |
+
|
194 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
195 |
+
targets = []
|
196 |
+
for seq in mutated_sequences:
|
197 |
+
len_sequence = len(seq)
|
198 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
199 |
+
for i in range(len_sequence - target_length + 1):
|
200 |
+
target_seq = seq[i:i + target_length]
|
201 |
+
if target_seq[4:7] == 'TTT':
|
202 |
+
pos = ref_sequence.find(target_seq)
|
203 |
+
if pos != -1:
|
204 |
+
is_mut = False
|
205 |
+
if strand == -1:
|
206 |
+
tar_start = end - pos - target_length + 1
|
207 |
+
else:
|
208 |
+
tar_start = start + pos
|
209 |
+
else:
|
210 |
+
is_mut = True
|
211 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
212 |
+
if strand == -1:
|
213 |
+
tar_start = str(end - nw_pos - target_length + 1) + '*'
|
214 |
+
else:
|
215 |
+
tar_start = str(start + nw_pos) + '*'
|
216 |
+
gRNA = ''.join([dnatorna[base] for base in target_seq[8:28]])
|
217 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
218 |
+
|
219 |
+
# filter duplicated targets
|
220 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
221 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
222 |
+
|
223 |
+
return unique_targets
|
224 |
+
|
225 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
226 |
+
Crispr_BiLSTM = BiLSTM_model(input_shape=(34, 4))
|
227 |
+
Crispr_BiLSTM.load_weights(model_path)
|
228 |
+
|
229 |
+
formatted_data = []
|
230 |
+
for target in targets:
|
231 |
+
# Predict
|
232 |
+
encoded_seq = get_seqcode(target[0])
|
233 |
+
prediction = float(list(Crispr_BiLSTM.predict(encoded_seq, verbose=0)[0])[0])
|
234 |
+
if prediction > 100:
|
235 |
+
prediction = 100
|
236 |
+
|
237 |
+
# Format output
|
238 |
+
gRNA = target[1]
|
239 |
+
exon_chr = target[2]
|
240 |
+
strand = target[3]
|
241 |
+
tar_start = target[4]
|
242 |
+
transcript_id = target[5]
|
243 |
+
exon_id = target[6]
|
244 |
+
gene_symbol = target[7]
|
245 |
+
is_mut = target[8]
|
246 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id, exon_id, target[0], gRNA, prediction, is_mut])
|
247 |
+
|
248 |
+
return formatted_data
|
249 |
+
|
250 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
251 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
252 |
+
results = []
|
253 |
+
all_exons = [] # To accumulate all exons
|
254 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
255 |
+
|
256 |
+
if transcripts:
|
257 |
+
for transcript in transcripts:
|
258 |
+
Exons = transcript['Exon']
|
259 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
260 |
+
transcript_id = transcript['id']
|
261 |
+
|
262 |
+
for Exon in Exons:
|
263 |
+
exon_id = Exon['id']
|
264 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
265 |
+
if gene_sequence:
|
266 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
267 |
+
exon_chr = Exon['seq_region_name']
|
268 |
+
start = Exon['start']
|
269 |
+
end = Exon['end']
|
270 |
+
strand = Exon['strand']
|
271 |
+
|
272 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand, transcript_id, exon_id, gene_symbol, vcf_reader)
|
273 |
+
if targets:
|
274 |
+
# Predict on-target efficiency for each gRNA site
|
275 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
276 |
+
results.extend(formatted_data) # Flatten the results
|
277 |
+
else:
|
278 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
279 |
+
else:
|
280 |
+
print("Failed to retrieve transcripts.")
|
281 |
+
|
282 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
283 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
284 |
+
|
285 |
+
# Return the sorted output, combined gene sequences, and all exons
|
286 |
+
return sorted_results, all_gene_sequences, all_exons
|
287 |
+
|
cas9att.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import tensorflow as tf
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from operator import add
|
6 |
+
from functools import reduce
|
7 |
+
import random
|
8 |
+
import tabulate
|
9 |
+
|
10 |
+
from keras import Model
|
11 |
+
from keras import regularizers
|
12 |
+
from keras.optimizers import Adam
|
13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
17 |
+
from keras.models import load_model
|
18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
19 |
+
from Bio import SeqIO
|
20 |
+
from Bio.SeqRecord import SeqRecord
|
21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
22 |
+
from Bio.Seq import Seq
|
23 |
+
|
24 |
+
import cyvcf2
|
25 |
+
import parasail
|
26 |
+
|
27 |
+
import re
|
28 |
+
|
29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
30 |
+
'C': (0, 1, 0, 0),
|
31 |
+
'G': (0, 0, 1, 0),
|
32 |
+
'T': (0, 0, 0, 1)
|
33 |
+
}
|
34 |
+
|
35 |
+
def get_seqcode(seq):
|
36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
37 |
+
|
38 |
+
class PositionalEncoding(Layer):
|
39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
40 |
+
super(PositionalEncoding, self).__init__()
|
41 |
+
self.sequence_len = sequence_len
|
42 |
+
self.embedding_dim = embedding_dim
|
43 |
+
|
44 |
+
def call(self, x):
|
45 |
+
|
46 |
+
position_embedding = np.array([
|
47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
48 |
+
for pos in range(self.sequence_len)])
|
49 |
+
|
50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
53 |
+
|
54 |
+
return position_embedding+x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
config.update({
|
59 |
+
'sequence_len' : self.sequence_len,
|
60 |
+
'embedding_dim' : self.embedding_dim,
|
61 |
+
})
|
62 |
+
return config
|
63 |
+
|
64 |
+
def MultiHeadAttention_model(input_shape):
|
65 |
+
input = Input(shape=input_shape)
|
66 |
+
|
67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
69 |
+
drop1 = Dropout(0.4)(pool1)
|
70 |
+
|
71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
73 |
+
drop2 = Dropout(0.4)(pool2)
|
74 |
+
|
75 |
+
lstm = Bidirectional(LSTM(128,
|
76 |
+
dropout=0.5,
|
77 |
+
activation='tanh',
|
78 |
+
return_sequences=True,
|
79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
80 |
+
|
81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
82 |
+
atten = MultiHeadAttention(num_heads=2,
|
83 |
+
key_dim=64,
|
84 |
+
dropout=0.2,
|
85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
86 |
+
|
87 |
+
flat = Flatten()(atten)
|
88 |
+
|
89 |
+
dense1 = Dense(512,
|
90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
92 |
+
activation="relu")(flat)
|
93 |
+
drop3 = Dropout(0.1)(dense1)
|
94 |
+
|
95 |
+
dense2 = Dense(128,
|
96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
98 |
+
activation="relu")(drop3)
|
99 |
+
drop4 = Dropout(0.1)(dense2)
|
100 |
+
|
101 |
+
dense3 = Dense(256,
|
102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
104 |
+
activation="relu")(drop4)
|
105 |
+
drop5 = Dropout(0.1)(dense3)
|
106 |
+
|
107 |
+
output = Dense(1, activation="linear")(drop5)
|
108 |
+
|
109 |
+
model = Model(inputs=[input], outputs=[output])
|
110 |
+
return model
|
111 |
+
|
112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
114 |
+
response = requests.get(url)
|
115 |
+
if response.status_code == 200:
|
116 |
+
gene_data = response.json()
|
117 |
+
if 'Transcript' in gene_data:
|
118 |
+
return gene_data['Transcript']
|
119 |
+
else:
|
120 |
+
print("No transcripts found for gene:", gene_symbol)
|
121 |
+
return None
|
122 |
+
else:
|
123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def fetch_ensembl_sequence(transcript_id):
|
127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
128 |
+
response = requests.get(url)
|
129 |
+
if response.status_code == 200:
|
130 |
+
sequence_data = response.json()
|
131 |
+
if 'seq' in sequence_data:
|
132 |
+
return sequence_data['seq']
|
133 |
+
else:
|
134 |
+
print("No sequence found for transcript:", transcript_id)
|
135 |
+
return None
|
136 |
+
else:
|
137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
138 |
+
return None
|
139 |
+
|
140 |
+
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
|
141 |
+
targets = []
|
142 |
+
len_sequence = len(sequence)
|
143 |
+
#complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
|
144 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
145 |
+
|
146 |
+
for i in range(len_sequence - len(pam) + 1):
|
147 |
+
if sequence[i + 1:i + 3] == pam[1:]:
|
148 |
+
if i >= target_length:
|
149 |
+
target_seq = sequence[i - target_length:i + 3]
|
150 |
+
if strand == -1:
|
151 |
+
tar_start = end - (i + 2)
|
152 |
+
tar_end = end - (i - target_length)
|
153 |
+
#seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
|
154 |
+
else:
|
155 |
+
tar_start = start + i - target_length
|
156 |
+
tar_end = start + i + 3 - 1
|
157 |
+
#seq_in_ref = target_seq
|
158 |
+
gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
|
159 |
+
#targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
|
160 |
+
targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])
|
161 |
+
|
162 |
+
return targets
|
163 |
+
|
164 |
+
# Function to predict on-target efficiency and format output
|
165 |
+
def format_prediction_output(targets, model_path):
|
166 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
167 |
+
model.load_weights(model_path)
|
168 |
+
|
169 |
+
formatted_data = []
|
170 |
+
|
171 |
+
for target in targets:
|
172 |
+
# Encode the gRNA sequence
|
173 |
+
encoded_seq = get_seqcode(target[0])
|
174 |
+
|
175 |
+
# Predict on-target efficiency using the model
|
176 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
177 |
+
if prediction > 100:
|
178 |
+
prediction = 100
|
179 |
+
|
180 |
+
# Format output
|
181 |
+
gRNA = target[1]
|
182 |
+
chr = target[2]
|
183 |
+
start = target[3]
|
184 |
+
end = target[4]
|
185 |
+
strand = target[5]
|
186 |
+
transcript_id = target[6]
|
187 |
+
exon_id = target[7]
|
188 |
+
#seq_in_ref = target[8]
|
189 |
+
#formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
|
190 |
+
formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])
|
191 |
+
|
192 |
+
return formatted_data
|
193 |
+
|
194 |
+
def process_gene(gene_symbol, model_path):
|
195 |
+
# Fetch transcripts for the given gene symbol
|
196 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
197 |
+
results = []
|
198 |
+
all_exons = [] # To accumulate all exons
|
199 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
200 |
+
|
201 |
+
if transcripts:
|
202 |
+
for transcript in transcripts:
|
203 |
+
Exons = transcript['Exon']
|
204 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
205 |
+
transcript_id = transcript['id']
|
206 |
+
|
207 |
+
for exon in Exons:
|
208 |
+
exon_id = exon['id']
|
209 |
+
gene_sequence = fetch_ensembl_sequence(exon_id)
|
210 |
+
if gene_sequence:
|
211 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
212 |
+
start = exon['start']
|
213 |
+
end = exon['end']
|
214 |
+
strand = exon['strand']
|
215 |
+
chr = exon['seq_region_name']
|
216 |
+
# Find potential CRISPR targets within the exon
|
217 |
+
targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
|
218 |
+
if targets:
|
219 |
+
# Format the prediction output for the targets found
|
220 |
+
formatted_data = format_prediction_output(targets, model_path)
|
221 |
+
results.extend(formatted_data) # Append results
|
222 |
+
else:
|
223 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
224 |
+
else:
|
225 |
+
print("Failed to retrieve transcripts.")
|
226 |
+
|
227 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
228 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
229 |
+
|
230 |
+
# Return the sorted output, combined gene sequences, and all exons
|
231 |
+
return sorted_results, all_gene_sequences, all_exons
|
232 |
+
|
233 |
+
|
234 |
+
def create_genbank_features(data):
|
235 |
+
features = []
|
236 |
+
|
237 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
238 |
+
if isinstance(data, pd.DataFrame):
|
239 |
+
formatted_data = data.values.tolist()
|
240 |
+
elif isinstance(data, list):
|
241 |
+
formatted_data = data
|
242 |
+
else:
|
243 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
244 |
+
|
245 |
+
for row in formatted_data:
|
246 |
+
try:
|
247 |
+
start = int(row[1])
|
248 |
+
end = int(row[2])
|
249 |
+
except ValueError as e:
|
250 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
251 |
+
continue
|
252 |
+
|
253 |
+
strand = 1 if row[3] == '+' else -1
|
254 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
255 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
256 |
+
'label': row[7], # Use gRNA as the label
|
257 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
258 |
+
})
|
259 |
+
features.append(feature)
|
260 |
+
|
261 |
+
return features
|
262 |
+
|
263 |
+
|
264 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
265 |
+
# Ensure gene_sequence is a string before creating Seq object
|
266 |
+
if not isinstance(gene_sequence, str):
|
267 |
+
gene_sequence = str(gene_sequence)
|
268 |
+
|
269 |
+
features = create_genbank_features(df)
|
270 |
+
|
271 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
272 |
+
seq_obj = Seq(gene_sequence)
|
273 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
274 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
275 |
+
record.annotations["molecule_type"] = "DNA"
|
276 |
+
SeqIO.write(record, output_path, "genbank")
|
277 |
+
|
278 |
+
|
279 |
+
def create_bed_file_from_df(df, output_path):
|
280 |
+
with open(output_path, 'w') as bed_file:
|
281 |
+
for index, row in df.iterrows():
|
282 |
+
chrom = row["Chr"]
|
283 |
+
start = int(row["Start Pos"])
|
284 |
+
end = int(row["End Pos"])
|
285 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
286 |
+
gRNA = row["gRNA"]
|
287 |
+
score = str(row["Prediction"])
|
288 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
289 |
+
transcript_id = row["Transcript"]
|
290 |
+
|
291 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
292 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
293 |
+
|
294 |
+
|
295 |
+
def create_csv_from_df(df, output_path):
|
296 |
+
df.to_csv(output_path, index=False)
|
297 |
+
|
298 |
+
|
299 |
+
|
cas9attvcf.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import tensorflow as tf
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from operator import add
|
6 |
+
from functools import reduce
|
7 |
+
import random
|
8 |
+
import tabulate
|
9 |
+
|
10 |
+
from keras import Model
|
11 |
+
from keras import regularizers
|
12 |
+
from keras.optimizers import Adam
|
13 |
+
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
|
14 |
+
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
|
15 |
+
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
|
16 |
+
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
|
17 |
+
from keras.models import load_model
|
18 |
+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
19 |
+
from Bio import SeqIO
|
20 |
+
from Bio.SeqRecord import SeqRecord
|
21 |
+
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
22 |
+
from Bio.Seq import Seq
|
23 |
+
|
24 |
+
import cyvcf2
|
25 |
+
import parasail
|
26 |
+
|
27 |
+
import re
|
28 |
+
|
29 |
+
ntmap = {'A': (1, 0, 0, 0),
|
30 |
+
'C': (0, 1, 0, 0),
|
31 |
+
'G': (0, 0, 1, 0),
|
32 |
+
'T': (0, 0, 0, 1)
|
33 |
+
}
|
34 |
+
|
35 |
+
def get_seqcode(seq):
|
36 |
+
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
|
37 |
+
|
38 |
+
class PositionalEncoding(Layer):
|
39 |
+
def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
|
40 |
+
super(PositionalEncoding, self).__init__()
|
41 |
+
self.sequence_len = sequence_len
|
42 |
+
self.embedding_dim = embedding_dim
|
43 |
+
|
44 |
+
def call(self, x):
|
45 |
+
|
46 |
+
position_embedding = np.array([
|
47 |
+
[pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
|
48 |
+
for pos in range(self.sequence_len)])
|
49 |
+
|
50 |
+
position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i
|
51 |
+
position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1
|
52 |
+
position_embedding = tf.cast(position_embedding, dtype=tf.float32)
|
53 |
+
|
54 |
+
return position_embedding+x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
config.update({
|
59 |
+
'sequence_len' : self.sequence_len,
|
60 |
+
'embedding_dim' : self.embedding_dim,
|
61 |
+
})
|
62 |
+
return config
|
63 |
+
|
64 |
+
def MultiHeadAttention_model(input_shape):
|
65 |
+
input = Input(shape=input_shape)
|
66 |
+
|
67 |
+
conv1 = Conv1D(256, 3, activation="relu")(input)
|
68 |
+
pool1 = AveragePooling1D(2)(conv1)
|
69 |
+
drop1 = Dropout(0.4)(pool1)
|
70 |
+
|
71 |
+
conv2 = Conv1D(256, 3, activation="relu")(drop1)
|
72 |
+
pool2 = AveragePooling1D(2)(conv2)
|
73 |
+
drop2 = Dropout(0.4)(pool2)
|
74 |
+
|
75 |
+
lstm = Bidirectional(LSTM(128,
|
76 |
+
dropout=0.5,
|
77 |
+
activation='tanh',
|
78 |
+
return_sequences=True,
|
79 |
+
kernel_regularizer=regularizers.l2(0.01)))(drop2)
|
80 |
+
|
81 |
+
pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
|
82 |
+
atten = MultiHeadAttention(num_heads=2,
|
83 |
+
key_dim=64,
|
84 |
+
dropout=0.2,
|
85 |
+
kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)
|
86 |
+
|
87 |
+
flat = Flatten()(atten)
|
88 |
+
|
89 |
+
dense1 = Dense(512,
|
90 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
91 |
+
bias_regularizer=regularizers.l2(1e-4),
|
92 |
+
activation="relu")(flat)
|
93 |
+
drop3 = Dropout(0.1)(dense1)
|
94 |
+
|
95 |
+
dense2 = Dense(128,
|
96 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
97 |
+
bias_regularizer=regularizers.l2(1e-4),
|
98 |
+
activation="relu")(drop3)
|
99 |
+
drop4 = Dropout(0.1)(dense2)
|
100 |
+
|
101 |
+
dense3 = Dense(256,
|
102 |
+
kernel_regularizer=regularizers.l2(1e-4),
|
103 |
+
bias_regularizer=regularizers.l2(1e-4),
|
104 |
+
activation="relu")(drop4)
|
105 |
+
drop5 = Dropout(0.1)(dense3)
|
106 |
+
|
107 |
+
output = Dense(1, activation="linear")(drop5)
|
108 |
+
|
109 |
+
model = Model(inputs=[input], outputs=[output])
|
110 |
+
return model
|
111 |
+
|
112 |
+
def fetch_ensembl_transcripts(gene_symbol):
|
113 |
+
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
|
114 |
+
response = requests.get(url)
|
115 |
+
if response.status_code == 200:
|
116 |
+
gene_data = response.json()
|
117 |
+
if 'Transcript' in gene_data:
|
118 |
+
return gene_data['Transcript']
|
119 |
+
else:
|
120 |
+
print("No transcripts found for gene:", gene_symbol)
|
121 |
+
return None
|
122 |
+
else:
|
123 |
+
print(f"Error fetching gene data from Ensembl: {response.text}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def fetch_ensembl_sequence(transcript_id):
|
127 |
+
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
|
128 |
+
response = requests.get(url)
|
129 |
+
if response.status_code == 200:
|
130 |
+
sequence_data = response.json()
|
131 |
+
if 'seq' in sequence_data:
|
132 |
+
return sequence_data['seq']
|
133 |
+
else:
|
134 |
+
print("No sequence found for transcript:", transcript_id)
|
135 |
+
return None
|
136 |
+
else:
|
137 |
+
print(f"Error fetching sequence data from Ensembl: {response.text}")
|
138 |
+
return None
|
139 |
+
|
140 |
+
def apply_mutation(ref_sequence, offset, ref, alt):
|
141 |
+
"""
|
142 |
+
Apply a single mutation to the sequence.
|
143 |
+
"""
|
144 |
+
if len(ref) == len(alt) and alt != "*": # SNP
|
145 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]
|
146 |
+
|
147 |
+
elif len(ref) < len(alt): # Insertion
|
148 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]
|
149 |
+
|
150 |
+
elif len(ref) == len(alt) and alt == "*": # Deletion
|
151 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]
|
152 |
+
|
153 |
+
elif len(ref) > len(alt) and alt != "*": # Deletion
|
154 |
+
mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]
|
155 |
+
|
156 |
+
elif len(ref) > len(alt) and alt == "*": # Deletion
|
157 |
+
mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]
|
158 |
+
|
159 |
+
return mutated_seq
|
160 |
+
|
161 |
+
def construct_combinations(sequence, mutations):
|
162 |
+
"""
|
163 |
+
Construct all combinations of mutations.
|
164 |
+
mutations is a list of tuples (position, ref, [alts])
|
165 |
+
"""
|
166 |
+
if not mutations:
|
167 |
+
return [sequence]
|
168 |
+
|
169 |
+
# Take the first mutation and recursively construct combinations for the rest
|
170 |
+
first_mutation = mutations[0]
|
171 |
+
rest_mutations = mutations[1:]
|
172 |
+
offset, ref, alts = first_mutation
|
173 |
+
|
174 |
+
sequences = []
|
175 |
+
for alt in alts:
|
176 |
+
mutated_sequence = apply_mutation(sequence, offset, ref, alt)
|
177 |
+
sequences.extend(construct_combinations(mutated_sequence, rest_mutations))
|
178 |
+
|
179 |
+
return sequences
|
180 |
+
|
181 |
+
def needleman_wunsch_alignment(query_seq, ref_seq):
|
182 |
+
"""
|
183 |
+
Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
|
184 |
+
Use this position to represent the position of target sequence with mutations
|
185 |
+
"""
|
186 |
+
# Needleman-Wunsch alignment
|
187 |
+
alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)
|
188 |
+
|
189 |
+
# extract CIGAR object
|
190 |
+
cigar = alignment.cigar
|
191 |
+
cigar_string = cigar.decode.decode("utf-8")
|
192 |
+
|
193 |
+
# record ref_pos
|
194 |
+
ref_pos = 0
|
195 |
+
|
196 |
+
matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
|
197 |
+
max_num_before_equal = 0
|
198 |
+
max_equal_index = -1
|
199 |
+
total_before_max_equal = 0
|
200 |
+
|
201 |
+
for i, (num_str, op) in enumerate(matches):
|
202 |
+
num = int(num_str)
|
203 |
+
if op == '=':
|
204 |
+
if num > max_num_before_equal:
|
205 |
+
max_num_before_equal = num
|
206 |
+
max_equal_index = i
|
207 |
+
total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))
|
208 |
+
|
209 |
+
ref_pos = total_before_max_equal
|
210 |
+
|
211 |
+
return ref_pos
|
212 |
+
|
213 |
+
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
|
214 |
+
exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
|
215 |
+
# initialization
|
216 |
+
mutated_sequences = [ref_sequence]
|
217 |
+
|
218 |
+
# find mutations within interested region
|
219 |
+
mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
|
220 |
+
if mutations:
|
221 |
+
# find mutations
|
222 |
+
mutation_list = []
|
223 |
+
for mutation in mutations:
|
224 |
+
offset = mutation.POS - start
|
225 |
+
ref = mutation.REF
|
226 |
+
alts = mutation.ALT[:-1]
|
227 |
+
mutation_list.append((offset, ref, alts))
|
228 |
+
|
229 |
+
# replace reference sequence of mutation
|
230 |
+
mutated_sequences = construct_combinations(ref_sequence, mutation_list)
|
231 |
+
|
232 |
+
# find gRNA in ref_sequence or all mutated_sequences
|
233 |
+
targets = []
|
234 |
+
for seq in mutated_sequences:
|
235 |
+
len_sequence = len(seq)
|
236 |
+
dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
|
237 |
+
for i in range(len_sequence - len(pam) + 1):
|
238 |
+
if seq[i + 1:i + 3] == pam[1:]:
|
239 |
+
if i >= target_length:
|
240 |
+
target_seq = seq[i - target_length:i + 3]
|
241 |
+
pos = ref_sequence.find(target_seq)
|
242 |
+
if pos != -1:
|
243 |
+
is_mut = False
|
244 |
+
if strand == -1:
|
245 |
+
tar_start = end - pos - target_length - 2
|
246 |
+
else:
|
247 |
+
tar_start = start + pos
|
248 |
+
else:
|
249 |
+
is_mut = True
|
250 |
+
nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
|
251 |
+
if strand == -1:
|
252 |
+
tar_start = str(end - nw_pos - target_length - 2) + '*'
|
253 |
+
else:
|
254 |
+
tar_start = str(start + nw_pos) + '*'
|
255 |
+
gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
|
256 |
+
targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])
|
257 |
+
|
258 |
+
# filter duplicated targets
|
259 |
+
unique_targets_set = set(tuple(element) for element in targets)
|
260 |
+
unique_targets = [list(element) for element in unique_targets_set]
|
261 |
+
|
262 |
+
return unique_targets
|
263 |
+
|
264 |
+
def format_prediction_output_with_mutation(targets, model_path):
|
265 |
+
model = MultiHeadAttention_model(input_shape=(23, 4))
|
266 |
+
model.load_weights(model_path)
|
267 |
+
|
268 |
+
formatted_data = []
|
269 |
+
|
270 |
+
for target in targets:
|
271 |
+
# Encode the gRNA sequence
|
272 |
+
encoded_seq = get_seqcode(target[0])
|
273 |
+
|
274 |
+
|
275 |
+
# Predict on-target efficiency using the model
|
276 |
+
prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
|
277 |
+
if prediction > 100:
|
278 |
+
prediction = 100
|
279 |
+
|
280 |
+
# Format output
|
281 |
+
gRNA = target[1]
|
282 |
+
exon_chr = target[2]
|
283 |
+
strand = target[3]
|
284 |
+
tar_start = target[4]
|
285 |
+
transcript_id = target[5]
|
286 |
+
exon_id = target[6]
|
287 |
+
gene_symbol = target[7]
|
288 |
+
is_mut = target[8]
|
289 |
+
formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
|
290 |
+
exon_id, target[0], gRNA, prediction, is_mut])
|
291 |
+
|
292 |
+
return formatted_data
|
293 |
+
|
294 |
+
|
295 |
+
def process_gene(gene_symbol, vcf_reader, model_path):
|
296 |
+
transcripts = fetch_ensembl_transcripts(gene_symbol)
|
297 |
+
results = []
|
298 |
+
all_exons = [] # To accumulate all exons
|
299 |
+
all_gene_sequences = [] # To accumulate all gene sequences
|
300 |
+
|
301 |
+
if transcripts:
|
302 |
+
for transcript in transcripts:
|
303 |
+
Exons = transcript['Exon']
|
304 |
+
all_exons.extend(Exons) # Add all exons from this transcript to the list
|
305 |
+
transcript_id = transcript['id']
|
306 |
+
|
307 |
+
for Exon in Exons:
|
308 |
+
exon_id = Exon['id']
|
309 |
+
gene_sequence = fetch_ensembl_sequence(exon_id) # Reference exon sequence
|
310 |
+
if gene_sequence:
|
311 |
+
all_gene_sequences.append(gene_sequence) # Add this gene sequence to the list
|
312 |
+
exon_chr = Exon['seq_region_name']
|
313 |
+
start = Exon['start']
|
314 |
+
end = Exon['end']
|
315 |
+
strand = Exon['strand']
|
316 |
+
|
317 |
+
targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
|
318 |
+
transcript_id, exon_id, gene_symbol, vcf_reader)
|
319 |
+
if targets:
|
320 |
+
# Predict on-target efficiency for each gRNA site including mutations
|
321 |
+
formatted_data = format_prediction_output_with_mutation(targets, model_path)
|
322 |
+
results.extend(formatted_data)
|
323 |
+
else:
|
324 |
+
print(f"Failed to retrieve gene sequence for exon {exon_id}.")
|
325 |
+
else:
|
326 |
+
print("Failed to retrieve transcripts.")
|
327 |
+
|
328 |
+
# Sort results based on prediction score (assuming score is at the 8th index)
|
329 |
+
sorted_results = sorted(results, key=lambda x: x[8], reverse=True)
|
330 |
+
|
331 |
+
# Return the sorted output, combined gene sequences, and all exons
|
332 |
+
return sorted_results, all_gene_sequences, all_exons
|
333 |
+
|
334 |
+
|
335 |
+
def create_genbank_features(data):
|
336 |
+
features = []
|
337 |
+
|
338 |
+
# If the input data is a DataFrame, convert it to a list of lists
|
339 |
+
if isinstance(data, pd.DataFrame):
|
340 |
+
formatted_data = data.values.tolist()
|
341 |
+
elif isinstance(data, list):
|
342 |
+
formatted_data = data
|
343 |
+
else:
|
344 |
+
raise TypeError("Data should be either a list or a pandas DataFrame.")
|
345 |
+
|
346 |
+
for row in formatted_data:
|
347 |
+
try:
|
348 |
+
start = int(row[1])
|
349 |
+
end = int(row[2])
|
350 |
+
except ValueError as e:
|
351 |
+
print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
|
352 |
+
continue
|
353 |
+
|
354 |
+
strand = 1 if row[3] == '+' else -1
|
355 |
+
location = FeatureLocation(start=start, end=end, strand=strand)
|
356 |
+
feature = SeqFeature(location=location, type="misc_feature", qualifiers={
|
357 |
+
'label': row[7], # Use gRNA as the label
|
358 |
+
'note': f"Prediction: {row[8]}" # Include the prediction score
|
359 |
+
})
|
360 |
+
features.append(feature)
|
361 |
+
|
362 |
+
return features
|
363 |
+
|
364 |
+
|
365 |
+
def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
|
366 |
+
# Ensure gene_sequence is a string before creating Seq object
|
367 |
+
if not isinstance(gene_sequence, str):
|
368 |
+
gene_sequence = str(gene_sequence)
|
369 |
+
|
370 |
+
features = create_genbank_features(df)
|
371 |
+
|
372 |
+
# Now gene_sequence is guaranteed to be a string, suitable for Seq
|
373 |
+
seq_obj = Seq(gene_sequence)
|
374 |
+
record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
|
375 |
+
description=f'CRISPR Cas9 predicted targets for {gene_symbol}', features=features)
|
376 |
+
record.annotations["molecule_type"] = "DNA"
|
377 |
+
SeqIO.write(record, output_path, "genbank")
|
378 |
+
|
379 |
+
|
380 |
+
def create_bed_file_from_df(df, output_path):
|
381 |
+
with open(output_path, 'w') as bed_file:
|
382 |
+
for index, row in df.iterrows():
|
383 |
+
chrom = row["Chr"]
|
384 |
+
start = int(row["Start Pos"])
|
385 |
+
end = int(row["End Pos"])
|
386 |
+
strand = '+' if row["Strand"] == '1' else '-'
|
387 |
+
gRNA = row["gRNA"]
|
388 |
+
score = str(row["Prediction"])
|
389 |
+
# transcript_id is not typically part of the standard BED columns but added here for completeness
|
390 |
+
transcript_id = row["Transcript"]
|
391 |
+
|
392 |
+
# Writing only standard BED columns; additional columns can be appended as needed
|
393 |
+
bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
|
394 |
+
|
395 |
+
|
396 |
+
def create_csv_from_df(df, output_path):
|
397 |
+
df.to_csv(output_path, index=False)
|
cas9on.py
CHANGED
@@ -8,9 +8,7 @@ from Bio import SeqIO
|
|
8 |
from Bio.SeqRecord import SeqRecord
|
9 |
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
10 |
from Bio.Seq import Seq
|
11 |
-
|
12 |
-
import random
|
13 |
-
import pyBigWig
|
14 |
|
15 |
# configure GPUs
|
16 |
for gpu in tf.config.list_physical_devices('GPU'):
|
|
|
8 |
from Bio.SeqRecord import SeqRecord
|
9 |
from Bio.SeqFeature import SeqFeature, FeatureLocation
|
10 |
from Bio.Seq import Seq
|
11 |
+
|
|
|
|
|
12 |
|
13 |
# configure GPUs
|
14 |
for gpu in tf.config.list_physical_devices('GPU'):
|
requirements.txt
CHANGED
@@ -4,5 +4,8 @@ pandas==1.5.2
|
|
4 |
tensorflow==2.11.0
|
5 |
tensorflow-probability==0.19.0
|
6 |
plotly==5.18.0
|
|
|
|
|
|
|
7 |
gtracks
|
8 |
pyGenomeTracks
|
|
|
4 |
tensorflow==2.11.0
|
5 |
tensorflow-probability==0.19.0
|
6 |
plotly==5.18.0
|
7 |
+
tabulate
|
8 |
+
cyvcf2
|
9 |
+
parasail
|
10 |
gtracks
|
11 |
pyGenomeTracks
|