File size: 6,709 Bytes
c5d6bef
 
 
4fcd52e
c5d6bef
c0cd7ac
b773936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d6bef
2256afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d6bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0cd7ac
c5d6bef
 
 
 
 
 
 
 
2256afd
eae90c3
c5d6bef
4fcd52e
 
 
 
 
 
 
 
 
 
67e2f88
4fcd52e
 
 
 
 
 
 
c5d6bef
4fcd52e
c5d6bef
7f124ce
 
2256afd
97ef52e
7f124ce
 
eae90c3
97ef52e
 
 
 
7f124ce
c5d6bef
 
4fcd52e
d0c22b9
98e60d4
4fcd52e
78afc4e
4fcd52e
 
 
78afc4e
0b8f649
74de55c
67e2f88
78afc4e
67f0454
4fcd52e
67e2f88
9f44693
4fcd52e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from Bio import SeqIO
import gradio as gr
import spaces
from gradio_molecule3d import Molecule3D

reps =    [
    {
      "model": 0,
      "chain": "",
      "resname": "",
      "style": "stick",
      "color": "whiteCarbon",
      "residue_range": "",
      "around": 0,
      "byres": False,
      "visible": False
    }
]

def read_mol(molpath):
    with open(molpath, "r") as fp:
        lines = fp.readlines()
    mol = ""
    for l in lines:
        mol += l
    return mol


def molecule(input_pdb):

    mol = read_mol(input_pdb)

    x = (
        """<!DOCTYPE html>
        <html>
        <head>    
    <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
    <style>
    body{
        font-family:sans-serif
    }
    .mol-container {
    width: 100%;
    height: 600px;
    position: relative;
    }
    .mol-container select{
        background-image:None;
    }
    </style>
     <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
    <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
    </head>
    <body>  
    <div id="container" class="mol-container"></div>
  
            <script>
               let pdb = `"""
        + mol
        + """`  
      
             $(document).ready(function () {
                let element = $("#container");
                let config = { backgroundColor: "white" };
                let viewer = $3Dmol.createViewer(element, config);
                viewer.addModel(pdb, "pdb");
                viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });
                viewer.zoomTo();
                viewer.render();
                viewer.zoom(0.8, 2000);
              })
        </script>
        </body></html>"""
    )

    return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""


def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

model = model.cuda()

model.esm = model.esm.half()

import torch

torch.backends.cuda.matmul.allow_tf32 = True

model.trunk.set_chunk_size(64)

@spaces.GPU(duration=120)
def fold_protein(test_protein):
    tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
    tokenized_input = tokenized_input.cuda()
    with torch.no_grad():
        output = model(tokenized_input)
    pdb = convert_outputs_to_pdb(output)
    with open("output_structure.pdb", "w") as f:
        f.write("".join(pdb))
    html = molecule("output_structure.pdb")
    return html, "output_structure.pdb"

@spaces.GPU(duration=180)
def fold_protein_wpdb(test_protein, pdb_path):
    tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
    tokenized_input = tokenized_input.cuda()
    with torch.no_grad():
        output = model(tokenized_input)
    pdb = convert_outputs_to_pdb(output)
    with open(pdb_path, "w") as f:
        f.write("".join(pdb))
    html = molecule(pdb_path)
    return html

def load_protein_sequences(fasta_file):
    protein_sequences = {} 
    for record in SeqIO.parse(fasta_file, "fasta"):
        protein_sequences[record.id] = str(record.seq)
    return protein_sequences

iface = gr.Interface(
    title="Proteinviz",
    fn=fold_protein,
    inputs=gr.Textbox(
            label="Protein Sequence",
            info="Find sequences examples below, and complete examples with images at: https://github.com/AstraBert/proteinviz/tree/main/examples.md; if you input a sequence, you're gonna get the static image and the 3D model to explore and play with",
            lines=5,
            value=f"Paste or write amino-acidic sequence here",
        ),
    outputs=[gr.HTML(label="Protein 3D model"), Molecule3D(label="Molecular 3D model", reps=reps)], 
    examples=[
        "MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKVKAHGKKVLGAFSDGLAHLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGKEFTPPVQAAYQKVVAGVANALAHKYH",
        "MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGPGCMSCKCVLS",
        "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG",
    ]
)

with gr.Blocks() as demo1:
    input_seqs = gr.File(label="FASTA File With Protein Sequences")
    @gr.render(inputs=input_seqs)
    def show_split(inputfile):
        if inputfile is None:
            gr.Markdown("## No Input Provided")
        else:
            seqs = load_protein_sequences(inputfile)
            print("Loaded sequences")
            for seq in seqs:
                pdb_path = f'{seq.replace(" ", "_").replace(",","")}.pdb'
                html, pdb = fold_protein_wpdb(seqs[seq], pdb_path)
                print(f"Prediction for {seq} is over")
                gr.HTML(html, label=f"{seq} structural representation")


demo = gr.TabbedInterface([iface, demo1], ["Single Protein Structure Prediction", "Bulk Protein Structure Prediction"])

demo.launch(server_name="0.0.0.0", share=False)