Spaces:
Runtime error
Runtime error
omarperacha
commited on
Commit
·
5a79fe4
1
Parent(s):
16bd580
embedings generate
Browse files- .gitignore +1 -0
- app.py +4 -4
- ps4_data/data/protT5/output/per_residue_embeddings0.npz +3 -0
- ps4_data/get_embeddings.py +3 -10
- ps4_eval/eval.py +67 -0
- requirements.txt +1 -1
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
.DS_Store
|
2 |
.idea/
|
3 |
ps4_data/__pycache__/
|
|
|
|
1 |
.DS_Store
|
2 |
.idea/
|
3 |
ps4_data/__pycache__/
|
4 |
+
ps4_eval/__pycache__/
|
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
from ps4_data.get_embeddings import generate_embedings
|
4 |
|
5 |
|
6 |
def pred(residue_seq):
|
7 |
-
generate_embedings(residue_seq)
|
8 |
-
|
9 |
-
return
|
10 |
|
11 |
|
12 |
iface = gr.Interface(fn=pred, title="Protein Secondary Structure Prediction with PS4-Mega",
|
|
|
1 |
import gradio as gr
|
2 |
+
from ps4_eval.eval import sample_new_sequence
|
3 |
from ps4_data.get_embeddings import generate_embedings
|
4 |
|
5 |
|
6 |
def pred(residue_seq):
|
7 |
+
embs = generate_embedings(residue_seq)["residue_embs"]["0"]
|
8 |
+
preds = sample_new_sequence(embs, "ps4_models/Mega/PS4-Mega_loss-0.633_acc-78.176.pt")
|
9 |
+
return preds
|
10 |
|
11 |
|
12 |
iface = gr.Interface(fn=pred, title="Protein Secondary Structure Prediction with PS4-Mega",
|
ps4_data/data/protT5/output/per_residue_embeddings0.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce32f0bb1cced643cc12d813b19bfa11b074b0d789329d7da1a1c066d63ccf75
|
3 |
+
size 197778
|
ps4_data/get_embeddings.py
CHANGED
@@ -22,17 +22,10 @@ def generate_embedings(input_seq, output_path=None):
|
|
22 |
|
23 |
# Load fasta.
|
24 |
all_seqs = {"0": input_seq}
|
25 |
-
|
26 |
-
chunk_size = 1000
|
27 |
-
|
28 |
# Compute embeddings and/or secondary structure predictions
|
29 |
-
|
30 |
-
keys = list(all_seqs.keys())[i: chunk_size + i]
|
31 |
-
seqs = {k: all_seqs[k] for k in keys}
|
32 |
-
results = __get_embeddings(model, tokenizer, seqs, device)
|
33 |
|
34 |
-
|
35 |
-
__save_embeddings(results["residue_embs"], per_residue_path + f"{i}.npz")
|
36 |
|
37 |
|
38 |
def __get_T5_model(device):
|
@@ -92,7 +85,7 @@ def __get_embeddings(model, tokenizer, seqs, device, per_residue=True,
|
|
92 |
# slice off padding --> batch-size x seq_len x embedding_dim
|
93 |
emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
|
94 |
if per_residue: # store per-residue embeddings (Lx1024)
|
95 |
-
results["residue_embs"][identifier] = emb.detach().cpu().
|
96 |
print("emb_count:", len(results["residue_embs"]))
|
97 |
|
98 |
passed_time = time.time() - start
|
|
|
22 |
|
23 |
# Load fasta.
|
24 |
all_seqs = {"0": input_seq}
|
|
|
|
|
|
|
25 |
# Compute embeddings and/or secondary structure predictions
|
26 |
+
results = __get_embeddings(model, tokenizer, all_seqs, device)
|
|
|
|
|
|
|
27 |
|
28 |
+
return results
|
|
|
29 |
|
30 |
|
31 |
def __get_T5_model(device):
|
|
|
85 |
# slice off padding --> batch-size x seq_len x embedding_dim
|
86 |
emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
|
87 |
if per_residue: # store per-residue embeddings (Lx1024)
|
88 |
+
results["residue_embs"][identifier] = emb.detach().cpu().squeeze()
|
89 |
print("emb_count:", len(results["residue_embs"]))
|
90 |
|
91 |
passed_time = time.time() - start
|
ps4_eval/eval.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, cuda, load, device
|
3 |
+
from ps4_models.classifiers import PS4_Mega, PS4_Conv
|
4 |
+
|
5 |
+
|
6 |
+
def load_trained_model(load_path, model_name='PS4_Mega'):
|
7 |
+
|
8 |
+
if model_name.lower() not in ['ps4_conv', 'ps4_mega']:
|
9 |
+
raise ValueError(f'Model name {model_name} not recognised, please choose from PS4_Conv, PS4_Mega')
|
10 |
+
|
11 |
+
model: nn.Module = PS4_Mega() if model_name.lower() == 'ps4_mega' else PS4_Conv()
|
12 |
+
|
13 |
+
if load_path != '':
|
14 |
+
try:
|
15 |
+
if cuda.is_available():
|
16 |
+
model.load_state_dict(load(load_path)['model_state_dict'])
|
17 |
+
else:
|
18 |
+
model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
|
19 |
+
print("loded params from", load_path)
|
20 |
+
except:
|
21 |
+
raise ImportError(f'No file located at {load_path}, could not load parameters')
|
22 |
+
print(model)
|
23 |
+
|
24 |
+
pytorch_total_params = sum(par.numel() for par in model.parameters() if par.requires_grad)
|
25 |
+
print(pytorch_total_params)
|
26 |
+
|
27 |
+
return model
|
28 |
+
|
29 |
+
|
30 |
+
# MARK: sampling from new sequence
|
31 |
+
def sample_new_sequence(embs, weights_load_path, model_name='PS4_Mega'):
|
32 |
+
|
33 |
+
model = load_trained_model(weights_load_path, model_name)
|
34 |
+
|
35 |
+
seq_size = len(embs)
|
36 |
+
R = embs.view(1, seq_size)
|
37 |
+
|
38 |
+
pred_ss = ''
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
y_hat = model(R)
|
42 |
+
probs = torch.softmax(y_hat, 2)
|
43 |
+
_, ss_preds = torch.max(probs, 2)
|
44 |
+
|
45 |
+
for i in range(seq_size):
|
46 |
+
ss = ss_preds[0][i].item()
|
47 |
+
ss = ss_tokeniser(ss, reverse=True)
|
48 |
+
pred_ss += ss
|
49 |
+
|
50 |
+
return pred_ss
|
51 |
+
|
52 |
+
|
53 |
+
def ss_tokeniser(ss, reverse=False):
|
54 |
+
|
55 |
+
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C']
|
56 |
+
|
57 |
+
if reverse:
|
58 |
+
return inverse_ss_tokeniser(ss)
|
59 |
+
else:
|
60 |
+
return 0 if (ss == 'P' or ss == ' ') else ss_set.index(ss)
|
61 |
+
|
62 |
+
|
63 |
+
def inverse_ss_tokeniser(ss):
|
64 |
+
|
65 |
+
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C', 'C']
|
66 |
+
|
67 |
+
return ss_set[ss]
|
requirements.txt
CHANGED
@@ -5,5 +5,5 @@ scikit-learn~=0.24.2
|
|
5 |
transformers~=4.26.1
|
6 |
setuptools~=57.4.0
|
7 |
pandas~=1.3.2
|
8 |
-
|
9 |
-e git+https://github.com/facebookresearch/mega.git@main#egg=fairseq
|
|
|
5 |
transformers~=4.26.1
|
6 |
setuptools~=57.4.0
|
7 |
pandas~=1.3.2
|
8 |
+
sentencepiece~=0.1.97
|
9 |
-e git+https://github.com/facebookresearch/mega.git@main#egg=fairseq
|