Spaces:
Running
Running
File size: 7,877 Bytes
39b7b6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from typing import Optional
import weave
from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
from medrag_multi_modal.assistant.llm_client import LLMClient
from medrag_multi_modal.assistant.schema import (
MedQACitation,
MedQAMCQResponse,
MedQAResponse,
)
from medrag_multi_modal.retrieval.common import SimilarityMetric
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
class MedQAAssistant(weave.Model):
"""
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
language model client, a retriever model, and a figure annotator.
!!! example "Usage Example"
```python
import weave
from dotenv import load_dotenv
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
load_dotenv()
weave.init(project_name="ml-colabs/medrag-multi-modal")
llm_client = LLMClient(model_name="gemini-1.5-flash")
retriever=MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name="grays-anatomy-chunks:v0",
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
)
figure_annotator=FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
medqa_assistant.predict(query="What is ribosome?")
```
Args:
llm_client (LLMClient): The language model client used to generate responses.
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
"""
llm_client: LLMClient
retriever: weave.Model
figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
top_k_chunks_for_query: int = 2
top_k_chunks_for_options: int = 2
rely_only_on_context: bool = True
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
@weave.op()
def retrieve_chunks_for_query(self, query: str) -> list[dict]:
retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
if not isinstance(self.retriever, BM25sRetriever):
retriever_kwargs["metric"] = self.retrieval_similarity_metric
return self.retriever.predict(query, **retriever_kwargs)
@weave.op()
def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
if not isinstance(self.retriever, BM25sRetriever):
retriever_kwargs["metric"] = self.retrieval_similarity_metric
retrieved_chunks = []
for option in options:
retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
return retrieved_chunks
@weave.op()
def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
"""
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
from a medical document and using a language model to generate the final response.
This function performs the following steps:
1. Retrieves relevant text chunks from the medical document based on the query and any provided options
using the retriever model.
2. Extracts the text and page indices from the retrieved chunks.
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
and figure descriptions.
5. Uses the language model client to generate a response based on the constructed prompts, either choosing
from provided options or generating a free-form response.
6. Returns the generated response, which includes the answer and explanation if options were provided.
The function can operate in two modes:
- Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
- Free response: When no options are provided, it generates a comprehensive response based on the context
Args:
query (str): The medical query to be answered.
options (Optional[list[str]]): The list of options to choose from.
rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
Returns:
MedQAResponse: The generated response to the query, including source information.
"""
retrieved_chunks = self.retrieve_chunks_for_query(query)
options = options or []
retrieved_chunks += self.retrieve_chunks_for_options(options)
retrieved_chunk_texts = []
page_indices = set()
for chunk in retrieved_chunks:
retrieved_chunk_texts.append(chunk["text"])
page_indices.add(int(chunk["page_idx"]))
figure_descriptions = []
if self.figure_annotator is not None:
for page_idx in page_indices:
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
page_idx
]
figure_descriptions += [
item["figure_description"] for item in figure_annotations
]
system_prompt = """You are an expert in medical science. You are given a question
and a list of excerpts from various medical documents.
"""
query = f"""# Question
{query}
"""
if len(options) > 0:
system_prompt += """\nYou are also given a list of options to choose your answer from.
You are supposed to choose the best possible option based on the context provided. You should also
explain your answer to justify why you chose that option.
"""
query += "## Options\n"
for option in options:
query += f"- {option}\n"
else:
system_prompt += "\nYou are supposed to answer the question based on the context provided."
if self.rely_only_on_context:
system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
You are not allowed to use any external knowledge to answer the question.
"""
response = self.llm_client.predict(
system_prompt=system_prompt,
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
schema=MedQAMCQResponse if len(options) > 0 else None,
)
# TODO: Add figure citations
# TODO: Add source document name from retrieved chunks as citations
citations = []
for page_idx in page_indices:
citations.append(
MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
)
return MedQAResponse(response=response, citations=citations)
|