Spaces:
Running
Running
File size: 3,418 Bytes
e873d33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import dataclasses
import math
from typing import List, Optional
import torch
from pymilvus import MilvusClient, connections
from transformers import AutoModel, AutoTokenizer
from credentials import get_token
@dataclasses.dataclass
class MilvusParams:
uri: str
token: str
db_name: str
collection_name: str
class ProteinSearchEngine:
n_dims = 128
dist_metric = "euclidean"
max_lengths = (30, 300)
def __init__(self, milvus_params: MilvusParams, model_repo: str):
self.model_repo = model_repo
self.milvus_params = milvus_params
connections.connect(
"default",
uri=milvus_params.uri,
token=milvus_params.token,
db_name=milvus_params.db_name,
)
self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_repo, use_auth_token=get_token()
)
self.model = AutoModel.from_pretrained(
self.model_repo, use_auth_token=get_token(), trust_remote_code=True
)
self.model.eval()
def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None):
max_length = self.max_lengths[0]
vec = self._embed_sequence(max_length, sequence)
response = self.search(vec, n_results=n, is_peptide=False, organism=organism)
search_results = self._format_search_results(response)
return search_results
def _embed_sequence(self, max_length, sequence):
encoded = self.tokenizer.encode_plus(
sequence,
add_special_tokens=True,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
with torch.no_grad():
vec = (
self.model.forward1(encoded.to(self.model.device))
.squeeze()
.cpu()
.numpy()
)
return vec
def _format_search_results(self, response):
search_results = []
max_dist = math.sqrt(2 * self.n_dims)
for res in response:
entry = res["entity"]
dist = math.sqrt(res["distance"])
entry["dist"] = dist
entry["score"] = (max_dist - dist) / max_dist
search_results.append(entry)
return search_results
def search(
self,
vec: List[float],
n_results: int,
is_peptide: bool,
organism: Optional[str] = None,
):
is_peptide = bool(is_peptide)
filter_str = f"is_peptide == {is_peptide}"
if organism is not None:
filter_str += f" and organism == '{organism}'"
results = self.client.search(
collection_name=self.milvus_params.collection_name,
data=[vec],
limit=n_results,
output_fields=[
"genes",
"uniprot_id",
"pdb_name",
"chain_id",
"is_peptide",
"organism",
],
filter=filter_str,
)
return results[0]
def get_organisms(self):
res = self.client.query(
collection_name=self.milvus_params.collection_name,
output_fields=["organism"],
filter="entry_id > 0",
)
return res
|