geekyrakshit's picture
add: files
39b7b6a verified
raw
history blame
7.88 kB
from typing import Optional
import weave
from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
from medrag_multi_modal.assistant.llm_client import LLMClient
from medrag_multi_modal.assistant.schema import (
MedQACitation,
MedQAMCQResponse,
MedQAResponse,
)
from medrag_multi_modal.retrieval.common import SimilarityMetric
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
class MedQAAssistant(weave.Model):
"""
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
language model client, a retriever model, and a figure annotator.
!!! example "Usage Example"
```python
import weave
from dotenv import load_dotenv
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
load_dotenv()
weave.init(project_name="ml-colabs/medrag-multi-modal")
llm_client = LLMClient(model_name="gemini-1.5-flash")
retriever=MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name="grays-anatomy-chunks:v0",
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
)
figure_annotator=FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
medqa_assistant.predict(query="What is ribosome?")
```
Args:
llm_client (LLMClient): The language model client used to generate responses.
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
"""
llm_client: LLMClient
retriever: weave.Model
figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
top_k_chunks_for_query: int = 2
top_k_chunks_for_options: int = 2
rely_only_on_context: bool = True
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
@weave.op()
def retrieve_chunks_for_query(self, query: str) -> list[dict]:
retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
if not isinstance(self.retriever, BM25sRetriever):
retriever_kwargs["metric"] = self.retrieval_similarity_metric
return self.retriever.predict(query, **retriever_kwargs)
@weave.op()
def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
if not isinstance(self.retriever, BM25sRetriever):
retriever_kwargs["metric"] = self.retrieval_similarity_metric
retrieved_chunks = []
for option in options:
retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
return retrieved_chunks
@weave.op()
def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
"""
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
from a medical document and using a language model to generate the final response.
This function performs the following steps:
1. Retrieves relevant text chunks from the medical document based on the query and any provided options
using the retriever model.
2. Extracts the text and page indices from the retrieved chunks.
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
and figure descriptions.
5. Uses the language model client to generate a response based on the constructed prompts, either choosing
from provided options or generating a free-form response.
6. Returns the generated response, which includes the answer and explanation if options were provided.
The function can operate in two modes:
- Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
- Free response: When no options are provided, it generates a comprehensive response based on the context
Args:
query (str): The medical query to be answered.
options (Optional[list[str]]): The list of options to choose from.
rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
Returns:
MedQAResponse: The generated response to the query, including source information.
"""
retrieved_chunks = self.retrieve_chunks_for_query(query)
options = options or []
retrieved_chunks += self.retrieve_chunks_for_options(options)
retrieved_chunk_texts = []
page_indices = set()
for chunk in retrieved_chunks:
retrieved_chunk_texts.append(chunk["text"])
page_indices.add(int(chunk["page_idx"]))
figure_descriptions = []
if self.figure_annotator is not None:
for page_idx in page_indices:
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
page_idx
]
figure_descriptions += [
item["figure_description"] for item in figure_annotations
]
system_prompt = """You are an expert in medical science. You are given a question
and a list of excerpts from various medical documents.
"""
query = f"""# Question
{query}
"""
if len(options) > 0:
system_prompt += """\nYou are also given a list of options to choose your answer from.
You are supposed to choose the best possible option based on the context provided. You should also
explain your answer to justify why you chose that option.
"""
query += "## Options\n"
for option in options:
query += f"- {option}\n"
else:
system_prompt += "\nYou are supposed to answer the question based on the context provided."
if self.rely_only_on_context:
system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
You are not allowed to use any external knowledge to answer the question.
"""
response = self.llm_client.predict(
system_prompt=system_prompt,
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
schema=MedQAMCQResponse if len(options) > 0 else None,
)
# TODO: Add figure citations
# TODO: Add source document name from retrieved chunks as citations
citations = []
for page_idx in page_indices:
citations.append(
MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
)
return MedQAResponse(response=response, citations=citations)