roni
per gene aggregation
33eb5d4
raw
history blame
5.59 kB
import collections
from typing import Dict, List
import gradio as gr
from get_index import get_engines
from protein_viz import get_pdb_title, render_html
index_repo = "ronig/siamese_protein_index"
model_repo = "ronig/protein_search_engine"
engines = get_engines(index_repo, model_repo)
available_indexes = list(engines.keys())
app_description = """
# Protein Binding Search Engine
This application enables a quick protein-peptide binding search based on sequences.
You can use it to search the full [PDB](https://www.rcsb.org/) database or in a specific organism genome.
"""
max_results = 1000
choice_sep = " | "
def search_and_display(seq, max_res, index_selection):
n_search_res = 10000
max_res = int(limit_n_results(max_res))
engine = engines[index_selection]
search_res = engine.search_by_sequence(seq, n=n_search_res)
agg_search_results = aggregate_search_results(search_res, max_res)
formatted_search_results = format_search_results(agg_search_results)
results_options = update_dropdown_menu(agg_search_results)
return formatted_search_results, results_options
def limit_n_results(n):
return max(min(n, max_results), 1)
def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]:
aggregated_by_gene = collections.defaultdict(list)
for raw_result in raw_results:
entry = select_keys(raw_result, ["pdb_name", "chain_id", "score", "organism"])
genes = raw_result["genes"]
if genes is not None:
gene_names = genes.split(" ")
for gene in gene_names:
aggregated_by_gene[gene].append(entry)
if len(aggregated_by_gene) >= max_res:
return dict(aggregated_by_gene)
return dict(aggregated_by_gene)
def select_keys(d: dict, keys: List[str]):
return {key: d[key] for key in keys}
def format_search_results(agg_search_results):
formatted_search_results = {}
for gene, entries in agg_search_results.items():
entry = entries[0]
organism = entry["organism"]
score = entry["score"]
key = f"Gene: {gene} | Organism: {organism}"
formatted_search_results[key] = score
return formatted_search_results
def update_dropdown_menu(agg_search_res):
choices = []
for gene, entries in agg_search_res.items():
for entry in entries:
choice = choice_sep.join([gene, entry["pdb_name"], entry["chain_id"]])
choices.append(choice)
if choices:
update = gr.Dropdown.update(
choices=choices, interactive=True, value=choices[0], visible=True
)
else:
update = gr.Dropdown.update(
choices=choices, interactive=True, visible=False, value=None
)
return update
def parse_pdb_search_result(raw_result):
prot = raw_result["pdb_name"]
chain = raw_result["chain_id"]
value = raw_result["score"]
gene_names = raw_result["genes"]
species = raw_result["organism"]
key = f"PDB: {prot}.{chain}"
if gene_names is not None:
key += f" | Genes: {gene_names} | Organism: {species}"
return key, value
def switch_viz(new_choice):
if new_choice is None:
html = ""
title_update = gr.Markdown.update(visible=False)
description_update = gr.Markdown.update(value=None, visible=False)
else:
choice_parts = new_choice.split(choice_sep)
pdb_id, chain = choice_parts[1:3]
title_update = gr.Markdown.update(visible=True)
pdb_title = get_pdb_title(pdb_id)
new_value = f"""**PDB Title**: {pdb_title}"""
description_update = gr.Markdown.update(value=new_value, visible=True)
html = render_html(pdb_id=pdb_id, chain=chain)
return html, title_update, description_update
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(app_description)
with gr.Column():
with gr.Row():
with gr.Column():
seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence")
n_results = gr.Number(10, label="N Results")
index_selector = gr.Dropdown(
choices=available_indexes,
value="PDB",
multiselect=False,
visible=True,
label="Index",
)
search_button = gr.Button("Search", variant="primary")
search_results = gr.Label(
num_top_classes=max_results, label="Search Results"
)
viz_header = gr.Markdown("## Visualization", visible=False)
results_selector = gr.Dropdown(
choices=[],
multiselect=False,
visible=False,
label="Visualized Search Result",
)
viz_body = gr.Markdown("", visible=False)
protein_viz = gr.HTML(
value=render_html(pdb_id=None, chain=None),
label="Protein Visualization",
)
gr.Examples(
["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"],
inputs=[seq_input],
)
search_button.click(
search_and_display,
inputs=[seq_input, n_results, index_selector],
outputs=[search_results, results_selector],
)
results_selector.change(
switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body]
)
if __name__ == "__main__":
demo.launch()