Atom Bioworks commited on
Commit
d3248a6
·
verified ·
1 Parent(s): 71b9472

Create gui.py

Browse files
Files changed (1) hide show
  1. gui.py +109 -0
gui.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from api_prediction import AptaTransPipeline_Dist
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import torch
5
+ import tempfile
6
+ from tabulate import tabulate
7
+ from PIL import Image
8
+ import itertools
9
+ import os
10
+ import RNA
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.image as mpimg
13
+ import random
14
+ from scipy.cluster.hierarchy import dendrogram, linkage
15
+ # Visualization
16
+ from Bio.Phylo.PhyloXML import Phylogeny
17
+ from Bio import SeqIO
18
+ from Bio.Seq import Seq
19
+ from Bio.SeqRecord import SeqRecord
20
+ from Bio import AlignIO
21
+ from Bio.Align.Applications import MafftCommandline
22
+ from Bio import Phylo
23
+ from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
24
+ import io
25
+
26
+ os.environ['GRADIO_SERVER_NAME'] = '0.0.0.0'
27
+ title='DNAptaESM2 Model Infernence'
28
+ desc='AptaBLE (cross-attention network), trained to predict the likelihood a DNA aptamer will form a complex with a target protein!\n\nPass in a FASTA-formatted file of all aptamers and input your protein target amino acid sequence. Your output scores are available for download via an Excel file.'
29
+
30
+ global pipeline
31
+
32
+ pipeline = AptaTransPipeline_Dist(
33
+ lr=1e-6,
34
+ weight_decay=None,
35
+ epochs=None,
36
+ model_type=None,
37
+ model_version=None,
38
+ model_save_path=None,
39
+ accelerate_save_path=None,
40
+ tensorboard_logdir=None,
41
+ d_model=128,
42
+ d_ff=512,
43
+ n_layers=6,
44
+ n_heads=8,
45
+ dropout=0.1,
46
+ load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
47
+ device='cuda',
48
+ seed=1004)
49
+
50
+ def comparison(protein, aptamer_file, analysis):
51
+ print('analysis: ', analysis)
52
+ display = []
53
+ table_data = pd.DataFrame()
54
+ r_names, aptamers = read_fasta(aptamer_file)
55
+ proteins = [protein for i in range(len(aptamers))]
56
+ df = pd.DataFrame(columns=['Protein', 'Protein Seq', 'Aptamer', 'Aptamer Seq', 'Score'])
57
+ # print('Number of aptamers: ', len(aptamers))
58
+ scores = get_scores(aptamers, proteins)
59
+ df['Protein'] = ['protein_prov.']*len(aptamers)
60
+ df['Aptamer'] = r_names
61
+ df['Protein Seq'] = proteins
62
+ df['Aptamer Seq'] = aptamers
63
+ df['Score'] = scores
64
+
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
66
+ with pd.ExcelWriter(temp_file.name, engine='openpyxl') as writer:
67
+ df.to_excel(writer, index=False)
68
+ temp_file_path = temp_file.name
69
+
70
+ print('Saving to excel!')
71
+ df.to_excel(f'{aptamer_file}.xlsx')
72
+
73
+ torch.cuda.empty_cache()
74
+
75
+ return '\n'.join(display), temp_file_path
76
+
77
+ def read_fasta(file_path):
78
+ headers = []
79
+ sequences = []
80
+ with open(file_path, 'r') as file:
81
+ content = file.readlines()
82
+ for i in range(0, len(content), 2):
83
+ header = content[i].strip()
84
+ if header.startswith('>'):
85
+ headers.append(header)
86
+ sequences.append(content[i+1].strip())
87
+ return headers, sequences
88
+
89
+ def get_scores(aptamers, proteins):
90
+ pipeline.model.to('cuda')
91
+ scores = pipeline.inference(aptamers, proteins, [0]*len(aptamers))
92
+ pipeline.model.to('cpu')
93
+ return scores
94
+
95
+
96
+ iface = gr.Interface(
97
+ fn=comparison,
98
+ inputs=[
99
+ gr.Textbox(lines=2, placeholder="Protein"),
100
+ gr.File(type="filepath"),
101
+ ],
102
+ outputs=[
103
+ gr.Textbox(placeholder="Scores"),
104
+ gr.File(label="Download Excel")
105
+ ],
106
+ description=desc
107
+ )
108
+
109
+ iface.launch()