roni
updaing gradio vsion
d46e61e
raw
history blame
6.41 kB
import collections
import os
from typing import Dict, List
import gradio as gr
from index_list import read_index_list
from protein_viz import get_pdb_title, render_html
from search_engine import MilvusParams, ProteinSearchEngine
model_repo = "ronig/protein_biencoder"
available_indexes = read_index_list()
engine = ProteinSearchEngine(
milvus_params=MilvusParams(
uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com",
token=os.environ.get("MILVUS_TOKEN"),
db_name="Protein",
collection_name="Peptriever",
),
model_repo=model_repo,
)
max_results = 1000
choice_sep = " | "
max_seq_length = 50
def search_and_display(seq, max_res, index_selection):
n_search_res = 1024
_validate_sequence_length(seq)
max_res = int(limit_n_results(max_res))
if index_selection == "All Species":
index_selection = None
search_res = engine.search_by_sequence(
seq, n=n_search_res, organism=index_selection
)
agg_search_results = aggregate_search_results(search_res, max_res)
formatted_search_results = format_search_results(agg_search_results)
results_options = update_dropdown_menu(agg_search_results)
return formatted_search_results, results_options
def _validate_sequence_length(seq):
if len(seq) > max_seq_length:
raise gr.Error("Only peptide input is currently supported")
def limit_n_results(n):
return max(min(n, max_results), 1)
def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]:
aggregated_by_uniprot = collections.defaultdict(list)
for raw_result in raw_results:
entry = select_keys(
raw_result,
keys=["pdb_name", "chain_id", "score", "organism", "uniprot_id", "genes"],
)
uniprot_id = raw_result["uniprot_id"]
if uniprot_id is not None:
aggregated_by_uniprot[uniprot_id].append(entry)
if len(aggregated_by_uniprot) >= max_res:
return dict(aggregated_by_uniprot)
return dict(aggregated_by_uniprot)
def select_keys(d: dict, keys: List[str]):
return {key: d[key] for key in keys}
def format_search_results(agg_search_results):
formatted_search_results = {}
for uniprot_id, entries in agg_search_results.items():
entry = entries[0]
organism = entry["organism"]
score = entry["score"]
genes = entry["genes"]
key = f"Uniprot ID: {uniprot_id} | Organism: {organism} | Gene Names: {genes}"
formatted_search_results[key] = score
return formatted_search_results
def update_dropdown_menu(agg_search_res):
choices = []
for uniprot_id, entries in agg_search_res.items():
for entry in entries:
choice = choice_sep.join(
[
uniprot_id,
entry["pdb_name"],
entry["chain_id"],
entry["genes"] or "",
]
)
choices.append(choice)
if choices:
update = gr.update(
gr.Dropdown.get_component_class_id(),
choices=choices,
interactive=True,
value=choices[0],
visible=True,
)
else:
update = gr.update(
gr.Dropdown.get_component_class_id(),
choices=choices,
interactive=True,
visible=False,
value=None,
)
return update
def parse_pdb_search_result(raw_result):
prot = raw_result["pdb_name"]
chain = raw_result["chain_id"]
value = raw_result["score"]
gene_names = raw_result["genes"]
species = raw_result["organism"]
key = f"PDB: {prot}.{chain}"
if gene_names is not None:
key += f" | Genes: {gene_names} | Organism: {species}"
return key, value
def switch_viz(new_choice):
if new_choice is None:
html = ""
title_update = gr.update(gr.Markdown.get_component_class_id(), visible=False)
description_update = gr.update(
gr.Markdown.get_component_class_id(), value=None, visible=False
)
else:
choice_parts = new_choice.split(choice_sep)
pdb_id, chain = choice_parts[1:3]
title_update = gr.update(gr.Markdown.get_component_class_id(), visible=True)
pdb_title = get_pdb_title(pdb_id)
new_value = f"""**PDB Title**: {pdb_title}"""
description_update = gr.update(
gr.Markdown.get_component_class_id(), value=new_value, visible=True
)
html = render_html(pdb_id=pdb_id, chain=chain)
return html, title_update, description_update
with gr.Blocks() as demo:
with gr.Column():
with gr.Column():
with gr.Row():
with gr.Column():
seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence")
n_results = gr.Number(10, label="N Results")
index_selector = gr.Dropdown(
choices=available_indexes,
value="All Species",
multiselect=False,
visible=True,
label="Index",
)
search_button = gr.Button("Search", variant="primary")
search_results = gr.Label(
num_top_classes=max_results, label="Search Results", scale=2
)
viz_header = gr.Markdown("## Visualization", visible=False)
results_selector = gr.Dropdown(
choices=[],
multiselect=False,
visible=False,
label="Visualized Search Result",
)
viz_body = gr.Markdown("", visible=False)
protein_viz = gr.HTML(
value=render_html(pdb_id=None, chain=None),
label="Protein Visualization",
)
gr.Examples(
["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"],
inputs=[seq_input],
)
search_button.click(
search_and_display,
inputs=[seq_input, n_results, index_selector],
outputs=[search_results, results_selector],
)
results_selector.change(
switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body]
)
if __name__ == "__main__":
demo.launch()