supercat666 commited on
Commit
bc641c8
·
1 Parent(s): 6392fd6

fix cas9 and cas12 output

Browse files
Files changed (3) hide show
  1. app.py +2 -4
  2. cas12lstm.py +69 -3
  3. cas9att.py +1 -4
app.py CHANGED
@@ -185,8 +185,7 @@ if selected_model == 'Cas9':
185
  if predict_button and gene_symbol:
186
  with st.spinner('Predicting... Please wait'):
187
  predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
188
-
189
- sorted_predictions = sorted(predictions)[:10]
190
  st.session_state['on_target_results'] = sorted_predictions
191
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
192
  st.session_state['exons'] = exons # Store exon data
@@ -436,8 +435,7 @@ elif selected_model == 'Cas12':
436
  if predict_button and gene_symbol:
437
  with st.spinner('Predicting... Please wait'):
438
  predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas9att_path)
439
-
440
- sorted_predictions = sorted(predictions)[:10]
441
  st.session_state['on_target_results'] = sorted_predictions
442
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
443
  st.session_state['exons'] = exons # Store exon data
 
185
  if predict_button and gene_symbol:
186
  with st.spinner('Predicting... Please wait'):
187
  predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
188
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
 
189
  st.session_state['on_target_results'] = sorted_predictions
190
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
191
  st.session_state['exons'] = exons # Store exon data
 
435
  if predict_button and gene_symbol:
436
  with st.spinner('Predicting... Please wait'):
437
  predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas9att_path)
438
+ sorted_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)[:10]
 
439
  st.session_state['on_target_results'] = sorted_predictions
440
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
441
  st.session_state['exons'] = exons # Store exon data
cas12lstm.py CHANGED
@@ -14,6 +14,10 @@ from functools import reduce
14
  from operator import add
15
  import tabulate
16
  from difflib import SequenceMatcher
 
 
 
 
17
 
18
  import cyvcf2
19
  import parasail
@@ -184,9 +188,71 @@ def process_gene(gene_symbol, model_path):
184
  for result in results:
185
  for item in result:
186
  output.append(item)
187
- # Sort results based on prediction score (assuming score is at the 8th index)
188
- sorted_results = sorted(output, key=lambda x: x[8], reverse=True)
189
 
190
  # Return the sorted output, combined gene sequences, and all exons
191
- return sorted_results, all_gene_sequences, all_exons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
14
  from operator import add
15
  import tabulate
16
  from difflib import SequenceMatcher
17
+ from Bio import SeqIO
18
+ from Bio.SeqRecord import SeqRecord
19
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
20
+ from Bio.Seq import Seq
21
 
22
  import cyvcf2
23
  import parasail
 
188
  for result in results:
189
  for item in result:
190
  output.append(item)
 
 
191
 
192
  # Return the sorted output, combined gene sequences, and all exons
193
+ return results, all_gene_sequences, all_exons
194
+
195
+ def create_genbank_features(data):
196
+ features = []
197
+
198
+ # If the input data is a DataFrame, convert it to a list of lists
199
+ if isinstance(data, pd.DataFrame):
200
+ formatted_data = data.values.tolist()
201
+ elif isinstance(data, list):
202
+ formatted_data = data
203
+ else:
204
+ raise TypeError("Data should be either a list or a pandas DataFrame.")
205
+
206
+ for row in formatted_data:
207
+ try:
208
+ start = int(row[1])
209
+ end = int(row[2])
210
+ except ValueError as e:
211
+ print(f"Error converting start/end to int: {row[1]}, {row[2]} - {e}")
212
+ continue
213
+
214
+ strand = 1 if row[3] == '+' else -1
215
+ location = FeatureLocation(start=start, end=end, strand=strand)
216
+ feature = SeqFeature(location=location, type="misc_feature", qualifiers={
217
+ 'label': row[7], # Use gRNA as the label
218
+ 'note': f"Prediction: {row[8]}" # Include the prediction score
219
+ })
220
+ features.append(feature)
221
+
222
+ return features
223
+
224
+
225
+ def generate_genbank_file_from_df(df, gene_sequence, gene_symbol, output_path):
226
+ # Ensure gene_sequence is a string before creating Seq object
227
+ if not isinstance(gene_sequence, str):
228
+ gene_sequence = str(gene_sequence)
229
+
230
+ features = create_genbank_features(df)
231
+
232
+ # Now gene_sequence is guaranteed to be a string, suitable for Seq
233
+ seq_obj = Seq(gene_sequence)
234
+ record = SeqRecord(seq_obj, id=gene_symbol, name=gene_symbol,
235
+ description=f'CRISPR Cas12 predicted targets for {gene_symbol}', features=features)
236
+ record.annotations["molecule_type"] = "DNA"
237
+ SeqIO.write(record, output_path, "genbank")
238
+
239
+
240
+ def create_bed_file_from_df(df, output_path):
241
+ with open(output_path, 'w') as bed_file:
242
+ for index, row in df.iterrows():
243
+ chrom = row["Chr"]
244
+ start = int(row["Start Pos"])
245
+ end = int(row["End Pos"])
246
+ strand = '+' if row["Strand"] == '1' else '-'
247
+ gRNA = row["gRNA"]
248
+ score = str(row["Prediction"])
249
+ # transcript_id is not typically part of the standard BED columns but added here for completeness
250
+ transcript_id = row["Transcript"]
251
+
252
+ # Writing only standard BED columns; additional columns can be appended as needed
253
+ bed_file.write(f"{chrom}\t{start}\t{end}\t{gRNA}\t{score}\t{strand}\n")
254
+
255
+
256
+ def create_csv_from_df(df, output_path):
257
+ df.to_csv(output_path, index=False)
258
 
cas9att.py CHANGED
@@ -228,12 +228,9 @@ def process_gene(gene_symbol, model_path):
228
  for result in results:
229
  for item in result:
230
  output.append(item)
231
-
232
- # Sort results based on prediction score (assuming score is at the 8th index)
233
- sorted_results = sorted(output, key=lambda x: x[8], reverse=True)
234
 
235
  # Return the sorted output, combined gene sequences, and all exons
236
- return sorted_results, all_gene_sequences, all_exons
237
 
238
 
239
  def create_genbank_features(data):
 
228
  for result in results:
229
  for item in result:
230
  output.append(item)
 
 
 
231
 
232
  # Return the sorted output, combined gene sequences, and all exons
233
+ return results, all_gene_sequences, all_exons
234
 
235
 
236
  def create_genbank_features(data):