import collections import os from typing import Dict, List import gradio as gr from index_list import read_index_list from protein_viz import get_pdb_title, render_html from search_engine import MilvusParams, ProteinSearchEngine model_repo = "ronig/protein_biencoder" available_indexes = read_index_list() engine = ProteinSearchEngine( milvus_params=MilvusParams( uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com", token=os.environ.get("MILVUS_TOKEN"), db_name="Protein", collection_name="Peptriever", ), model_repo=model_repo, ) max_results = 1000 choice_sep = " | " max_seq_length = 50 def search_and_display(seq, max_res, index_selection): n_search_res = 1024 _validate_sequence_length(seq) max_res = int(limit_n_results(max_res)) if index_selection == "All Species": index_selection = None search_res = engine.search_by_sequence( seq, n=n_search_res, organism=index_selection ) agg_search_results = aggregate_search_results(search_res, max_res) formatted_search_results = format_search_results(agg_search_results) results_options = update_dropdown_menu(agg_search_results) return formatted_search_results, results_options def _validate_sequence_length(seq): if len(seq) > max_seq_length: raise gr.Error("Only peptide input is currently supported") def limit_n_results(n): return max(min(n, max_results), 1) def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]: aggregated_by_uniprot = collections.defaultdict(list) for raw_result in raw_results: entry = select_keys( raw_result, keys=["pdb_name", "chain_id", "score", "organism", "uniprot_id", "genes"], ) uniprot_id = raw_result["uniprot_id"] if uniprot_id is not None: aggregated_by_uniprot[uniprot_id].append(entry) if len(aggregated_by_uniprot) >= max_res: return dict(aggregated_by_uniprot) return dict(aggregated_by_uniprot) def select_keys(d: dict, keys: List[str]): return {key: d[key] for key in keys} def format_search_results(agg_search_results): formatted_search_results = {} for uniprot_id, entries in agg_search_results.items(): entry = entries[0] organism = entry["organism"] score = entry["score"] genes = entry["genes"] key = f"Uniprot ID: {uniprot_id} | Organism: {organism} | Gene Names: {genes}" formatted_search_results[key] = score return formatted_search_results def update_dropdown_menu(agg_search_res): choices = [] for uniprot_id, entries in agg_search_res.items(): for entry in entries: choice = choice_sep.join( [ uniprot_id, entry["pdb_name"], entry["chain_id"], entry["genes"] or "", ] ) choices.append(choice) if choices: update = gr.update( gr.Dropdown.get_component_class_id(), choices=choices, interactive=True, value=choices[0], visible=True, ) else: update = gr.update( gr.Dropdown.get_component_class_id(), choices=choices, interactive=True, visible=False, value=None, ) return update def parse_pdb_search_result(raw_result): prot = raw_result["pdb_name"] chain = raw_result["chain_id"] value = raw_result["score"] gene_names = raw_result["genes"] species = raw_result["organism"] key = f"PDB: {prot}.{chain}" if gene_names is not None: key += f" | Genes: {gene_names} | Organism: {species}" return key, value def switch_viz(new_choice): if new_choice is None: html = "" title_update = gr.update(gr.Markdown.get_component_class_id(), visible=False) description_update = gr.update( gr.Markdown.get_component_class_id(), value=None, visible=False ) else: choice_parts = new_choice.split(choice_sep) pdb_id, chain = choice_parts[1:3] title_update = gr.update(gr.Markdown.get_component_class_id(), visible=True) pdb_title = get_pdb_title(pdb_id) new_value = f"""**PDB Title**: {pdb_title}""" description_update = gr.update( gr.Markdown.get_component_class_id(), value=new_value, visible=True ) html = render_html(pdb_id=pdb_id, chain=chain) return html, title_update, description_update with gr.Blocks() as demo: with gr.Column(): with gr.Column(): with gr.Row(): with gr.Column(): seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence") n_results = gr.Number(10, label="N Results") index_selector = gr.Dropdown( choices=available_indexes, value="All Species", multiselect=False, visible=True, label="Index", ) search_button = gr.Button("Search", variant="primary") search_results = gr.Label( num_top_classes=max_results, label="Search Results", scale=2 ) viz_header = gr.Markdown("## Visualization", visible=False) results_selector = gr.Dropdown( choices=[], multiselect=False, visible=False, label="Visualized Search Result", ) viz_body = gr.Markdown("", visible=False) protein_viz = gr.HTML( value=render_html(pdb_id=None, chain=None), label="Protein Visualization", ) gr.Examples( ["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"], inputs=[seq_input], ) search_button.click( search_and_display, inputs=[seq_input, n_results, index_selector], outputs=[search_results, results_selector], ) results_selector.change( switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body] ) if __name__ == "__main__": demo.launch()