File size: 3,509 Bytes
39b7b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29d72d7
39b7b6a
 
 
 
29d72d7
39b7b6a
 
 
 
29d72d7
39b7b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st

from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
from medrag_multi_modal.retrieval.text_retrieval import (
    BM25sRetriever,
    ContrieverRetriever,
    MedCPTRetriever,
    NVEmbed2Retriever,
)

# Define constants
ALL_AVAILABLE_MODELS = [
    "gemini-1.5-flash-latest",
    "gemini-1.5-pro-latest",
    "gpt-4o",
    "gpt-4o-mini",
]

# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
    label="Project Name",
    value="ml-colabs/medrag-multi-modal",
    placeholder="wandb project name",
    help="format: wandb_username/wandb_project_name",
)
chunk_dataset_id = st.sidebar.selectbox(
    label="Chunk Dataset ID",
    options=["ashwiniai/medrag-text-corpus-chunks"],
)
llm_model = st.sidebar.selectbox(
    label="LLM Model",
    options=ALL_AVAILABLE_MODELS,
)
top_k_chunks_for_query = st.sidebar.slider(
    label="Top K Chunks for Query",
    min_value=1,
    max_value=20,
    value=5,
)
top_k_chunks_for_options = st.sidebar.slider(
    label="Top K Chunks for Options",
    min_value=1,
    max_value=20,
    value=3,
)
rely_only_on_context = st.sidebar.checkbox(
    label="Rely Only on Context",
    value=False,
)
retriever_type = st.sidebar.selectbox(
    label="Retriever Type",
    options=[
        "",
        "BM25S",
        "Contriever",
        "MedCPT",
        "NV-Embed-v2",
    ],
)

if retriever_type != "":

    llm_model = LLMClient(model_name=llm_model)

    retriever = None

    if retriever_type == "BM25S":
        retriever = BM25sRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
        )
    elif retriever_type == "Contriever":
        retriever = ContrieverRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
            chunk_dataset=chunk_dataset_id,
        )
    elif retriever_type == "MedCPT":
        retriever = MedCPTRetriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
            chunk_dataset=chunk_dataset_id,
        )
    elif retriever_type == "NV-Embed-v2":
        retriever = NVEmbed2Retriever.from_index(
            index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
            chunk_dataset=chunk_dataset_id,
        )

    medqa_assistant = MedQAAssistant(
        llm_client=llm_model,
        retriever=retriever,
        top_k_chunks_for_query=top_k_chunks_for_query,
        top_k_chunks_for_options=top_k_chunks_for_options,
    )

    with st.chat_message("assistant"):
        st.markdown(
            """
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.

**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
Please consult a medical professional for any medical advice.

In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
            """,
            unsafe_allow_html=True,
        )
    query = st.chat_input("Enter your question here")
    if query:
        with st.chat_message("user"):
            st.markdown(query)
        response = medqa_assistant.predict(query=query)
        with st.chat_message("assistant"):
            st.markdown(response.response)