Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import gradio as gr | |
import pandas as pd | |
import os | |
import subprocess | |
def predict_top_100_genes(disease_id): | |
# Initialize paths | |
input_csv_path = 'data/downstream/{}_disease.csv'.format(disease_id) | |
output_csv_path = 'data/downstream/{}_top100.csv'.format(disease_id) | |
# Check if the output CSV already exists | |
if not os.path.exists(output_csv_path): | |
# Proceed with your existing code if the output file doesn't exist | |
df = pd.read_csv('data/pretrain/disgenet_latest.csv') | |
df = df[df['proteinSeq'].notna()] | |
# Check if the disease_id is present in the dataframe | |
if disease_id not in df['diseaseId'].values: | |
return f"Error: Disease ID '{disease_id}' not found in the database. Please check the ID and try again." | |
desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0] | |
related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique() | |
df['score'] = df['proteinSeq'].isin(related_proteins).astype(int) | |
new_df = pd.DataFrame({ | |
'diseaseId': disease_id, | |
'diseaseDes': desired_diseaseDes, | |
'geneSymbol': df['geneSymbol'], | |
'proteinSeq': df['proteinSeq'], | |
'score': df['score'] | |
}).drop_duplicates().reset_index(drop=True) | |
new_df.to_csv(input_csv_path, index=False) | |
# Call the model script only if the output CSV does not exist | |
script_path = 'model.sh' | |
subprocess.run(['bash', script_path, input_csv_path, output_csv_path], check=True) | |
# Read the model output file or the existing file to get the top 100 genes | |
output_df = pd.read_csv(output_csv_path) | |
# Update here to select only the required columns and rename them | |
result_df = output_df[['geneSymbol', 'Prediction_score']].rename(columns={'geneSymbol': 'Gene', 'Prediction_score': 'Score'}).head(100) | |
return result_df | |
iface = gr.Interface( | |
fn=predict_top_100_genes, | |
inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"), | |
outputs=gr.Dataframe(label="Predicted Top 100 Related Genes"), | |
title="Gene Disease Association Prediction", | |
description = ( | |
"This AI model predicts the top 100 genes associated with a given disease based on 16,733 genes." | |
" To get started, you need a Disease ID (UMLS CUI), which can be obtained from the DisGeNET database. " | |
"\n\n**Steps to Obtain a Disease ID from DisGeNET:**\n" | |
"1. Visit the DisGeNET website: [https://www.disgenet.org/search](https://www.disgenet.org/search).\n" | |
"2. Use the search bar to enter your disease of interest. For instance, if you're interested in 'Alzheimer's Disease', type 'Alzheimer's Disease' into the search bar.\n" | |
"3. From the search results, identify the disease you're researching. The Disease ID (UMLS CUI) is listed alongside each disease name, e.g. C0002395.\n" | |
"4. Enter the Disease ID into the input box below and submit.\n\n" | |
"The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n" | |
"\n**The model will take about 18 minutes to inference a new disease.**\n" | |
) | |
) | |
iface.launch(share=True) | |