supercat666 commited on
Commit
dc94424
·
1 Parent(s): 73dcc35
Files changed (3) hide show
  1. app.py +101 -3
  2. cas12.py +175 -0
  3. cas12_model/Seq_deepCpf1_weights.h5 +3 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import tiger
3
  import cas9on
4
  import cas9off
 
5
  import pandas as pd
6
  import streamlit as st
7
  import plotly.graph_objs as go
@@ -18,6 +19,7 @@ CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
18
 
19
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
20
  cas9on_path = 'cas9_model/on-cla.h5'
 
21
 
22
  @st.cache_data
23
  def convert_df(df):
@@ -287,9 +289,105 @@ if selected_model == 'Cas9':
287
  st.experimental_rerun()
288
 
289
  elif selected_model == 'Cas12':
290
- # Placeholder for Cas12 model loading
291
- # TODO: Implement Cas12 model loading logic
292
- raise NotImplementedError("Cas12 model loading not implemented yet.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  elif selected_model == 'Cas13d':
294
  ENTRY_METHODS = dict(
295
  manual='Manual entry of single transcript',
 
2
  import tiger
3
  import cas9on
4
  import cas9off
5
+ import cas12
6
  import pandas as pd
7
  import streamlit as st
8
  import plotly.graph_objs as go
 
19
 
20
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
21
  cas9on_path = 'cas9_model/on-cla.h5'
22
+ cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'
23
 
24
  @st.cache_data
25
  def convert_df(df):
 
289
  st.experimental_rerun()
290
 
291
  elif selected_model == 'Cas12':
292
+ # Gene symbol entry
293
+ gene_symbol = st.text_input('Enter a Gene Symbol:', key='gene_symbol')
294
+
295
+ # Initialize the current_gene_symbol in the session state if it doesn't exist
296
+ if 'current_gene_symbol' not in st.session_state:
297
+ st.session_state['current_gene_symbol'] = ""
298
+
299
+ # Prediction button
300
+ predict_button = st.button('Predict on-target')
301
+
302
+ # Function to clean up old files
303
+ def clean_up_old_files(gene_symbol):
304
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
305
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
306
+ if os.path.exists(genbank_file_path):
307
+ os.remove(genbank_file_path)
308
+ if os.path.exists(bed_file_path):
309
+ os.remove(bed_file_path)
310
+
311
+ # Clean up files if a new gene symbol is entered
312
+ if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
313
+ clean_up_old_files(st.session_state['current_gene_symbol'])
314
+
315
+ # Process predictions
316
+ if predict_button and gene_symbol:
317
+ # Update the current gene symbol
318
+ st.session_state['current_gene_symbol'] = gene_symbol
319
+
320
+ # Run the prediction process
321
+ predictions, gene_sequence = cas12.process_gene(gene_symbol,cas12_path)
322
+ sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
323
+ st.session_state['on_target_results'] = sorted_predictions
324
+
325
+ # Visualization and file generation
326
+ if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
327
+ df = pd.DataFrame(st.session_state['on_target_results'],
328
+ columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "Prediction"])
329
+
330
+ # Now create a Plotly plot with the sorted_predictions
331
+ fig = go.Figure()
332
+
333
+ # Iterate over the sorted predictions to create the plot
334
+ for i, prediction in enumerate(sorted_predictions, start=1):
335
+ # Extract data for plotting
336
+ chrom, start, end, strand, gRNA, pred_score = prediction
337
+ # Strand is not used in this plot, but you could use it to determine marker symbol, for example
338
+ fig.add_trace(go.Scatter(
339
+ x=[start, end],
340
+ y=[i, i], # Y-values are just the rank of the prediction
341
+ mode='lines+markers+text',
342
+ name=f"gRNA: {gRNA}",
343
+ text=[f"Rank: {i}", ""], # Text at the start position only
344
+ hoverinfo='text',
345
+ hovertext=[
346
+ f"Rank: {i}<br>Chromosome: {chrom}<br>Target: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
347
+ ""
348
+ ],
349
+ ))
350
+ # Update the layout of the plot
351
+ fig.update_layout(
352
+ title='Top 10 gRNA Sequences by Prediction Score',
353
+ xaxis_title='Genomic Position',
354
+ yaxis_title='Rank',
355
+ yaxis=dict(showticklabels=False)
356
+ # We hide the y-axis labels since the rank is indicated in the hovertext
357
+ )
358
+ # Display the plot
359
+ st.plotly_chart(fig)
360
+
361
+ # Ensure gene_sequence is not empty before generating files
362
+ if gene_sequence:
363
+ genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
364
+ bed_file_path = f"{gene_symbol}_crispr_targets.bed"
365
+
366
+ # Generate GenBank file
367
+ cas12.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
368
+
369
+ # Generate BED file
370
+ cas12.create_bed_file_from_df(df, bed_file_path)
371
+
372
+ st.write('Top on-target predictions:')
373
+ st.dataframe(df)
374
+
375
+ # Download buttons
376
+ with open(genbank_file_path, "rb") as file:
377
+ st.download_button(
378
+ label="Download GenBank File",
379
+ data=file,
380
+ file_name=genbank_file_path,
381
+ mime="text/x-genbank"
382
+ )
383
+
384
+ with open(bed_file_path, "rb") as file:
385
+ st.download_button(label="Download BED File", data=file,
386
+ file_name=bed_file_path, mime="text/plain")
387
+
388
+ # Clean up old files after download buttons are created
389
+ clean_up_old_files(gene_symbol)
390
+
391
  elif selected_model == 'Cas13d':
392
  ENTRY_METHODS = dict(
393
  manual='Manual entry of single transcript',
cas12.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras import Model
2
+ from keras.layers import Input
3
+ from keras.layers import Multiply
4
+ from keras.layers import Dense, Dropout, Activation, Flatten
5
+ from keras.layers import Convolution1D, AveragePooling1D
6
+ import pandas as pd
7
+ import numpy as np
8
+ import keras
9
+ import requests
10
+ from functools import reduce
11
+ from operator import add
12
+ from Bio.SeqRecord import SeqRecord
13
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
14
+ from Bio.Seq import Seq
15
+ from Bio import SeqIO
16
+
17
+ ntmap = {'A': (1, 0, 0, 0),
18
+ 'C': (0, 1, 0, 0),
19
+ 'G': (0, 0, 1, 0),
20
+ 'T': (0, 0, 0, 1)
21
+ }
22
+
23
+ def get_seqcode(seq):
24
+ return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))
25
+
26
+ def Seq_DeepCpf1_model(input_shape):
27
+ Seq_deepCpf1_Input_SEQ = Input(shape=input_shape)
28
+ Seq_deepCpf1_C1 = Convolution1D(80, 5, activation='relu')(Seq_deepCpf1_Input_SEQ)
29
+ Seq_deepCpf1_P1 = AveragePooling1D(2)(Seq_deepCpf1_C1)
30
+ Seq_deepCpf1_F = Flatten()(Seq_deepCpf1_P1)
31
+ Seq_deepCpf1_DO1 = Dropout(0.3)(Seq_deepCpf1_F)
32
+ Seq_deepCpf1_D1 = Dense(80, activation='relu')(Seq_deepCpf1_DO1)
33
+ Seq_deepCpf1_DO2 = Dropout(0.3)(Seq_deepCpf1_D1)
34
+ Seq_deepCpf1_D2 = Dense(40, activation='relu')(Seq_deepCpf1_DO2)
35
+ Seq_deepCpf1_DO3 = Dropout(0.3)(Seq_deepCpf1_D2)
36
+ Seq_deepCpf1_D3 = Dense(40, activation='relu')(Seq_deepCpf1_DO3)
37
+ Seq_deepCpf1_DO4 = Dropout(0.3)(Seq_deepCpf1_D3)
38
+ Seq_deepCpf1_Output = Dense(1, activation='linear')(Seq_deepCpf1_DO4)
39
+ Seq_deepCpf1 = Model(inputs=[Seq_deepCpf1_Input_SEQ], outputs=[Seq_deepCpf1_Output])
40
+ return Seq_deepCpf1
41
+
42
+ # seq-ca model (DeepCpf1)
43
+ def DeepCpf1_model(input_shape):
44
+ DeepCpf1_Input_SEQ = Input(shape=input_shape)
45
+ DeepCpf1_C1 = Convolution1D(80, 5, activation='relu')(DeepCpf1_Input_SEQ)
46
+ DeepCpf1_P1 = AveragePooling1D(2)(DeepCpf1_C1)
47
+ DeepCpf1_F = Flatten()(DeepCpf1_P1)
48
+ DeepCpf1_DO1 = Dropout(0.3)(DeepCpf1_F)
49
+ DeepCpf1_D1 = Dense(80, activation='relu')(DeepCpf1_DO1)
50
+ DeepCpf1_DO2 = Dropout(0.3)(DeepCpf1_D1)
51
+ DeepCpf1_D2 = Dense(40, activation='relu')(DeepCpf1_DO2)
52
+ DeepCpf1_DO3 = Dropout(0.3)(DeepCpf1_D2)
53
+ DeepCpf1_D3_SEQ = Dense(40, activation='relu')(DeepCpf1_DO3)
54
+ DeepCpf1_Input_CA = Input(shape=(1,))
55
+ DeepCpf1_D3_CA = Dense(40, activation='relu')(DeepCpf1_Input_CA)
56
+ DeepCpf1_M = Multiply()([DeepCpf1_D3_SEQ, DeepCpf1_D3_CA])
57
+ DeepCpf1_DO4 = Dropout(0.3)(DeepCpf1_M)
58
+ DeepCpf1_Output = Dense(1, activation='linear')(DeepCpf1_DO4)
59
+ DeepCpf1 = Model(inputs=[DeepCpf1_Input_SEQ, DeepCpf1_Input_CA], outputs=[DeepCpf1_Output])
60
+ return DeepCpf1
61
+
62
+ def fetch_ensembl_transcripts(gene_symbol):
63
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
64
+ response = requests.get(url)
65
+ if response.status_code == 200:
66
+ gene_data = response.json()
67
+ if 'Transcript' in gene_data:
68
+ return gene_data['Transcript']
69
+ else:
70
+ print("No transcripts found for gene:", gene_symbol)
71
+ return None
72
+ else:
73
+ print(f"Error fetching gene data from Ensembl: {response.text}")
74
+ return None
75
+
76
+ def fetch_ensembl_sequence(transcript_id):
77
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
78
+ response = requests.get(url)
79
+ if response.status_code == 200:
80
+ sequence_data = response.json()
81
+ if 'seq' in sequence_data:
82
+ return sequence_data['seq']
83
+ else:
84
+ print("No sequence found for transcript:", transcript_id)
85
+ return None
86
+ else:
87
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
88
+ return None
89
+
90
+
91
+ def find_crispr_targets(sequence, chr, start, strand, pam="TTTN", target_length=34):
92
+ targets = []
93
+ len_sequence = len(sequence)
94
+
95
+ for i in range(len_sequence - target_length + 1):
96
+ target_seq = sequence[i:i + target_length]
97
+ if target_seq[4:7] == 'TTT':
98
+ tar_start = start + i
99
+ tar_end = start + i + target_length
100
+ gRNA = target_seq[8:28]
101
+ targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand)])
102
+ return targets
103
+
104
+ def format_prediction_output(targets, seq_deepCpf1):
105
+ formatted_data = []
106
+ for target in targets:
107
+ # Predict
108
+ encoded_seq = get_seqcode(target[0]) # 'target' seems to be the full sequence including PAM
109
+ prediction = seq_deepCpf1.predict(encoded_seq)
110
+ # Format output
111
+ gRNA = target[1] # gRNA is presumably the guide RNA sequence
112
+ chr = target[2] # Chromosome
113
+ start = target[3] # Start position
114
+ end = target[4] # End position
115
+ strand = target[5] # Strand
116
+ target_seq = target[0] # Full target sequence including PAM
117
+ formatted_data.append([chr, start, end, strand, target_seq, gRNA, prediction[0][0]])
118
+ return formatted_data
119
+
120
+ def process_gene(gene_symbol, model_path):
121
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
122
+ all_data = []
123
+ gene_sequence = '' # Initialize an empty string for the gene sequence
124
+
125
+ # Load the model
126
+ seq_deepCpf1 = Seq_DeepCpf1_model(input_shape=(34, 4))
127
+ seq_deepCpf1.load_weights(model_path)
128
+
129
+ if transcripts:
130
+ for transcript in transcripts:
131
+ transcript_id = transcript['id']
132
+ chr = transcript.get('seq_region_name', 'unknown')
133
+ start = transcript.get('start', 0)
134
+ strand = transcript.get('strand', 'unknown')
135
+ # Fetch the sequence here and concatenate if multiple transcripts
136
+ gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
137
+
138
+ if gene_sequence:
139
+ targets = find_crispr_targets(gene_sequence, chr, start, strand)
140
+ if targets:
141
+ formatted_data = format_prediction_output(targets, seq_deepCpf1)
142
+ all_data.extend(formatted_data)
143
+ else:
144
+ print("Failed to retrieve transcripts.")
145
+
146
+ return all_data, gene_sequence
147
+
148
+ def create_genbank_features(formatted_data):
149
+ features = []
150
+ for data in formatted_data:
151
+ location = FeatureLocation(start=int(data[1]), end=int(data[2]), strand=(1 if data[3] == '+' else -1))
152
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
153
+ 'label': data[5], # gRNA as label
154
+ 'note': f"Prediction: {data[6]}" # Prediction score in note
155
+ })
156
+ features.append(feature)
157
+ return features
158
+
159
+ def generate_genbank_file_from_data(formatted_data, gene_sequence, gene_symbol, output_path):
160
+ features = create_genbank_features(formatted_data)
161
+ record = SeqRecord(Seq(gene_sequence), id=gene_symbol, name=gene_symbol,
162
+ description='CRISPR Cas12 predicted targets', features=features)
163
+ record.annotations["molecule_type"] = "DNA"
164
+ SeqIO.write(record, output_path, "genbank")
165
+
166
+ def generate_bed_file_from_data(formatted_data, output_path):
167
+ with open(output_path, 'w') as bed_file:
168
+ for data in formatted_data:
169
+ chrom = data[0]
170
+ start = data[1]
171
+ end = data[2]
172
+ strand = data[3]
173
+ gRNA = data[5]
174
+ score = data[6]
175
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
cas12_model/Seq_deepCpf1_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c52c1f93169ea1da55d4cb464f4d948551b9aeafb9ee47dc55fa76e23486526d
3
+ size 1285864