from typing import Optional import weave from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage from medrag_multi_modal.assistant.llm_client import LLMClient from medrag_multi_modal.assistant.schema import ( MedQACitation, MedQAMCQResponse, MedQAResponse, ) from medrag_multi_modal.retrieval.common import SimilarityMetric from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever class MedQAAssistant(weave.Model): """ `MedQAAssistant` is a class designed to assist with medical queries by leveraging a language model client, a retriever model, and a figure annotator. !!! example "Usage Example" ```python import weave from dotenv import load_dotenv from medrag_multi_modal.assistant import ( FigureAnnotatorFromPageImage, LLMClient, MedQAAssistant, ) from medrag_multi_modal.retrieval import MedCPTRetriever load_dotenv() weave.init(project_name="ml-colabs/medrag-multi-modal") llm_client = LLMClient(model_name="gemini-1.5-flash") retriever=MedCPTRetriever.from_wandb_artifact( chunk_dataset_name="grays-anatomy-chunks:v0", index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", ) figure_annotator=FigureAnnotatorFromPageImage( figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), structured_output_llm_client=LLMClient(model_name="gpt-4o"), image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", ) medqa_assistant = MedQAAssistant( llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator ) medqa_assistant.predict(query="What is ribosome?") ``` Args: llm_client (LLMClient): The language model client used to generate responses. retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query. top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options. retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. """ llm_client: LLMClient retriever: weave.Model figure_annotator: Optional[FigureAnnotatorFromPageImage] = None top_k_chunks_for_query: int = 2 top_k_chunks_for_options: int = 2 rely_only_on_context: bool = True retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE @weave.op() def retrieve_chunks_for_query(self, query: str) -> list[dict]: retriever_kwargs = {"top_k": self.top_k_chunks_for_query} if not isinstance(self.retriever, BM25sRetriever): retriever_kwargs["metric"] = self.retrieval_similarity_metric return self.retriever.predict(query, **retriever_kwargs) @weave.op() def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]: retriever_kwargs = {"top_k": self.top_k_chunks_for_options} if not isinstance(self.retriever, BM25sRetriever): retriever_kwargs["metric"] = self.retrieval_similarity_metric retrieved_chunks = [] for option in options: retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs) return retrieved_chunks @weave.op() def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse: """ Generates a response to a medical query by retrieving relevant text chunks and figure descriptions from a medical document and using a language model to generate the final response. This function performs the following steps: 1. Retrieves relevant text chunks from the medical document based on the query and any provided options using the retriever model. 2. Extracts the text and page indices from the retrieved chunks. 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks, and figure descriptions. 5. Uses the language model client to generate a response based on the constructed prompts, either choosing from provided options or generating a free-form response. 6. Returns the generated response, which includes the answer and explanation if options were provided. The function can operate in two modes: - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice - Free response: When no options are provided, it generates a comprehensive response based on the context Args: query (str): The medical query to be answered. options (Optional[list[str]]): The list of options to choose from. rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation. Returns: MedQAResponse: The generated response to the query, including source information. """ retrieved_chunks = self.retrieve_chunks_for_query(query) options = options or [] retrieved_chunks += self.retrieve_chunks_for_options(options) retrieved_chunk_texts = [] page_indices = set() for chunk in retrieved_chunks: retrieved_chunk_texts.append(chunk["text"]) page_indices.add(int(chunk["page_idx"])) figure_descriptions = [] if self.figure_annotator is not None: for page_idx in page_indices: figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ page_idx ] figure_descriptions += [ item["figure_description"] for item in figure_annotations ] system_prompt = """You are an expert in medical science. You are given a question and a list of excerpts from various medical documents. """ query = f"""# Question {query} """ if len(options) > 0: system_prompt += """\nYou are also given a list of options to choose your answer from. You are supposed to choose the best possible option based on the context provided. You should also explain your answer to justify why you chose that option. """ query += "## Options\n" for option in options: query += f"- {option}\n" else: system_prompt += "\nYou are supposed to answer the question based on the context provided." if self.rely_only_on_context: system_prompt += """\n\nYou are only allowed to use the context provided to answer the question. You are not allowed to use any external knowledge to answer the question. """ response = self.llm_client.predict( system_prompt=system_prompt, user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], schema=MedQAMCQResponse if len(options) > 0 else None, ) # TODO: Add figure citations # TODO: Add source document name from retrieved chunks as citations citations = [] for page_idx in page_indices: citations.append( MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy") ) return MedQAResponse(response=response, citations=citations)