Spaces:
Running
Running
import collections | |
from typing import Dict, List | |
import gradio as gr | |
from get_index import get_engines | |
from protein_viz import get_pdb_title, render_html | |
index_repo = "ronig/siamese_protein_index" | |
model_repo = "ronig/protein_search_engine" | |
engines = get_engines(index_repo, model_repo) | |
available_indexes = list(engines.keys()) | |
app_description = """ | |
# Protein Binding Search Engine | |
This application enables a quick protein-peptide binding search based on sequences. | |
You can use it to search the full [PDB](https://www.rcsb.org/) database or in a specific organism genome. | |
""" | |
max_results = 1000 | |
choice_sep = " | " | |
def search_and_display(seq, max_res, index_selection): | |
n_search_res = 10000 | |
max_res = int(limit_n_results(max_res)) | |
engine = engines[index_selection] | |
search_res = engine.search_by_sequence(seq, n=n_search_res) | |
agg_search_results = aggregate_search_results(search_res, max_res) | |
formatted_search_results = format_search_results(agg_search_results) | |
results_options = update_dropdown_menu(agg_search_results) | |
return formatted_search_results, results_options | |
def limit_n_results(n): | |
return max(min(n, max_results), 1) | |
def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]: | |
aggregated_by_gene = collections.defaultdict(list) | |
for raw_result in raw_results: | |
entry = select_keys(raw_result, ["pdb_name", "chain_id", "score", "organism"]) | |
genes = raw_result["genes"] | |
if genes is not None: | |
gene_names = genes.split(" ") | |
for gene in gene_names: | |
aggregated_by_gene[gene].append(entry) | |
if len(aggregated_by_gene) >= max_res: | |
return dict(aggregated_by_gene) | |
return dict(aggregated_by_gene) | |
def select_keys(d: dict, keys: List[str]): | |
return {key: d[key] for key in keys} | |
def format_search_results(agg_search_results): | |
formatted_search_results = {} | |
for gene, entries in agg_search_results.items(): | |
entry = entries[0] | |
organism = entry["organism"] | |
score = entry["score"] | |
key = f"Gene: {gene} | Organism: {organism}" | |
formatted_search_results[key] = score | |
return formatted_search_results | |
def update_dropdown_menu(agg_search_res): | |
choices = [] | |
for gene, entries in agg_search_res.items(): | |
for entry in entries: | |
choice = choice_sep.join([gene, entry["pdb_name"], entry["chain_id"]]) | |
choices.append(choice) | |
if choices: | |
update = gr.Dropdown.update( | |
choices=choices, interactive=True, value=choices[0], visible=True | |
) | |
else: | |
update = gr.Dropdown.update( | |
choices=choices, interactive=True, visible=False, value=None | |
) | |
return update | |
def parse_pdb_search_result(raw_result): | |
prot = raw_result["pdb_name"] | |
chain = raw_result["chain_id"] | |
value = raw_result["score"] | |
gene_names = raw_result["genes"] | |
species = raw_result["organism"] | |
key = f"PDB: {prot}.{chain}" | |
if gene_names is not None: | |
key += f" | Genes: {gene_names} | Organism: {species}" | |
return key, value | |
def switch_viz(new_choice): | |
if new_choice is None: | |
html = "" | |
title_update = gr.Markdown.update(visible=False) | |
description_update = gr.Markdown.update(value=None, visible=False) | |
else: | |
choice_parts = new_choice.split(choice_sep) | |
pdb_id, chain = choice_parts[1:3] | |
title_update = gr.Markdown.update(visible=True) | |
pdb_title = get_pdb_title(pdb_id) | |
new_value = f"""**PDB Title**: {pdb_title}""" | |
description_update = gr.Markdown.update(value=new_value, visible=True) | |
html = render_html(pdb_id=pdb_id, chain=chain) | |
return html, title_update, description_update | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown(app_description) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence") | |
n_results = gr.Number(10, label="N Results") | |
index_selector = gr.Dropdown( | |
choices=available_indexes, | |
value="PDB", | |
multiselect=False, | |
visible=True, | |
label="Index", | |
) | |
search_button = gr.Button("Search", variant="primary") | |
search_results = gr.Label( | |
num_top_classes=max_results, label="Search Results" | |
) | |
viz_header = gr.Markdown("## Visualization", visible=False) | |
results_selector = gr.Dropdown( | |
choices=[], | |
multiselect=False, | |
visible=False, | |
label="Visualized Search Result", | |
) | |
viz_body = gr.Markdown("", visible=False) | |
protein_viz = gr.HTML( | |
value=render_html(pdb_id=None, chain=None), | |
label="Protein Visualization", | |
) | |
gr.Examples( | |
["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"], | |
inputs=[seq_input], | |
) | |
search_button.click( | |
search_and_display, | |
inputs=[seq_input, n_results, index_selector], | |
outputs=[search_results, results_selector], | |
) | |
results_selector.change( | |
switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |