Spaces:
Running
Running
import dataclasses | |
import math | |
from typing import List, Optional | |
import torch | |
from pymilvus import MilvusClient, connections | |
from transformers import AutoModel, AutoTokenizer | |
from credentials import get_token | |
class MilvusParams: | |
uri: str | |
token: str | |
db_name: str | |
collection_name: str | |
class ProteinSearchEngine: | |
n_dims = 128 | |
dist_metric = "euclidean" | |
max_lengths = (30, 300) | |
def __init__(self, milvus_params: MilvusParams, model_repo: str): | |
self.model_repo = model_repo | |
self.milvus_params = milvus_params | |
connections.connect( | |
"default", | |
uri=milvus_params.uri, | |
token=milvus_params.token, | |
db_name=milvus_params.db_name, | |
) | |
self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_repo, use_auth_token=get_token() | |
) | |
self.model = AutoModel.from_pretrained( | |
self.model_repo, use_auth_token=get_token(), trust_remote_code=True | |
) | |
self.model.eval() | |
def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None): | |
max_length = self.max_lengths[0] | |
vec = self._embed_sequence(max_length, sequence) | |
response = self.search(vec, n_results=n, is_peptide=False, organism=organism) | |
search_results = self._format_search_results(response) | |
return search_results | |
def _embed_sequence(self, max_length, sequence): | |
encoded = self.tokenizer.encode_plus( | |
sequence, | |
add_special_tokens=True, | |
truncation=True, | |
max_length=max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
vec = ( | |
self.model.forward1(encoded.to(self.model.device)) | |
.squeeze() | |
.cpu() | |
.numpy() | |
) | |
return vec | |
def _format_search_results(self, response): | |
search_results = [] | |
max_dist = math.sqrt(2 * self.n_dims) | |
for res in response: | |
entry = res["entity"] | |
dist = math.sqrt(res["distance"]) | |
entry["dist"] = dist | |
entry["score"] = (max_dist - dist) / max_dist | |
search_results.append(entry) | |
return search_results | |
def search( | |
self, | |
vec: List[float], | |
n_results: int, | |
is_peptide: bool, | |
organism: Optional[str] = None, | |
): | |
is_peptide = bool(is_peptide) | |
filter_str = f"is_peptide == {is_peptide}" | |
if organism is not None: | |
filter_str += f" and organism == '{organism}'" | |
results = self.client.search( | |
collection_name=self.milvus_params.collection_name, | |
data=[vec], | |
limit=n_results, | |
output_fields=[ | |
"genes", | |
"uniprot_id", | |
"pdb_name", | |
"chain_id", | |
"is_peptide", | |
"organism", | |
], | |
filter=filter_str, | |
) | |
return results[0] | |
def get_organisms(self): | |
res = self.client.query( | |
collection_name=self.milvus_params.collection_name, | |
output_fields=["organism"], | |
filter="entry_id > 0", | |
) | |
return res | |