diff --git a/README.md b/README.md index 8fe83325c94cba1f9d51e6fbe2de518e3603541f..6d8f70c3997c29c9c651be8fc13a2ce108ae0f95 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ --- -title: Medrag Multi Modal -emoji: 🏆 -colorFrom: green -colorTo: yellow +title: MedRAG Multi-Modal +emoji: 🩺 +colorFrom: blue +colorTo: pink sdk: streamlit -sdk_version: 1.40.0 +sdk_version: "1.39.0" app_file: app.py pinned: false -short_description: Multi-modal assistant for medical professionals --- +# MedRAG Multi-Modal -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Multi-modal RAG for medical docmain. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb47e6591b88c69e3efd1994cf1fd99b9e1005a --- /dev/null +++ b/app.py @@ -0,0 +1,114 @@ +import streamlit as st + +from medrag_multi_modal.assistant import LLMClient, MedQAAssistant +from medrag_multi_modal.retrieval.text_retrieval import ( + BM25sRetriever, + ContrieverRetriever, + MedCPTRetriever, + NVEmbed2Retriever, +) + +# Define constants +ALL_AVAILABLE_MODELS = [ + "gemini-1.5-flash-latest", + "gemini-1.5-pro-latest", + "gpt-4o", + "gpt-4o-mini", +] + +# Sidebar for configuration settings +st.sidebar.title("Configuration Settings") +project_name = st.sidebar.text_input( + label="Project Name", + value="ml-colabs/medrag-multi-modal", + placeholder="wandb project name", + help="format: wandb_username/wandb_project_name", +) +chunk_dataset_id = st.sidebar.selectbox( + label="Chunk Dataset ID", + options=["ashwiniai/medrag-text-corpus-chunks"], +) +llm_model = st.sidebar.selectbox( + label="LLM Model", + options=ALL_AVAILABLE_MODELS, +) +top_k_chunks_for_query = st.sidebar.slider( + label="Top K Chunks for Query", + min_value=1, + max_value=20, + value=5, +) +top_k_chunks_for_options = st.sidebar.slider( + label="Top K Chunks for Options", + min_value=1, + max_value=20, + value=3, +) +rely_only_on_context = st.sidebar.checkbox( + label="Rely Only on Context", + value=False, +) +retriever_type = st.sidebar.selectbox( + label="Retriever Type", + options=[ + "", + "BM25S", + "Contriever", + "MedCPT", + "NV-Embed-v2", + ], +) + +if retriever_type != "": + + llm_model = LLMClient(model_name=llm_model) + + retriever = None + + if retriever_type == "BM25S": + retriever = BM25sRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" + ) + elif retriever_type == "Contriever": + retriever = ContrieverRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", + chunk_dataset_id=chunk_dataset_id, + ) + elif retriever_type == "MedCPT": + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset_id=chunk_dataset_id, + ) + elif retriever_type == "NV-Embed-v2": + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset_id=chunk_dataset_id, + ) + + medqa_assistant = MedQAAssistant( + llm_client=llm_model, + retriever=retriever, + top_k_chunks_for_query=top_k_chunks_for_query, + top_k_chunks_for_options=top_k_chunks_for_options, + ) + + with st.chat_message("assistant"): + st.markdown( + """ +Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences. +I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge. + +**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions. +Please consult a medical professional for any medical advice. + +In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal). + """, + unsafe_allow_html=True, + ) + query = st.chat_input("Enter your question here") + if query: + with st.chat_message("user"): + st.markdown(query) + response = medqa_assistant.predict(query=query) + with st.chat_message("assistant"): + st.markdown(response.response) diff --git a/medrag_multi_modal/__init__.py b/medrag_multi_modal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fa49d0b06f6966ca8571b186b6f5df8731fe885 Binary files /dev/null and b/medrag_multi_modal/__pycache__/__init__.cpython-310.pyc differ diff --git a/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc b/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c00481373ca79b82b61d60014ecb8a249882d9 Binary files /dev/null and b/medrag_multi_modal/__pycache__/__init__.cpython-39.pyc differ diff --git a/medrag_multi_modal/__pycache__/cli.cpython-310.pyc b/medrag_multi_modal/__pycache__/cli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06eb6b04321081a1f5a30045cf0b2659f112a6f7 Binary files /dev/null and b/medrag_multi_modal/__pycache__/cli.cpython-310.pyc differ diff --git a/medrag_multi_modal/__pycache__/utils.cpython-310.pyc b/medrag_multi_modal/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8445b913c68d810a2bb8f33c3d53a57797a1ce8 Binary files /dev/null and b/medrag_multi_modal/__pycache__/utils.cpython-310.pyc differ diff --git a/medrag_multi_modal/__pycache__/utils.cpython-39.pyc b/medrag_multi_modal/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40139b2fb48e976667ac4f1310bcaca54af718df Binary files /dev/null and b/medrag_multi_modal/__pycache__/utils.cpython-39.pyc differ diff --git a/medrag_multi_modal/assistant/__init__.py b/medrag_multi_modal/assistant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bef57a938785df32db1cf835ee8b1a0fbc7d276 --- /dev/null +++ b/medrag_multi_modal/assistant/__init__.py @@ -0,0 +1,5 @@ +from .figure_annotation import FigureAnnotatorFromPageImage +from .llm_client import ClientType, LLMClient +from .medqa_assistant import MedQAAssistant + +__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"] diff --git a/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35f51f75db22f247bbd5191203330d3da55ac60a Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-310.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed6e54de0a2027b22a8047ee6d2df9d23a89b8ea Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/__init__.cpython-39.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e3e45e700ee14de0853ef095151dd6cb603d44 Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-310.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..684968207bc589f187fc8a5cefa99875fd4e253e Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/figure_annotation.cpython-39.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9976b496c227a3001cb51628d867abb4492dc394 Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-310.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7256ed212132d24cc59b084221f8fadc5c11ea6b Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/llm_client.cpython-39.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793cde5aa95a66b382e6e05814385bbd7664c6ed Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/medqa_assistant.cpython-310.pyc differ diff --git a/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc b/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7c6c3be809cb097ff3154bae6daf11a2ea8d962 Binary files /dev/null and b/medrag_multi_modal/assistant/__pycache__/schema.cpython-310.pyc differ diff --git a/medrag_multi_modal/assistant/figure_annotation.py b/medrag_multi_modal/assistant/figure_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3838004688355f117e922720b2c8558917e0e0 --- /dev/null +++ b/medrag_multi_modal/assistant/figure_annotation.py @@ -0,0 +1,147 @@ +import os +from glob import glob +from typing import Optional, Union + +import cv2 +import weave +from PIL import Image + +from medrag_multi_modal.assistant.llm_client import LLMClient +from medrag_multi_modal.assistant.schema import FigureAnnotations +from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file + + +class FigureAnnotatorFromPageImage(weave.Model): + """ + `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate + figures from a page image of a scientific textbook. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.assistant import ( + FigureAnnotatorFromPageImage, LLMClient + ) + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + figure_annotator = FigureAnnotatorFromPageImage( + figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), + structured_output_llm_client=LLMClient(model_name="gpt-4o"), + image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", + ) + annotations = figure_annotator.predict(page_idx=34) + ``` + + Args: + figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations + from the page image. + structured_output_llm_client (LLMClient): An LLM client used to convert the extracted + annotations into a structured format. + image_artifact_address (Optional[str]): The address of the image artifact containing the + page images. + """ + + figure_extraction_llm_client: LLMClient + structured_output_llm_client: LLMClient + _artifact_dir: str + + def __init__( + self, + figure_extraction_llm_client: LLMClient, + structured_output_llm_client: LLMClient, + image_artifact_address: Optional[str] = None, + ): + super().__init__( + figure_extraction_llm_client=figure_extraction_llm_client, + structured_output_llm_client=structured_output_llm_client, + ) + self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset") + + @weave.op() + def annotate_figures( + self, page_image: Image.Image + ) -> dict[str, Union[Image.Image, str]]: + annotation = self.figure_extraction_llm_client.predict( + system_prompt=""" +You are an expert in the domain of scientific textbooks, especially medical texts. +You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy. +You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc. +Then you are to identify the figure IDs associated with each figure in the page image. +Then, you are to extract only the exact figure descriptions from the page image. +You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object. + +Here are some clues you need to follow: +1. Figure IDs are unique identifiers for each figure in the page image. +2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure. +3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc. +4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID. +5. The text in the page image is written in English and is present in a two-column format. +6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space. + You are to carefully identify all the figures in the page image. +7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side + or one above the other. +8. The figures may or may not have a distinct border against a white background. +10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is. +""", + user_prompt=[page_image], + ) + return {"page_image": page_image, "annotations": annotation} + + @weave.op + def extract_structured_output(self, annotations: str) -> FigureAnnotations: + return self.structured_output_llm_client.predict( + system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.", + user_prompt=[annotations], + schema=FigureAnnotations, + ) + + @weave.op() + def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]: + """ + Predicts figure annotations for a specific page in a document. + + This function retrieves the artifact directory from the given image artifact address, + reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata + to find the specified page index. If the page index matches, it reads the page image + and associated figure images, and then uses the `annotate_figures` method to extract + figure annotations from the page image. The extracted annotations are then structured + using the `extract_structured_output` method and returned as a dictionary. + + Args: + page_idx (int): The index of the page to annotate. + image_artifact_address (str): The address of the image artifact containing the + page images. + + Returns: + dict: A dictionary containing the page index as the key and the extracted figure + annotations as the value. + """ + + metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl")) + annotations = {} + for item in metadata: + if item["page_idx"] == page_idx: + page_image_file = os.path.join( + self._artifact_dir, f"page{item['page_idx']}.png" + ) + figure_image_files = glob( + os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png") + ) + if len(figure_image_files) > 0: + page_image = cv2.imread(page_image_file) + page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB) + page_image = Image.fromarray(page_image) + figure_extracted_annotations = self.annotate_figures( + page_image=page_image + ) + figure_extracted_annotations = self.extract_structured_output( + figure_extracted_annotations["annotations"] + ).model_dump() + annotations[item["page_idx"]] = figure_extracted_annotations[ + "annotations" + ] + break + return annotations diff --git a/medrag_multi_modal/assistant/llm_client.py b/medrag_multi_modal/assistant/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8ff5a637a5b342ac1689dc85fee903dff64b27 --- /dev/null +++ b/medrag_multi_modal/assistant/llm_client.py @@ -0,0 +1,245 @@ +import json +import os +from enum import Enum +from typing import Any, Optional, Union + +import instructor +import weave +from PIL import Image + +from ..utils import base64_encode_image + + +class ClientType(str, Enum): + GEMINI = "gemini" + MISTRAL = "mistral" + OPENAI = "openai" + + +GOOGLE_MODELS = [ + "gemini-1.0-pro-latest", + "gemini-1.0-pro", + "gemini-pro", + "gemini-1.0-pro-001", + "gemini-1.0-pro-vision-latest", + "gemini-pro-vision", + "gemini-1.5-pro-latest", + "gemini-1.5-pro-001", + "gemini-1.5-pro-002", + "gemini-1.5-pro", + "gemini-1.5-pro-exp-0801", + "gemini-1.5-pro-exp-0827", + "gemini-1.5-flash-latest", + "gemini-1.5-flash-001", + "gemini-1.5-flash-001-tuning", + "gemini-1.5-flash", + "gemini-1.5-flash-exp-0827", + "gemini-1.5-flash-002", + "gemini-1.5-flash-8b", + "gemini-1.5-flash-8b-001", + "gemini-1.5-flash-8b-latest", + "gemini-1.5-flash-8b-exp-0827", + "gemini-1.5-flash-8b-exp-0924", +] + +MISTRAL_MODELS = [ + "ministral-3b-latest", + "ministral-8b-latest", + "mistral-large-latest", + "mistral-small-latest", + "codestral-latest", + "pixtral-12b-2409", + "open-mistral-nemo", + "open-codestral-mamba", + "open-mistral-7b", + "open-mixtral-8x7b", + "open-mixtral-8x22b", +] + +OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"] + + +class LLMClient(weave.Model): + """ + LLMClient is a class that interfaces with different large language model (LLM) providers + such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with + these different APIs and provides a unified interface for making predictions. + + Args: + model_name (str): The name of the model to be used for predictions. + client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI). + If not provided, it is inferred from the model_name. + """ + + model_name: str + client_type: Optional[ClientType] + + def __init__(self, model_name: str, client_type: Optional[ClientType] = None): + if client_type is None: + if model_name in GOOGLE_MODELS: + client_type = ClientType.GEMINI + elif model_name in MISTRAL_MODELS: + client_type = ClientType.MISTRAL + elif model_name in OPENAI_MODELS: + client_type = ClientType.OPENAI + else: + raise ValueError(f"Invalid model name: {model_name}") + super().__init__(model_name=model_name, client_type=client_type) + + @weave.op() + def execute_gemini_sdk( + self, + user_prompt: Union[str, list[str]], + system_prompt: Optional[Union[str, list[str]]] = None, + schema: Optional[Any] = None, + ) -> Union[str, Any]: + import google.generativeai as genai + from google.generativeai.types import HarmBlockThreshold, HarmCategory + + system_prompt = ( + [system_prompt] if isinstance(system_prompt, str) else system_prompt + ) + user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt + + genai.configure(api_key=os.environ.get("GOOGLE_API_KEY")) + model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt) + generation_config = ( + None + if schema is None + else genai.GenerationConfig( + response_mime_type="application/json", response_schema=schema + ) + ) + response = model.generate_content( + user_prompt, + generation_config=generation_config, + # This is necessary in order to answer questions about anatomy, sexual diseases, + # medical devices, medicines, etc. + safety_settings={ + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + }, + ) + return response.text if schema is None else json.loads(response.text) + + @weave.op() + def execute_mistral_sdk( + self, + user_prompt: Union[str, list[str]], + system_prompt: Optional[Union[str, list[str]]] = None, + schema: Optional[Any] = None, + ) -> Union[str, Any]: + from mistralai import Mistral + + system_prompt = ( + [system_prompt] if isinstance(system_prompt, str) else system_prompt + ) + user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt + system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt] + user_messages = [] + for prompt in user_prompt: + if isinstance(prompt, Image.Image): + user_messages.append( + { + "type": "image_url", + "image_url": base64_encode_image(prompt, "image/png"), + } + ) + else: + user_messages.append({"type": "text", "text": prompt}) + messages = [ + {"role": "system", "content": system_messages}, + {"role": "user", "content": user_messages}, + ] + + client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY")) + client = instructor.from_mistral(client) if schema is not None else client + + if schema is None: + raise NotImplementedError( + "Mistral does not support structured output using a schema" + ) + else: + response = client.chat.complete(model=self.model_name, messages=messages) + return response.choices[0].message.content + + @weave.op() + def execute_openai_sdk( + self, + user_prompt: Union[str, list[str]], + system_prompt: Optional[Union[str, list[str]]] = None, + schema: Optional[Any] = None, + ) -> Union[str, Any]: + from openai import OpenAI + + system_prompt = ( + [system_prompt] if isinstance(system_prompt, str) else system_prompt + ) + user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt + + system_messages = [ + {"role": "system", "content": prompt} for prompt in system_prompt + ] + user_messages = [] + for prompt in user_prompt: + if isinstance(prompt, Image.Image): + user_messages.append( + { + "type": "image_url", + "image_url": { + "url": base64_encode_image(prompt, "image/png"), + }, + }, + ) + else: + user_messages.append({"type": "text", "text": prompt}) + messages = system_messages + [{"role": "user", "content": user_messages}] + + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + if schema is None: + completion = client.chat.completions.create( + model=self.model_name, messages=messages + ) + return completion.choices[0].message.content + + completion = weave.op()(client.beta.chat.completions.parse)( + model=self.model_name, messages=messages, response_format=schema + ) + return completion.choices[0].message.parsed + + @weave.op() + def predict( + self, + user_prompt: Union[str, list[str]], + system_prompt: Optional[Union[str, list[str]]] = None, + schema: Optional[Any] = None, + ) -> Union[str, Any]: + """ + Predicts the response from a language model based on the provided prompts and schema. + + This function determines the client type and calls the appropriate SDK execution function + to get the response from the language model. It supports multiple client types including + GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding + execution function with the provided user and system prompts, and an optional schema. + + Args: + user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model. + system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model. + schema (Optional[Any]): The schema to be used for parsing the response, if applicable. + + Returns: + Union[str, Any]: The response from the language model, which could be a string or any other type + depending on the schema provided. + + Raises: + ValueError: If the client type is invalid. + """ + if self.client_type == ClientType.GEMINI: + return self.execute_gemini_sdk(user_prompt, system_prompt, schema) + elif self.client_type == ClientType.MISTRAL: + return self.execute_mistral_sdk(user_prompt, system_prompt, schema) + elif self.client_type == ClientType.OPENAI: + return self.execute_openai_sdk(user_prompt, system_prompt, schema) + else: + raise ValueError(f"Invalid client type: {self.client_type}") diff --git a/medrag_multi_modal/assistant/medqa_assistant.py b/medrag_multi_modal/assistant/medqa_assistant.py new file mode 100644 index 0000000000000000000000000000000000000000..95cc5e539958e9046aee6f69045f85bb86f1cf37 --- /dev/null +++ b/medrag_multi_modal/assistant/medqa_assistant.py @@ -0,0 +1,174 @@ +from typing import Optional + +import weave + +from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage +from medrag_multi_modal.assistant.llm_client import LLMClient +from medrag_multi_modal.assistant.schema import ( + MedQACitation, + MedQAMCQResponse, + MedQAResponse, +) +from medrag_multi_modal.retrieval.common import SimilarityMetric +from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + +class MedQAAssistant(weave.Model): + """ + `MedQAAssistant` is a class designed to assist with medical queries by leveraging a + language model client, a retriever model, and a figure annotator. + + !!! example "Usage Example" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.assistant import ( + FigureAnnotatorFromPageImage, + LLMClient, + MedQAAssistant, + ) + from medrag_multi_modal.retrieval import MedCPTRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + + llm_client = LLMClient(model_name="gemini-1.5-flash") + + retriever=MedCPTRetriever.from_wandb_artifact( + chunk_dataset_name="grays-anatomy-chunks:v0", + index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", + ) + + figure_annotator=FigureAnnotatorFromPageImage( + figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), + structured_output_llm_client=LLMClient(model_name="gpt-4o"), + image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", + ) + medqa_assistant = MedQAAssistant( + llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator + ) + medqa_assistant.predict(query="What is ribosome?") + ``` + + Args: + llm_client (LLMClient): The language model client used to generate responses. + retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. + figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. + top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query. + top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options. + retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. + """ + + llm_client: LLMClient + retriever: weave.Model + figure_annotator: Optional[FigureAnnotatorFromPageImage] = None + top_k_chunks_for_query: int = 2 + top_k_chunks_for_options: int = 2 + rely_only_on_context: bool = True + retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE + + @weave.op() + def retrieve_chunks_for_query(self, query: str) -> list[dict]: + retriever_kwargs = {"top_k": self.top_k_chunks_for_query} + if not isinstance(self.retriever, BM25sRetriever): + retriever_kwargs["metric"] = self.retrieval_similarity_metric + return self.retriever.predict(query, **retriever_kwargs) + + @weave.op() + def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]: + retriever_kwargs = {"top_k": self.top_k_chunks_for_options} + if not isinstance(self.retriever, BM25sRetriever): + retriever_kwargs["metric"] = self.retrieval_similarity_metric + retrieved_chunks = [] + for option in options: + retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs) + return retrieved_chunks + + @weave.op() + def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse: + """ + Generates a response to a medical query by retrieving relevant text chunks and figure descriptions + from a medical document and using a language model to generate the final response. + + This function performs the following steps: + 1. Retrieves relevant text chunks from the medical document based on the query and any provided options + using the retriever model. + 2. Extracts the text and page indices from the retrieved chunks. + 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. + 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks, + and figure descriptions. + 5. Uses the language model client to generate a response based on the constructed prompts, either choosing + from provided options or generating a free-form response. + 6. Returns the generated response, which includes the answer and explanation if options were provided. + + The function can operate in two modes: + - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice + - Free response: When no options are provided, it generates a comprehensive response based on the context + + Args: + query (str): The medical query to be answered. + options (Optional[list[str]]): The list of options to choose from. + rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation. + + Returns: + MedQAResponse: The generated response to the query, including source information. + """ + retrieved_chunks = self.retrieve_chunks_for_query(query) + options = options or [] + retrieved_chunks += self.retrieve_chunks_for_options(options) + + retrieved_chunk_texts = [] + page_indices = set() + for chunk in retrieved_chunks: + retrieved_chunk_texts.append(chunk["text"]) + page_indices.add(int(chunk["page_idx"])) + + figure_descriptions = [] + if self.figure_annotator is not None: + for page_idx in page_indices: + figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ + page_idx + ] + figure_descriptions += [ + item["figure_description"] for item in figure_annotations + ] + + system_prompt = """You are an expert in medical science. You are given a question +and a list of excerpts from various medical documents. + """ + query = f"""# Question +{query} + """ + + if len(options) > 0: + system_prompt += """\nYou are also given a list of options to choose your answer from. +You are supposed to choose the best possible option based on the context provided. You should also +explain your answer to justify why you chose that option. +""" + query += "## Options\n" + for option in options: + query += f"- {option}\n" + else: + system_prompt += "\nYou are supposed to answer the question based on the context provided." + + if self.rely_only_on_context: + system_prompt += """\n\nYou are only allowed to use the context provided to answer the question. +You are not allowed to use any external knowledge to answer the question. +""" + + response = self.llm_client.predict( + system_prompt=system_prompt, + user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], + schema=MedQAMCQResponse if len(options) > 0 else None, + ) + + # TODO: Add figure citations + # TODO: Add source document name from retrieved chunks as citations + citations = [] + for page_idx in page_indices: + citations.append( + MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy") + ) + + return MedQAResponse(response=response, citations=citations) diff --git a/medrag_multi_modal/assistant/schema.py b/medrag_multi_modal/assistant/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4ca6cfb280436a70aa04a50e7d48bd6dbc08f1 --- /dev/null +++ b/medrag_multi_modal/assistant/schema.py @@ -0,0 +1,27 @@ +from typing import Union + +from pydantic import BaseModel + + +class FigureAnnotation(BaseModel): + figure_id: str + figure_description: str + + +class FigureAnnotations(BaseModel): + annotations: list[FigureAnnotation] + + +class MedQAMCQResponse(BaseModel): + answer: str + explanation: str + + +class MedQACitation(BaseModel): + page_number: int + document_name: str + + +class MedQAResponse(BaseModel): + response: Union[str, MedQAMCQResponse] + citations: list[MedQACitation] diff --git a/medrag_multi_modal/cli.py b/medrag_multi_modal/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..419362447c14f4a3994fd7563818603bd271e24a --- /dev/null +++ b/medrag_multi_modal/cli.py @@ -0,0 +1,68 @@ +import argparse +import os +import subprocess +import sys + + +def main(): + parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI") + subparsers = parser.add_subparsers(dest="command", required=True) + + # Run subcommand + run_parser = subparsers.add_parser("run", help="Run the Streamlit application") + run_parser.add_argument( + "--port", type=int, default=8501, help="Port to run Streamlit on" + ) + + # Evaluate subcommand + eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests") + eval_parser.add_argument( + "--test-file", + default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"), + help="Path to test file", + ) + eval_parser.add_argument( + "--test-case", + type=str, + help="Only run tests which match the given substring expression", + ) + eval_parser.add_argument( + "--model-name", + type=str, + default="gemini-1.5-flash", + help="Model name to use for evaluation", + ) + + args = parser.parse_args() + + if args.command == "run": + subprocess.run( + [ + sys.executable, + "-m", + "streamlit", + "run", + "app.py", + "--server.port", + str(args.port), + ] + ) + + elif args.command == "evaluate": + test_file = ( + args.test_file + "::" + args.test_case if args.test_case else args.test_file + ) + cmd = [ + sys.executable, + "-m", + "pytest", + "-s", + test_file, + "-v", + f"--model-name={args.model_name}", + ] + subprocess.run(cmd) + + +if __name__ == "__main__": + main() diff --git a/medrag_multi_modal/document_loader/__init__.py b/medrag_multi_modal/document_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c48cb0307455aa952a0a060aba96d0964d1683a --- /dev/null +++ b/medrag_multi_modal/document_loader/__init__.py @@ -0,0 +1,25 @@ +from .image_loader import ( + FitzPILImageLoader, + MarkerImageLoader, + PDF2ImageLoader, + PDFPlumberImageLoader, + PyMuPDFImageLoader, +) +from .text_loader import ( + MarkerTextLoader, + PDFPlumberTextLoader, + PyMuPDF4LLMTextLoader, + PyPDF2TextLoader, +) + +__all__ = [ + "PyMuPDF4LLMTextLoader", + "PyPDF2TextLoader", + "PDFPlumberTextLoader", + "MarkerTextLoader", + "PDF2ImageLoader", + "MarkerImageLoader", + "PDFPlumberImageLoader", + "PyMuPDFImageLoader", + "FitzPILImageLoader", +] diff --git a/medrag_multi_modal/document_loader/image_loader/__init__.py b/medrag_multi_modal/document_loader/image_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0f43d7477e0d2ab5205f4861d14564c9e8597b --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/__init__.py @@ -0,0 +1,13 @@ +from .fitzpil_img_loader import FitzPILImageLoader +from .marker_img_loader import MarkerImageLoader +from .pdf2image_img_loader import PDF2ImageLoader +from .pdfplumber_img_loader import PDFPlumberImageLoader +from .pymupdf_img_loader import PyMuPDFImageLoader + +__all__ = [ + "PDF2ImageLoader", + "MarkerImageLoader", + "PDFPlumberImageLoader", + "PyMuPDFImageLoader", + "FitzPILImageLoader", +] diff --git a/medrag_multi_modal/document_loader/image_loader/base_img_loader.py b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc99e99fb6102812b4106e3a5ae6eae73a1836e --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/base_img_loader.py @@ -0,0 +1,180 @@ +import asyncio +import os +from abc import abstractmethod +from glob import glob +from typing import Dict, List, Optional + +import huggingface_hub +import jsonlines +import rich +from datasets import ( + Dataset, + Features, + Image, + Sequence, + Value, + concatenate_datasets, + load_dataset, +) + +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) + + +class BaseImageLoader(BaseTextLoader): + def __init__(self, url: str, document_name: str, document_file_path: str): + super().__init__(url, document_name, document_file_path) + + @abstractmethod + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, str]: + """ + Abstract method to process a single page of the PDF and extract the image data. + + Overwrite this method in the subclass to provide the actual implementation and + processing logic for each page of the PDF using various PDF processing libraries. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted images. + **kwargs: Additional keyword arguments that may be used by underlying libraries. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + pass + + def save_as_dataset( + self, + start_page: int, + end_page: int, + image_save_dir: str, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + ): + features = Features( + { + "page_image": Image(decode=True), + "page_figure_images": Sequence(Image(decode=True)), + "document_name": Value(dtype="string"), + "page_idx": Value(dtype="int32"), + } + ) + + all_examples = [] + for page_idx in range(start_page, end_page): + page_image_file_paths = glob( + os.path.join(image_save_dir, f"page{page_idx}*.png") + ) + if len(page_image_file_paths) > 0: + page_image_path = page_image_file_paths[0] + figure_image_paths = [ + image_file_path + for image_file_path in glob( + os.path.join(image_save_dir, f"page{page_idx}*_fig*.png") + ) + ] + + example = { + "page_image": page_image_path, + "page_figure_images": figure_image_paths, + "document_name": self.document_name, + "page_idx": page_idx, + } + all_examples.append(example) + + dataset = Dataset.from_list(all_examples, features=features) + + if dataset_repo_id: + if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"): + if not overwrite_dataset: + dataset = concatenate_datasets( + [dataset, load_dataset(dataset_repo_id)["corpus"]] + ) + + dataset.push_to_hub(dataset_repo_id, split="corpus") + + return dataset + + def cleanup_image_dir(self, image_save_dir: str = "./images"): + for file in os.listdir(image_save_dir): + file_path = os.path.join(image_save_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + + async def load_data( + self, + start_page: Optional[int] = None, + end_page: Optional[int] = None, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + image_save_dir: str = "./images", + exclude_file_extensions: list[str] = [], + **kwargs, + ) -> List[Dict[str, str]]: + """ + Asynchronously loads images from a PDF file specified by a URL or local file path. + The overrided processing abstract method then processes the images, + and optionally publishes it to a WandB artifact. + + This function downloads a PDF from a given URL if it does not already exist locally, + reads the specified range of pages, scans each page's content to extract images, and + returns a list of Page objects containing the images and metadata. + + It uses `PyPDF2` to calculate the number of pages in the PDF and the + overriden `extract_page_data` method provides the actual implementation to process + each page, extract the image content from the PDF, and convert it to png format. + It processes pages concurrently using `asyncio` for efficiency. + + If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact. + + Args: + start_page (Optional[int]): The starting page index (0-based) to process. + end_page (Optional[int]): The ending page index (0-based) to process. + dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. + image_save_dir (str): The directory to save the extracted images. + exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir. + **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library. + + Returns: + Dataset: A HuggingFace dataset containing the processed pages. + + Raises: + ValueError: If the specified start_page or end_page is out of bounds of the document's page count. + """ + os.makedirs(image_save_dir, exist_ok=True) + start_page, end_page = self.get_page_indices(start_page, end_page) + pages = [] + processed_pages_counter: int = 1 + total_pages = end_page - start_page + + async def process_page(page_idx): + nonlocal processed_pages_counter + page_data = await self.extract_page_data(page_idx, image_save_dir, **kwargs) + pages.append(page_data) + rich.print( + f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}" + ) + processed_pages_counter += 1 + + tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)] + for task in asyncio.as_completed(tasks): + await task + + with jsonlines.open( + os.path.join(image_save_dir, "metadata.jsonl"), mode="w" + ) as writer: + writer.write(pages) + + for file in os.listdir(image_save_dir): + if file.endswith(tuple(exclude_file_extensions)): + os.remove(os.path.join(image_save_dir, file)) + + dataset = self.save_as_dataset( + start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset + ) + + return dataset diff --git a/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..836c89e78b014df842328c2eaac8b6f20216ea39 --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py @@ -0,0 +1,127 @@ +import io +import os +from typing import Any, Dict + +import fitz +from pdf2image.pdf2image import convert_from_path +from PIL import Image, ImageOps, UnidentifiedImageError + +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) + + +class FitzPILImageLoader(BaseImageLoader): + """ + `FitzPILImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and + loading of pages from a PDF file as images using the fitz and PIL libraries. + + This class provides functionality to extract images from a PDF file using fitz and PIL libraries, + and optionally publish these images to a WandB artifact. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = FitzPILImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + ``` + + Args: + url (str): The URL of the PDF document. + document_name (str): The name of the document. + document_file_path (str): The path to the PDF file. + """ + + def __init__(self, url: str, document_name: str, document_file_path: str): + super().__init__(url, document_name, document_file_path) + + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, Any]: + """ + Extracts a single page from the PDF as an image using fitz and PIL libraries. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted image. + **kwargs: Additional keyword arguments that may be used by fitz and PIL. + + Returns: + Dict[str, Any]: A dictionary containing the processed page data. + The dictionary will have the following keys and values: + + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "image_file_paths": (list) the local file paths where the images are stored. + """ + image_file_paths = [] + + pdf_document = fitz.open(self.document_file_path) + page = pdf_document.load_page(page_idx) + + images = page.get_images(full=True) + for img_idx, image in enumerate(images): + xref = image[0] + base_image = pdf_document.extract_image(xref) + image_bytes = base_image["image"] + image_ext = base_image["ext"] + + try: + img = Image.open(io.BytesIO(image_bytes)) + + if img.mode in ["L"]: + # images in greyscale looks inverted, need to test on other PDFs + img = ImageOps.invert(img) + + if img.mode == "CMYK": + img = img.convert("RGB") + + if image_ext not in ["png", "jpg", "jpeg"]: + image_ext = "png" + image_file_name = f"page{page_idx}_fig{img_idx}.png" + image_file_path = os.path.join(image_save_dir, image_file_name) + + img.save(image_file_path, format="PNG") + else: + image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}" + image_file_path = os.path.join(image_save_dir, image_file_name) + + with open(image_file_path, "wb") as image_file: + image_file.write(image_bytes) + + image_file_paths.append(image_file_path) + + except (UnidentifiedImageError, OSError) as e: + print( + f"Skipping image at page {page_idx}, fig {img_idx} due to an error: {e}" + ) + continue + + pdf_document.close() + + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + + return { + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + "image_file_paths": image_file_paths, + } diff --git a/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e66cf0af74985c58200099c50ab103d6c8af6250 --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/marker_img_loader.py @@ -0,0 +1,131 @@ +import os +from typing import Any, Coroutine, Dict, List + +from marker.convert import convert_single_pdf +from marker.models import load_all_models +from pdf2image.pdf2image import convert_from_path + +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + + +class MarkerImageLoader(BaseImageLoader): + """ + `MarkerImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and + loading of pages from a PDF file as images using the marker library. + + This class provides functionality to extract images from a PDF file using marker library, + and optionally publish these images to a WandB artifact. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = MarkerImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + ``` + + Args: + url (str): The URL of the PDF document. + document_name (str): The name of the document. + document_file_path (str): The path to the PDF file. + save_page_image (bool): Whether to additionally save the image of the entire page. + """ + + def __init__( + self, + url: str, + document_name: str, + document_file_path: str, + save_page_image: bool = False, + ): + super().__init__(url, document_name, document_file_path) + self.save_page_image = save_page_image + self.model_lst = load_all_models() + + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, Any]: + """ + Extracts a single page from the PDF as an image using marker library. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted image. + **kwargs: Additional keyword arguments that may be used by marker. + + Returns: + Dict[str, Any]: A dictionary containing the processed page data. + The dictionary will have the following keys and values: + + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "image_file_path": (str) the local file path where the image is stored. + """ + _, images, _ = convert_single_pdf( + self.document_file_path, + self.model_lst, + max_pages=1, + batch_multiplier=1, + start_page=page_idx, + ocr_all_pages=True, + **kwargs, + ) + + image_file_paths = [] + for img_idx, (_, image) in enumerate(images.items()): + image_file_name = f"page{page_idx}_fig{img_idx}.png" + image_file_path = os.path.join(image_save_dir, image_file_name) + image.save(image_file_path, "png") + image_file_paths.append(image_file_path) + + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + + return { + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + "image_file_paths": os.path.join(image_save_dir, "*.png"), + } + + def load_data( + self, + start_page: int | None = None, + end_page: int | None = None, + wandb_artifact_name: str | None = None, + image_save_dir: str = "./images", + exclude_file_extensions: list[str] = [], + cleanup: bool = False, + **kwargs, + ) -> Coroutine[Any, Any, List[Dict[str, str]]]: + start_page = start_page - 1 if start_page is not None else None + end_page = end_page - 1 if end_page is not None else None + return super().load_data( + start_page, + end_page, + wandb_artifact_name, + image_save_dir, + exclude_file_extensions, + cleanup, + **kwargs, + ) diff --git a/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5abaad407c8e7781275fc30a19a771f902cf74 --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py @@ -0,0 +1,83 @@ +import os +from typing import Any, Dict + +from pdf2image.pdf2image import convert_from_path + +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) + + +class PDF2ImageLoader(BaseImageLoader): + """ + `PDF2ImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and + loading of pages from a PDF file as images using the pdf2image library. + + This class provides functionality to convert specific pages of a PDF document into images + and optionally publish these images to a WandB artifact. + It is like a snapshot image version of each of the pages from the PDF. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PDF2ImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + ``` + + Args: + url (str): The URL of the PDF document. + document_name (str): The name of the document. + document_file_path (str): The path to the PDF file. + """ + + def __init__(self, url: str, document_name: str, document_file_path: str): + super().__init__(url, document_name, document_file_path) + + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, Any]: + """ + Extracts a single page from the PDF as an image using pdf2image library. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted image. + **kwargs: Additional keyword arguments that may be used by pdf2image. + + Returns: + Dict[str, Any]: A dictionary containing the processed page data. + The dictionary will have the following keys and values: + + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "image_file_path": (str) the local file path where the image is stored. + """ + image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + + image_file_name = f"page{page_idx}.png" + image_file_path = os.path.join(image_save_dir, image_file_name) + image.save(image_file_path) + + return { + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + "image_file_path": image_file_path, + } diff --git a/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..2635071c6b58751c8791c07e0c96877e8210f3c1 --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py @@ -0,0 +1,101 @@ +import os +from typing import Any, Dict + +import pdfplumber +from pdf2image.pdf2image import convert_from_path + +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) + + +class PDFPlumberImageLoader(BaseImageLoader): + """ + `PDFPlumberImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and + loading of pages from a PDF file as images using the pdfplumber library. + + This class provides functionality to extract images from a PDF file using pdfplumber library, + and optionally publish these images to a WandB artifact. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PDFPlumberImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + ``` + + Args: + url (str): The URL of the PDF document. + document_name (str): The name of the document. + document_file_path (str): The path to the PDF file. + """ + + def __init__(self, url: str, document_name: str, document_file_path: str): + super().__init__(url, document_name, document_file_path) + + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, Any]: + """ + Extracts a single page from the PDF as an image using pdfplumber library. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted image. + **kwargs: Additional keyword arguments that may be used by pdfplumber. + + Returns: + Dict[str, Any]: A dictionary containing the processed page data. + The dictionary will have the following keys and values: + + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "image_file_path": (str) the local file path where the image is stored. + """ + with pdfplumber.open(self.document_file_path) as pdf: + page = pdf.pages[page_idx] + images = page.images + + image_file_paths = [] + for img_idx, image in enumerate(images): + extracted_image = page.crop( + ( + image["x0"], + image["top"], + image["x1"], + image["bottom"], + ) + ).to_image(resolution=300) + + image_file_name = f"page{page_idx}_fig{img_idx}.png" + image_file_path = os.path.join(image_save_dir, image_file_name) + + extracted_image.save(image_file_path, "png") + image_file_paths.append(image_file_path) + + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + + return { + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + "image_file_paths": image_file_paths, + } diff --git a/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..336b8afc01fa421f8b2b7ae4d6beedd3cdf54ace --- /dev/null +++ b/medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py @@ -0,0 +1,124 @@ +import io +import os +from typing import Any, Dict + +import fitz +from pdf2image.pdf2image import convert_from_path +from PIL import Image + +from medrag_multi_modal.document_loader.image_loader.base_img_loader import ( + BaseImageLoader, +) + + +class PyMuPDFImageLoader(BaseImageLoader): + """ + `PyMuPDFImageLoader` is a class that extends the `BaseImageLoader` class to handle the extraction and + loading of pages from a PDF file as images using the pymupdf library. + + This class provides functionality to extract images from a PDF file using pymupdf library, + and optionally publish these images to a WandB artifact. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PyMuPDFImageLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=32, end_page=37)) + ``` + + Args: + url (str): The URL of the PDF document. + document_name (str): The name of the document. + document_file_path (str): The path to the PDF file. + """ + + def __init__(self, url: str, document_name: str, document_file_path: str): + super().__init__(url, document_name, document_file_path) + + async def extract_page_data( + self, page_idx: int, image_save_dir: str, **kwargs + ) -> Dict[str, Any]: + """ + Extracts a single page from the PDF as an image using pymupdf library. + + Args: + page_idx (int): The index of the page to process. + image_save_dir (str): The directory to save the extracted image. + **kwargs: Additional keyword arguments that may be used by pymupdf. + + Returns: + Dict[str, Any]: A dictionary containing the processed page data. + The dictionary will have the following keys and values: + + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "image_file_paths": (list) the local file paths where the images are stored. + """ + image_file_paths = [] + + pdf_document = fitz.open(self.document_file_path) + page = pdf_document[page_idx] + + images = page.get_images(full=True) + for img_idx, image in enumerate(images): + xref = image[0] + base_image = pdf_document.extract_image(xref) + image_bytes = base_image["image"] + image_ext = base_image["ext"] + + if image_ext == "jb2": + image_ext = "png" + elif image_ext == "jpx": + image_ext = "jpg" + + image_file_name = f"page{page_idx}_fig{img_idx}.{image_ext}" + image_file_path = os.path.join(image_save_dir, image_file_name) + + # For JBIG2 and JPEG2000, we need to convert the image + if base_image["ext"] in ["jb2", "jpx"]: + try: + pix = fitz.Pixmap(image_bytes) + pix.save(image_file_path) + except Exception as err_fitz: + print(f"Error processing image with fitz: {err_fitz}") + # Fallback to using PIL for image conversion + try: + img = Image.open(io.BytesIO(image_bytes)) + img.save(image_file_path) + except Exception as err_pil: + print(f"Failed to process image with PIL: {err_pil}") + continue # Skip this image if both methods fail + else: + with open(image_file_path, "wb") as image_file: + image_file.write(image_bytes) + + image_file_paths.append(image_file_path) + + pdf_document.close() + + page_image = convert_from_path( + self.document_file_path, + first_page=page_idx + 1, + last_page=page_idx + 1, + **kwargs, + )[0] + page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png")) + + return { + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + "image_file_paths": image_file_paths, + } diff --git a/medrag_multi_modal/document_loader/text_loader/__init__.py b/medrag_multi_modal/document_loader/text_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80d18dbb893992b9196adc9277b3b758d64a2bce --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/__init__.py @@ -0,0 +1,11 @@ +from .marker_text_loader import MarkerTextLoader +from .pdfplumber_text_loader import PDFPlumberTextLoader +from .pymupdf4llm_text_loader import PyMuPDF4LLMTextLoader +from .pypdf2_text_loader import PyPDF2TextLoader + +__all__ = [ + "PyMuPDF4LLMTextLoader", + "PyPDF2TextLoader", + "PDFPlumberTextLoader", + "MarkerTextLoader", +] diff --git a/medrag_multi_modal/document_loader/text_loader/base_text_loader.py b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bc4dc455223eaaf75b61de500bf740d5fe9446 --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/base_text_loader.py @@ -0,0 +1,185 @@ +import asyncio +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import huggingface_hub +import PyPDF2 +from datasets import Dataset, concatenate_datasets, load_dataset +from firerequests import FireRequests +from rich.progress import Progress + + +class BaseTextLoader(ABC): + """ + An abstract base class for loading text from a PDF file, processing it into markdown, and optionally publishing it to a Weave dataset. + + This class handles the downloading of a PDF file from a given URL if it does not already exist locally. + Subclasses should implement the specific PDF reading, text extraction, and markdown conversion methods. + + The processed pages are finally stored in a list of Page objects, which can be optionally published to a Weave dataset. + + Args: + url (str): The URL of the PDF file to download if not present locally. + document_name (str): The name of the document for metadata purposes. + document_file_path (str): The local file path where the PDF is stored or will be downloaded. + metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset. + """ + + def __init__( + self, + url: str, + document_name: str, + document_file_path: str, + metadata: Optional[dict[str, Any]] = None, + ): + self.url = url + self.document_name = document_name + self.document_file_path = document_file_path + self.metadata = metadata or {} + if not os.path.exists(self.document_file_path): + FireRequests().download(url, filenames=self.document_file_path) + with open(self.document_file_path, "rb") as file: + pdf_reader = PyPDF2.PdfReader(file) + self.page_count = len(pdf_reader.pages) + + def get_page_indices( + self, start_page: Optional[int] = None, end_page: Optional[int] = None + ) -> tuple[int, int]: + """ + Get the start and end page indices for processing. + + Args: + start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page. + end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page. + + Returns: + tuple[int, int]: A tuple containing the start and end page indices. + """ + + if start_page: + if start_page > self.page_count: + raise ValueError( + f"Start page {start_page} is greater than the total page count {self.page_count}" + ) + else: + start_page = 0 + if end_page: + if end_page > self.page_count: + raise ValueError( + f"End page {end_page} is greater than the total page count {self.page_count}" + ) + else: + end_page = self.page_count - 1 + return start_page, end_page + + @abstractmethod + async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]: + """ + Abstract method to process a single page of the PDF and extract the text data. + + Overwrite this method in the subclass to provide the actual implementation and + processing logic for each page of the PDF using various PDF processing libraries. + + Args: + page_idx (int): The index of the page to process. + **kwargs: Additional keyword arguments that may be used by underlying libraries. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + pass + + async def load_data( + self, + start_page: Optional[int] = None, + end_page: Optional[int] = None, + exclude_pages: Optional[list[int]] = None, + dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + **kwargs, + ) -> Dataset: + """ + Asynchronously loads text from a PDF file specified by a URL or local file path. + The overrided processing abstract method then processes the text into markdown format, + and optionally publishes it to a Weave dataset. + + This function downloads a PDF from a given URL if it does not already exist locally, + reads the specified range of pages, converts each page's content to markdown, and + returns a list of Page objects containing the text and metadata. + + It uses `PyPDF2` to calculate the number of pages in the PDF and the + overriden `extract_page_data` method provides the actual implementation to process + each page, extract the text from the PDF, and convert it to markdown. + It processes pages concurrently using `asyncio` for efficiency. + + If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset. + + Args: + start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page. + end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page. + exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing. + dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. + **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library. + + Returns: + Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages. + Each entry in the dataset will have the following keys and values: + + - "text": (str) the processed page data in markdown format. + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "loader_name": (str) the name of the loader class used to process the page. + + Raises: + ValueError: If the specified start_page or end_page is out of bounds of the document's page count. + """ + start_page, end_page = self.get_page_indices(start_page, end_page) + pages = [] + processed_pages_counter: int = 1 + total_pages = end_page - start_page + exclude_pages = exclude_pages or [] + + async def process_page(page_idx): + nonlocal processed_pages_counter + page_data = await self.extract_page_data(page_idx, **kwargs) + page_data["loader_name"] = self.__class__.__name__ + for key, value in self.metadata.items(): + if key not in page_data: + page_data[key] = value + pages.append(page_data) + progress.update( + task_id, + advance=1, + description=f"Loading page {page_idx} using {self.__class__.__name__}", + ) + processed_pages_counter += 1 + + progress = Progress() + with progress: + task_id = progress.add_task("Starting...", total=total_pages) + tasks = [ + process_page(page_idx) + for page_idx in range(start_page, end_page + 1) + if page_idx not in exclude_pages + ] + for task in asyncio.as_completed(tasks): + await task + + pages.sort(key=lambda x: x["page_idx"]) + + dataset = Dataset.from_list(pages) + if dataset_repo_id: + if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"): + print("Dataset already exists") + if not overwrite_dataset: + print("Not overwriting dataset") + dataset = concatenate_datasets( + [dataset, load_dataset(dataset_repo_id, split="corpus")] + ) + dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False) + + return dataset diff --git a/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..16a19e3343bcd2d44f1481def4c7ad031838b815 --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/marker_text_loader.py @@ -0,0 +1,89 @@ +import os +from typing import Dict + +from marker.convert import convert_single_pdf +from marker.models import load_all_models + +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + + +class MarkerTextLoader(BaseTextLoader): + """ + A concrete implementation of the BaseTextLoader for loading text from a PDF file + using `marker-pdf`, processing it into a structured text format, and optionally publishing + it to a Weave dataset. + + This class extends the BaseTextLoader and implements the abstract methods to + load and process pages from a PDF file using marker-pdf, which is a pipeline of deep learning models. + + This class will handle the downloading of a PDF file from a given URL if it does not already exist locally. + It uses marker-pdf to read the PDF and extract structured text from each page. The processed pages are stored + in a list of Page objects, which can be optionally published to a Weave dataset. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader import MarkerTextLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = MarkerTextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + ``` + + Args: + url (str): The URL of the PDF file to download if not present locally. + document_name (str): The name of the document for metadata purposes. + document_file_path (str): The local file path where the PDF is stored or will be downloaded. + """ + + async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]: + """ + Process a single page of the PDF and extract its structured text using marker-pdf. + + Returns: + Dict[str, str]: A dictionary with the processed page data. + The dictionary will have the following keys and values: + + - "text": (str) the extracted structured text from the page. + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + - "meta": (dict) the metadata extracted from the page by marker-pdf. + + Args: + page_idx (int): The index of the page to process. + **kwargs: Additional keyword arguments to be passed to `marker.convert.convert_single_pdf`. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + model_lst = load_all_models() + + text, _, _ = convert_single_pdf( + self.document_file_path, + model_lst, + max_pages=1, + batch_multiplier=1, + start_page=page_idx, + ocr_all_pages=True, + **kwargs, + ) + + return { + "text": text, + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + } diff --git a/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..337aed66516170baded9a6faa997c7a2e2503319 --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py @@ -0,0 +1,76 @@ +from typing import Dict + +import pdfplumber + +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) + + +class PDFPlumberTextLoader(BaseTextLoader): + """ + A concrete implementation of the BaseTextLoader for loading text from a PDF file + using `pdfplumber`, processing it into a simple text format, and optionally publishing + it to a Weave dataset. + + This class extends the BaseTextLoader and implements the abstract methods to + load and process pages from a PDF file. + + This class will handle the downloading of a PDF file from a given URL if it does not already exist locally. + It uses pdfplumber to read the PDF and extract text from each page. The processed pages are stored in a list + of Page objects, which can be optionally published to a Weave dataset. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader import PDFPlumberTextLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PDFPlumberTextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + ``` + + Args: + url (str): The URL of the PDF file to download if not present locally. + document_name (str): The name of the document for metadata purposes. + document_file_path (str): The local file path where the PDF is stored or will be downloaded. + """ + + async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]: + """ + Process a single page of the PDF and extract its text using pdfplumber. + + Returns: + Dict[str, str]: A dictionary with the processed page data. + The dictionary will have the following keys and values: + + - "text": (str) the extracted text from the page. + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + + Args: + page_idx (int): The index of the page to process. + **kwargs: Additional keyword arguments to be passed to `pdfplumber.Page.extract_text`. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + with pdfplumber.open(self.document_file_path) as pdf: + page = pdf.pages[page_idx] + text = page.extract_text(**kwargs) + + return { + "text": text, + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + } diff --git a/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..05493656683fa7378dff415e3861663667f1340a --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py @@ -0,0 +1,73 @@ +from typing import Dict + +import pymupdf4llm + +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) + + +class PyMuPDF4LLMTextLoader(BaseTextLoader): + """ + A concrete implementation of the BaseTextLoader for loading text from a PDF file, + processing it into markdown using `pymupdf4llm`, and optionally publishing it to a Weave dataset. + + This class extends the BaseTextLoader and implements the abstract methods to load and process pages from a PDF file. + + This class will handle the downloading of a PDF file from a given URL if it does not already exist locally. + It uses PyPDF2 to read the PDF and pymupdf4llm to convert pages to markdown. The processed pages are stored in a list + of Page objects, which can be optionally published to a Weave dataset. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PyMuPDF4LLMTextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + ``` + + Args: + url (str): The URL of the PDF file to download if not present locally. + document_name (str): The name of the document for metadata purposes. + document_file_path (str): The local file path where the PDF is stored or will be downloaded. + """ + + async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]: + """ + Process a single page of the PDF and convert it to markdown using `pymupdf4llm`. + + Returns: + Dict[str, str]: A dictionary with the processed page data. + The dictionary will have the following keys and values: + + - "text": (str) the processed page data in markdown format. + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + + Args: + page_idx (int): The index of the page to process. + **kwargs: Additional keyword arguments to be passed to `pymupdf4llm.to_markdown`. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + text = pymupdf4llm.to_markdown( + doc=self.document_file_path, pages=[page_idx], show_progress=False, **kwargs + ) + return { + "text": text, + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + } diff --git a/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..df6cc011e1a93d9c5f8dd1d25d88f41dff928623 --- /dev/null +++ b/medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py @@ -0,0 +1,77 @@ +from typing import Dict + +import PyPDF2 + +from medrag_multi_modal.document_loader.text_loader.base_text_loader import ( + BaseTextLoader, +) + + +class PyPDF2TextLoader(BaseTextLoader): + """ + A concrete implementation of the BaseTextLoader for loading text from a PDF file + using `PyPDF2`, processing it into a simple text format, and optionally publishing + it to a Weave dataset. + + This class extends the BaseTextLoader and implements the abstract methods to + load and process pages from a PDF file. + + This class will handle the downloading of a PDF file from a given URL if it does not already exist locally. + It uses PyPDF2 to read the PDF and extract text from each page. The processed pages are stored in a list + of Page objects, which can be optionally published to a Weave dataset. + + !!! example "Example Usage" + ```python + import asyncio + + from medrag_multi_modal.document_loader import PyPDF2TextLoader + + URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf" + + loader = PyPDF2TextLoader( + url=URL, + document_name="Gray's Anatomy", + document_file_path="grays_anatomy.pdf", + ) + dataset = asyncio.run(loader.load_data(start_page=31, end_page=36)) + ``` + + Args: + url (str): The URL of the PDF file to download if not present locally. + document_name (str): The name of the document for metadata purposes. + document_file_path (str): The local file path where the PDF is stored or will be downloaded. + """ + + async def extract_page_data(self, page_idx: int, **kwargs) -> Dict[str, str]: + """ + Process a single page of the PDF and extract its text using PyPDF2. + + Returns: + Dict[str, str]: A dictionary with the processed page data. + The dictionary will have the following keys and values: + + - "text": (str) the extracted text from the page. + - "page_idx": (int) the index of the page. + - "document_name": (str) the name of the document. + - "file_path": (str) the local file path where the PDF is stored. + - "file_url": (str) the URL of the PDF file. + + Args: + page_idx (int): The index of the page to process. + **kwargs: Additional keyword arguments to be passed to `PyPDF2.PdfReader.pages[0].extract_text`. + + Returns: + Dict[str, str]: A dictionary containing the processed page data. + """ + with open(self.document_file_path, "rb") as file: + pdf_reader = PyPDF2.PdfReader(file) + page = pdf_reader.pages[page_idx] + text = page.extract_text(**kwargs) + + return { + "text": text, + "page_idx": page_idx, + "document_name": self.document_name, + "file_path": self.document_file_path, + "file_url": self.url, + } diff --git a/medrag_multi_modal/metrics/__init__.py b/medrag_multi_modal/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b7bf585104b0331b74407570a1d6744c836320 --- /dev/null +++ b/medrag_multi_modal/metrics/__init__.py @@ -0,0 +1,3 @@ +from .mmlu import MMLUOptionAccuracy + +__all__ = ["MMLUOptionAccuracy"] diff --git a/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..086528d16161f1a2c58f2d66f19ff84dfc96b4e7 Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8dcebbc939fb67e7993da6a8a84696402a47e5a Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/base.cpython-310.pyc differ diff --git a/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc b/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e03ba593229697c48a5efb03842f43be32d33618 Binary files /dev/null and b/medrag_multi_modal/metrics/__pycache__/mmlu.cpython-310.pyc differ diff --git a/medrag_multi_modal/metrics/base.py b/medrag_multi_modal/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e16cb9219d1f47115fc497f4ddc4f49d376a7c63 --- /dev/null +++ b/medrag_multi_modal/metrics/base.py @@ -0,0 +1,108 @@ +from typing import Optional + +import numpy as np +import weave + + +class BaseAccuracyMetric(weave.Scorer): + """ + BaseAccuracyMetric is a class that extends the + [`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers) + to provide a comprehensive evaluation of accuracy metrics for a given set of score rows. + + This class is designed to process a list of score rows, each containing a + 'correct' key that indicates whether a particular prediction was correct. + The `summarize` method calculates various statistical measures and metrics + based on this data, including: + + - True and false counts: The number of true and false predictions. + - True and false fractions: The proportion of true and false predictions. + - Standard error: The standard error of the mean for the true predictions. + - Precision: The ratio of true positive predictions to the total number of + positive predictions. + - Recall: The ratio of true positive predictions to the total number of + actual positives. + - F1 Score: The harmonic mean of precision and recall, providing a balance + between the two metrics. + + The `summarize` method returns a dictionary containing these metrics, + allowing for a detailed analysis of the model's performance. + + Methods: + summarize(score_rows: list) -> Optional[dict]: + Processes the input score rows to compute and return a dictionary + of accuracy metrics. + """ + @weave.op() + def summarize(self, score_rows: list) -> Optional[dict]: + """ + Summarizes the accuracy metrics from a list of score rows. + + This method processes a list of score rows, each containing a 'correct' key + that indicates whether a particular prediction was correct. It calculates + various statistical measures and metrics based on this data, including: + + - True and false counts: The number of true and false predictions. + - True and false fractions: The proportion of true and false predictions. + - Standard error: The standard error of the mean for the true predictions. + - Precision: The ratio of true positive predictions to the total number of + positive predictions. + - Recall: The ratio of true positive predictions to the total number of + actual positives. + - F1 Score: The harmonic mean of precision and recall, providing a balance + between the two metrics. + + The method returns a dictionary containing these metrics, allowing for a + detailed analysis of the model's performance. + + Args: + score_rows (list): A list of dictionaries, each containing a 'correct' + key with a boolean value indicating the correctness of a prediction. + + Returns: + Optional[dict]: A dictionary containing the calculated accuracy metrics, + or None if the input list is empty. + """ + valid_data = [ + x.get("correct") for x in score_rows if x.get("correct") is not None + ] + count_true = list(valid_data).count(True) + int_data = [int(x) for x in valid_data] + + sample_mean = np.mean(int_data) if int_data else 0 + sample_variance = np.var(int_data) if int_data else 0 + sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0 + + # Calculate precision, recall, and F1 score + true_positives = count_true + false_positives = len(valid_data) - count_true + false_negatives = len(score_rows) - len(valid_data) + + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1_score = ( + (2 * precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + return { + "correct": { + "true_count": count_true, + "false_count": len(score_rows) - count_true, + "true_fraction": float(sample_mean), + "false_fraction": 1.0 - float(sample_mean), + "stderr": float(sample_error), + "precision": precision, + "recall": recall, + "f1_score": f1_score, + } + } diff --git a/medrag_multi_modal/metrics/mmlu.py b/medrag_multi_modal/metrics/mmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..3e182084fd5cecef8834710d611ae0b5680dfe4b --- /dev/null +++ b/medrag_multi_modal/metrics/mmlu.py @@ -0,0 +1,24 @@ +import weave + +from medrag_multi_modal.assistant.schema import MedQAResponse +from medrag_multi_modal.metrics.base import BaseAccuracyMetric + + +class MMLUOptionAccuracy(BaseAccuracyMetric): + """ + MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`. + + This class is designed to evaluate the accuracy of a multiple-choice question + response by comparing the provided answer with the correct answer from the + given options. It uses the MedQAResponse schema to extract the response + and checks if it matches the correct answer. + + Methods: + -------- + score(output: MedQAResponse, options: list[str], answer: str) -> dict: + Compares the provided answer with the correct answer and returns a + dictionary indicating whether the answer is correct. + """ + @weave.op() + def score(self, output: MedQAResponse, options: list[str], answer: str): + return {"correct": options[answer] == output.response.answer} diff --git a/medrag_multi_modal/retrieval/__init__.py b/medrag_multi_modal/retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20f0a3bdfd27fd93ea9681dc031ead5b885c1909 --- /dev/null +++ b/medrag_multi_modal/retrieval/__init__.py @@ -0,0 +1,3 @@ +from .colpali_retrieval import CalPaliRetriever + +__all__ = ["CalPaliRetriever"] diff --git a/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4050794c922e39d8c9ff5248802458a642aa88b Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/__init__.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34fff250c919944461288b8d3de0985be872809 Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/colpali_retrieval.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc b/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7a40a45da5990618ca42d6beb0dc6cf691709b Binary files /dev/null and b/medrag_multi_modal/retrieval/__pycache__/common.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/colpali_retrieval.py b/medrag_multi_modal/retrieval/colpali_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..522d964058abda29649115fa732f5a30ae659607 --- /dev/null +++ b/medrag_multi_modal/retrieval/colpali_retrieval.py @@ -0,0 +1,255 @@ +import os +from typing import TYPE_CHECKING, Any, Optional + +import weave + +if TYPE_CHECKING: + from byaldi import RAGMultiModalModel + +import wandb +from PIL import Image + +from medrag_multi_modal.utils import get_wandb_artifact + + +class CalPaliRetriever(weave.Model): + """ + CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali. + + This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks. + It can be initialized with a pre-trained model or from a specified W&B artifact. The class + also provides methods to index new data and to predict/retrieve documents based on a query. + + Attributes: + model_name (str): The name of the model to be used for retrieval. + """ + + model_name: str + _docs_retrieval_model: Optional["RAGMultiModalModel"] = None + _metadata: Optional[dict] = None + _data_artifact_dir: Optional[str] = None + + def __init__( + self, + model_name: str = "vidore/colpali-v1.2", + docs_retrieval_model: Optional["RAGMultiModalModel"] = None, + data_artifact_dir: Optional[str] = None, + metadata_dataset_name: Optional[str] = None, + ): + super().__init__(model_name=model_name) + from byaldi import RAGMultiModalModel + + self._docs_retrieval_model = ( + docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name) + ) + self._data_artifact_dir = data_artifact_dir + self._metadata = ( + [dict(row) for row in weave.ref(metadata_dataset_name).get().rows] + if metadata_dataset_name + else None + ) + + def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str): + """ + Indexes a dataset of documents and saves the index as a Weave artifact. + + This method retrieves a dataset of documents from a Weave artifact using the provided + data artifact name. It then indexes the documents using the document retrieval model + and assigns the specified index name. The index is stored locally without storing the + collection with the index and overwrites any existing index with the same name. + + If a Weave run is active, the method creates a new Weave artifact with the specified + index name and type "colpali-index". It adds the local index directory to the artifact + and saves it to Weave, including metadata with the provided Weave dataset name. + + !!! example "Indexing Data" + First you need to install `Byaldi` library by Answer.ai. + + ```bash + uv pip install Byaldi>=0.0.5 + ``` + + Next, you can index the data by running the following code: + + ```python + import wandb + from medrag_multi_modal.retrieval import CalPaliRetriever + + wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index") + retriever = CalPaliRetriever() + retriever.index( + data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", + weave_dataset_name="grays-anatomy-images:v0", + index_name="grays-anatomy", + ) + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for ColPali by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + data_artifact_name (str): The name of the Weave artifact containing the dataset. + weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata. + index_name (str): The name to assign to the created index. + """ + data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset") + self._docs_retrieval_model.index( + input_path=data_artifact_dir, + index_name=index_name, + store_collection_with_index=False, + overwrite=True, + ) + if wandb.run: + artifact = wandb.Artifact( + name=index_name, + type="colpali-index", + metadata={"weave_dataset_name": weave_dataset_name}, + ) + artifact.add_dir( + local_path=os.path.join(".byaldi", index_name), name="index" + ) + artifact.save() + + @classmethod + def from_wandb_artifact( + cls, + index_artifact_name: str, + metadata_dataset_name: str, + data_artifact_name: str, + ): + """ + Creates an instance of the class from Weights & Biases (wandb) artifacts. + + This method retrieves the necessary artifacts from wandb to initialize the + ColPaliRetriever. It fetches the index artifact directory and the data artifact + directory using the provided artifact names. It then loads the document retrieval + model from the index path within the index artifact directory. Finally, it returns + an instance of the class initialized with the retrieved document retrieval model, + metadata dataset name, and data artifact directory. + + !!! example "Retrieving Documents" + First you need to install `Byaldi` library by Answer.ai. + + ```bash + uv pip install Byaldi>=0.0.5 + ``` + + Next, you can retrieve the documents by running the following code: + + ```python + import weave + + import wandb + from medrag_multi_modal.retrieval import CalPaliRetriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = CalPaliRetriever.from_wandb_artifact( + index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0", + metadata_dataset_name="grays-anatomy-images:v0", + data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", + ) + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for ColPali by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + index_artifact_name (str): The name of the wandb artifact containing the index. + metadata_dataset_name (str): The name of the dataset containing metadata. + data_artifact_name (str): The name of the wandb artifact containing the data. + + Returns: + An instance of the class initialized with the retrieved document retrieval model, + metadata dataset name, and data artifact directory. + """ + from byaldi import RAGMultiModalModel + + index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index") + data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset") + docs_retrieval_model = RAGMultiModalModel.from_index( + index_path=os.path.join(index_artifact_dir, "index") + ) + return cls( + docs_retrieval_model=docs_retrieval_model, + metadata_dataset_name=metadata_dataset_name, + data_artifact_dir=data_artifact_dir, + ) + + @weave.op() + def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]: + """ + Predicts and retrieves the top-k most relevant documents/images for a given query + using ColPali. + + This function uses the document retrieval model to search for the most relevant + documents based on the provided query. It returns a list of dictionaries, each + containing the document image, document ID, and the relevance score. + + !!! example "Retrieving Documents" + First you need to install `Byaldi` library by Answer.ai. + + ```bash + uv pip install Byaldi>=0.0.5 + ``` + + Next, you can retrieve the documents by running the following code: + + ```python + import weave + + import wandb + from medrag_multi_modal.retrieval import CalPaliRetriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = CalPaliRetriever.from_wandb_artifact( + index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0", + metadata_dataset_name="grays-anatomy-images:v0", + data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1", + ) + retriever.predict( + query="which neurotransmitters convey information between Merkel cells and sensory afferents?", + top_k=3, + ) + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for ColPali by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + query (str): The search query string. + top_k (int, optional): The number of top results to retrieve. Defaults to 10. + + Returns: + list[dict[str, Any]]: A list of dictionaries where each dictionary contains: + - "doc_image" (PIL.Image.Image): The image of the document. + - "doc_id" (str): The ID of the document. + - "score" (float): The relevance score of the document. + """ + results = self._docs_retrieval_model.search(query=query, k=top_k) + retrieved_results = [] + for result in results: + retrieved_results.append( + { + "doc_image": Image.open( + os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png") + ), + "doc_id": result["doc_id"], + "score": result["score"], + } + ) + return retrieved_results diff --git a/medrag_multi_modal/retrieval/common.py b/medrag_multi_modal/retrieval/common.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0f1244bb83c595a5f2427e6ecf86b9334bc4c7 --- /dev/null +++ b/medrag_multi_modal/retrieval/common.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class SimilarityMetric(Enum): + COSINE = "cosine" + EUCLIDEAN = "euclidean" + + +def mean_pooling(token_embeddings, mask): + token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0) + sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] + return sentence_embeddings + + +def argsort_scores(scores: list[float], descending: bool = False): + return [ + {"item": item, "original_index": idx} + for idx, item in sorted( + list(enumerate(scores)), key=lambda x: x[1], reverse=descending + ) + ] diff --git a/medrag_multi_modal/retrieval/text_retrieval/__init__.py b/medrag_multi_modal/retrieval/text_retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8482ab713102ca550af488a830af1aacc7fee3 --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/__init__.py @@ -0,0 +1,11 @@ +from .bm25s_retrieval import BM25sRetriever +from .contriever_retrieval import ContrieverRetriever +from .medcpt_retrieval import MedCPTRetriever +from .nv_embed_2 import NVEmbed2Retriever + +__all__ = [ + "BM25sRetriever", + "ContrieverRetriever", + "MedCPTRetriever", + "NVEmbed2Retriever", +] diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44658e870a50a31acf0f6ec9602623480e9bcb7b Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/__init__.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e1a2097604c09f736dfeab88bacc4d4d97def59 Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/bm25s_retrieval.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a921a70611df50782b1c7bd48c55fded32f6ab05 Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/contriever_retrieval.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..087ebe6536a3c552c34f2ea76daff37c66b07a17 Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/medcpt_retrieval.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db5aae11e9d1204143660983277301eb8299213 Binary files /dev/null and b/medrag_multi_modal/retrieval/text_retrieval/__pycache__/nv_embed_2.cpython-310.pyc differ diff --git a/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b528262a62c1ae3c962a42877d4cd26ef6d75d --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/bm25s_retrieval.py @@ -0,0 +1,238 @@ +import json +import os +import shutil +from typing import Optional, Union + +import bm25s +import huggingface_hub +import weave +from bm25s import BM25 +from datasets import Dataset, load_dataset +from Stemmer import Stemmer + +from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface + +LANGUAGE_DICT = { + "english": "en", + "french": "fr", + "german": "de", +} + + +class BM25sRetriever(weave.Model): + """ + `BM25sRetriever` is a class that provides functionality for indexing and + retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s). + + Args: + language (str): The language of the documents to be indexed and retrieved. + use_stemmer (bool): A flag indicating whether to use stemming during tokenization. + retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided, + a new instance is created. + """ + + language: Optional[str] + use_stemmer: bool = True + _retriever: Optional[BM25] + + def __init__( + self, + language: str = "english", + use_stemmer: bool = True, + retriever: Optional[BM25] = None, + ): + super().__init__(language=language, use_stemmer=use_stemmer) + self._retriever = retriever or BM25() + + def index( + self, + chunk_dataset: Union[Dataset, str], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + ): + """ + Indexes a dataset of text chunks using the BM25 algorithm. + + This method retrieves a dataset of text chunks from a specified source, tokenizes + the text using the BM25 tokenizer with optional stemming, and indexes the tokenized + text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved + to disk and optionally logged as a Huggingface artifact. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever() + retriever.index( + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + index_repo_id="geekyrakshit/grays-anatomy-index", + ) + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + """ + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + corpus = [row["text"] for row in chunk_dataset] + corpus_tokens = bm25s.tokenize( + corpus, + stopwords=LANGUAGE_DICT[self.language], + stemmer=Stemmer(self.language) if self.use_stemmer else None, + ) + self._retriever.index(corpus_tokens) + if index_repo_id: + os.makedirs(".huggingface", exist_ok=True) + index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1]) + self._retriever.save( + index_save_dir, corpus=[dict(row) for row in chunk_dataset] + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open(os.path.join(index_save_dir, "config.json"), "w") as config_file: + json.dump( + { + "language": self.language, + "use_stemmer": self.use_stemmer, + }, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: BM25s index", + ) + if cleanup: + shutil.rmtree(index_save_dir) + + @classmethod + def from_index(cls, index_repo_id: str): + """ + Creates an instance of the class from a Huggingface repository. + + This class method retrieves a BM25 index artifact from a Huggingface repository, + downloads the artifact, and loads the BM25 retriever with the index and its + associated corpus. The method also extracts metadata from the artifact to + initialize the class instance with the appropriate language and stemming + settings. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") + ``` + + Args: + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + + Returns: + An instance of the class initialized with the BM25 retriever and metadata + from the artifact. + """ + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") + retriever = bm25s.BM25.load(index_dir, load_corpus=True) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + config = json.load(config_file) + return cls(retriever=retriever, **config) + + @weave.op() + def retrieve(self, query: str, top_k: int = 2): + """ + Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm. + + This method tokenizes the input query using the BM25 tokenizer, which takes into + account the language-specific stopwords and optional stemming. It then retrieves + the top-k most relevant chunks from the BM25 index based on the tokenized query. + The results are returned as a list of dictionaries, each containing a chunk and + its corresponding relevance score. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its + relevance score. + """ + query_tokens = bm25s.tokenize( + query, + stopwords=LANGUAGE_DICT[self.language], + stemmer=Stemmer(self.language) if self.use_stemmer else None, + ) + results = self._retriever.retrieve(query_tokens, k=top_k) + retrieved_chunks = [] + for chunk, score in zip( + results.documents.flatten().tolist(), + results.scores.flatten().tolist(), + ): + retrieved_chunks.append({**chunk, **{"score": score}}) + return retrieved_chunks + + @weave.op() + def predict(self, query: str, top_k: int = 2): + """ + Predicts the top-k most relevant chunks for a given query using the BM25 algorithm. + + This function is a wrapper around the `retrieve` method. It takes an input query string, + tokenizes it using the BM25 tokenizer, and retrieves the top-k most relevant chunks from + the BM25 index. The results are returned as a list of dictionaries, each containing a chunk + and its corresponding relevance score. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = BM25sRetriever() + retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index") + retrieved_chunks = retriever.predict(query="What are Ribosomes?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + return self.retrieve(query, top_k) diff --git a/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..77da42e2693828232048723718ca29ae996c788f --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/contriever_retrieval.py @@ -0,0 +1,310 @@ +import json +import os +import shutil +from typing import Optional, Union + +import huggingface_hub +import safetensors +import safetensors.torch +import torch +import torch.nn.functional as F +import weave +from datasets import Dataset, load_dataset +from rich.progress import track +from transformers import ( + AutoModel, + AutoTokenizer, + BertPreTrainedModel, + PreTrainedTokenizerFast, +) + +from medrag_multi_modal.retrieval.common import ( + SimilarityMetric, + argsort_scores, + mean_pooling, +) +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) + + +class ContrieverRetriever(weave.Model): + """ + `ContrieverRetriever` is a class to perform retrieval tasks using the Contriever model. + + It provides methods to encode text data into embeddings, index a dataset of text chunks, + and retrieve the most relevant chunks for a given query based on similarity metrics. + + Args: + model_name (str): The name of the pre-trained model to use for encoding. + vector_index (Optional[torch.Tensor]): The tensor containing the vector representations + of the indexed chunks. + chunk_dataset (Optional[list[dict]]): The weave dataset of text chunks to be indexed. + """ + + model_name: str + _chunk_dataset: Optional[list[dict]] + _tokenizer: PreTrainedTokenizerFast + _model: BertPreTrainedModel + _vector_index: Optional[torch.Tensor] + + def __init__( + self, + model_name: str = "facebook/contriever", + vector_index: Optional[torch.Tensor] = None, + chunk_dataset: Optional[list[dict]] = None, + ): + super().__init__(model_name=model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._model = AutoModel.from_pretrained(self.model_name).to(get_torch_backend()) + self._vector_index = vector_index + self._chunk_dataset = chunk_dataset + + def encode(self, corpus: list[str], batch_size: int) -> torch.Tensor: + embeddings = [] + iterable = track( + range(0, len(corpus), batch_size), + description=f"Encoding corpus using {self.model_name}", + ) if batch_size > 1 else range(0, len(corpus), batch_size) + for idx in iterable: + batch = corpus[idx : idx + batch_size] + inputs = self._tokenizer( + batch, padding=True, truncation=True, return_tensors="pt" + ).to(get_torch_backend()) + with torch.no_grad(): + outputs = self._model(**inputs) + batch_embeddings = mean_pooling(outputs[0], inputs["attention_mask"]) + embeddings.append(batch_embeddings) + embeddings = torch.cat(embeddings, dim=0) + return embeddings + + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 32, + ): + """ + Indexes a dataset of text chunks and optionally saves the vector index to a file. + + This method retrieves a dataset of text chunks from a Weave reference, encodes the + text chunks into vector representations using the Contriever model, and stores the + resulting vector index. If an index name is provided, the vector index is saved to + a file in the safetensors format. Additionally, if a Weave run is active, the vector + index file is logged as an artifact to Weave. + + !!! example "Example Usage" + ```python + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + retriever = ContrieverRetriever() + retriever.index( + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", + batch_size=256, + ) + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. + """ + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + corpus = [row["text"] for row in self._chunk_dataset] + with torch.no_grad(): + vector_index = self.encode(corpus, batch_size) + self._vector_index = vector_index + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] + ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + {"model_name": self.model_name}, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", + ) + if cleanup: + shutil.rmtree(index_save_dir) + + @classmethod + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): + """ + Creates an instance of the class from a Weave artifact. + + This method retrieves a vector index and metadata from a Weave artifact stored in + Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave + reference. The vector index is loaded from a safetensors file and moved to the + appropriate device (CPU or GPU). The text chunks are converted into a list of + dictionaries. The method then returns an instance of the class initialized with + the retrieved model name, vector index, and chunk dataset. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + load_dotenv() + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + + Returns: + An instance of the class initialized with the retrieved model name, vector index, + and chunk dataset. + """ + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") + with safetensors.torch.safe_open( + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" + ) as f: + vector_index = f.get_tensor("vector_index") + device = torch.device(get_torch_backend()) + vector_index = vector_index.to(device) + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) + return cls( + model_name=metadata["model_name"], + vector_index=vector_index, + chunk_dataset=chunk_dataset, + ) + + @weave.op() + def retrieve( + self, + query: str, + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Retrieves the top-k most relevant chunks for a given query using the specified similarity metric. + + This method encodes the input query into an embedding and computes similarity scores between + the query embedding and the precomputed vector index. The similarity metric can be either + cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores + are returned as a list of dictionaries, each containing a chunk and its corresponding score. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.retrieve(query="What are Ribosomes?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + query = [query] + device = torch.device(get_torch_backend()) + with torch.no_grad(): + query_embedding = self.encode(query, batch_size=1).to(device) + if metric == SimilarityMetric.EUCLIDEAN: + scores = torch.squeeze(query_embedding @ self._vector_index.T) + else: + scores = F.cosine_similarity(query_embedding, self._vector_index) + scores = scores.cpu().numpy().tolist() + scores = argsort_scores(scores, descending=True)[:top_k] + retrieved_chunks = [] + for score in scores: + retrieved_chunks.append( + { + **self._chunk_dataset[score["original_index"]], + **{"score": score["item"]}, + } + ) + return retrieved_chunks + + @weave.op() + def predict( + self, + query: str, + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Predicts the top-k most relevant chunks for a given query using the specified similarity metric. + + This function is a wrapper around the `retrieve` method. It takes an input query string, + retrieves the top-k most relevant chunks from the precomputed vector index based on the + specified similarity metric, and returns the results as a list of dictionaries, each containing + a chunk and its corresponding relevance score. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import ContrieverRetriever + + load_dotenv() + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = ContrieverRetriever().from_index( + index_repo_id="geekyrakshit/grays-anatomy-index-contriever", + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + ) + retrieved_chunks = retriever.predict(query="What are Ribosomes?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + return self.retrieve(query, top_k, metric) diff --git a/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..ac400dacf329a532e99b3cac7967326b8c3d2c0f --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/medcpt_retrieval.py @@ -0,0 +1,335 @@ +import json +import os +import shutil +from typing import Optional, Union + +import huggingface_hub +import safetensors +import safetensors.torch +import torch +import torch.nn.functional as F +import weave +from datasets import Dataset, load_dataset +from rich.progress import track +from transformers import ( + AutoModel, + AutoTokenizer, + BertPreTrainedModel, + PreTrainedTokenizerFast, +) + +from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) + + +class MedCPTRetriever(weave.Model): + """ + A class to retrieve relevant text chunks using MedCPT models. + + This class provides methods to index a dataset of text chunks and retrieve the most relevant + chunks for a given query using MedCPT models. It uses separate models for encoding queries + and articles, and supports both cosine similarity and Euclidean distance as similarity metrics. + + Args: + query_encoder_model_name (str): The name of the model used for encoding queries. + article_encoder_model_name (str): The name of the model used for encoding articles. + chunk_size (Optional[int]): The maximum length of text chunks. + vector_index (Optional[torch.Tensor]): The vector index of encoded text chunks. + chunk_dataset (Optional[list[dict]]): The dataset of text chunks. + """ + + query_encoder_model_name: str + article_encoder_model_name: str + chunk_size: Optional[int] + _chunk_dataset: Optional[list[dict]] + _query_tokenizer: PreTrainedTokenizerFast + _article_tokenizer: PreTrainedTokenizerFast + _query_encoder_model: BertPreTrainedModel + _article_encoder_model: BertPreTrainedModel + _vector_index: Optional[torch.Tensor] + + def __init__( + self, + query_encoder_model_name: str = "ncbi/MedCPT-Query-Encoder", + article_encoder_model_name: str = "ncbi/MedCPT-Article-Encoder", + chunk_size: Optional[int] = None, + vector_index: Optional[torch.Tensor] = None, + chunk_dataset: Optional[list[dict]] = None, + ): + super().__init__( + query_encoder_model_name=query_encoder_model_name, + article_encoder_model_name=article_encoder_model_name, + chunk_size=chunk_size, + ) + self._query_tokenizer = AutoTokenizer.from_pretrained( + self.query_encoder_model_name + ) + self._article_tokenizer = AutoTokenizer.from_pretrained( + self.article_encoder_model_name + ) + self._query_encoder_model = AutoModel.from_pretrained( + self.query_encoder_model_name + ).to(get_torch_backend()) + self._article_encoder_model = AutoModel.from_pretrained( + self.article_encoder_model_name + ).to(get_torch_backend()) + self._chunk_dataset = chunk_dataset + self._vector_index = vector_index + + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 32, + ): + """ + Indexes a dataset of text chunks using the MedCPT model and optionally saves the vector index. + + This method retrieves a dataset of text chunks from a specified source, encodes the text + chunks into vector representations using the article encoder model, and stores the + resulting vector index. If an `index_repo_id` is provided, the vector index is saved + to disk in the safetensors format and optionally logged as a Huggingface artifact. + + !!! example "Example Usage" + ```python + import weave + from dotenv import load_dotenv + + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + load_dotenv() + retriever = MedCPTRetriever() + retriever.index( + chunk_dataset="geekyrakshit/grays-anatomy-chunks-test", + index_repo_id="geekyrakshit/grays-anatomy-index-medcpt", + ) + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. + + """ + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + corpus = [row["text"] for row in self._chunk_dataset] + vector_indices = [] + with torch.no_grad(): + for idx in track( + range(0, len(corpus), batch_size), + description="Encoding corpus using MedCPT", + ): + batch = corpus[idx : idx + batch_size] + encoded = self._article_tokenizer( + batch, + truncation=True, + padding=True, + return_tensors="pt", + max_length=self.chunk_size, + ).to(get_torch_backend()) + batch_vectors = ( + self._article_encoder_model(**encoded) + .last_hidden_state[:, 0, :] + .contiguous() + ) + vector_indices.append(batch_vectors) + + vector_index = torch.cat(vector_indices, dim=0) + self._vector_index = vector_index + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] + ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": self._vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + { + "query_encoder_model_name": self.query_encoder_model_name, + "article_encoder_model_name": self.article_encoder_model_name, + "chunk_size": self.chunk_size, + }, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", + ) + if cleanup: + shutil.rmtree(index_save_dir) + + @classmethod + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): + """ + Creates an instance of the class from a Huggingface repository. + + This method retrieves a vector index and metadata from a Huggingface repository. + It also retrieves a dataset of text chunks from the specified source. The vector + index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). + The method then returns an instance of the class initialized with the retrieved + model names, vector index, and chunk dataset. + + !!! example "Example Usage" + ```python + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + + Returns: + An instance of the class initialized with the retrieved model name, vector index, and chunk dataset. + """ + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") + with safetensors.torch.safe_open( + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" + ) as f: + vector_index = f.get_tensor("vector_index") + device = torch.device(get_torch_backend()) + vector_index = vector_index.to(device) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + return cls( + query_encoder_model_name=metadata["query_encoder_model_name"], + article_encoder_model_name=metadata["article_encoder_model_name"], + chunk_size=metadata["chunk_size"], + vector_index=vector_index, + chunk_dataset=chunk_dataset, + ) + + @weave.op() + def retrieve( + self, + query: str, + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Retrieves the top-k most relevant chunks for a given query using the specified similarity metric. + + This method encodes the input query into an embedding and computes similarity scores between + the query embedding and the precomputed vector index. The similarity metric can be either + cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores + are returned as a list of dictionaries, each containing a chunk and its corresponding score. + + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.retrieve(query="What is ribosome?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + query = [query] + device = torch.device(get_torch_backend()) + with torch.no_grad(): + encoded = self._query_tokenizer( + query, + truncation=True, + padding=True, + return_tensors="pt", + ).to(device) + query_embedding = self._query_encoder_model(**encoded).last_hidden_state[ + :, 0, : + ] + query_embedding = query_embedding.to(device) + if metric == SimilarityMetric.EUCLIDEAN: + scores = torch.squeeze(query_embedding @ self._vector_index.T) + else: + scores = F.cosine_similarity(query_embedding, self._vector_index) + scores = scores.cpu().numpy().tolist() + scores = argsort_scores(scores, descending=True)[:top_k] + retrieved_chunks = [] + for score in scores: + retrieved_chunks.append( + { + **self._chunk_dataset[score["original_index"]], + **{"score": score["item"]}, + } + ) + return retrieved_chunks + + @weave.op() + def predict( + self, + query: str, + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Predicts the most relevant chunks for a given query. + + This function uses the `retrieve` method to find the top-k relevant chunks + from the dataset based on the input query. It allows specifying the number + of top relevant chunks to retrieve and the similarity metric to use for scoring. + + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import MedCPTRetriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = MedCPTRetriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.predict(query="What is ribosome?") + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + return self.retrieve(query, top_k, metric) diff --git a/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py new file mode 100644 index 0000000000000000000000000000000000000000..67e61883f8e801f105a2f7e039e3ed16b89dd2e9 --- /dev/null +++ b/medrag_multi_modal/retrieval/text_retrieval/nv_embed_2.py @@ -0,0 +1,332 @@ +import json +import os +import shutil +from typing import Optional, Union + +import huggingface_hub +import safetensors +import torch +import torch.nn.functional as F +import weave +from datasets import Dataset, load_dataset +from rich.progress import track +from sentence_transformers import SentenceTransformer + +from medrag_multi_modal.retrieval.common import SimilarityMetric, argsort_scores +from medrag_multi_modal.utils import ( + fetch_from_huggingface, + get_torch_backend, + save_to_huggingface, +) + + +class NVEmbed2Retriever(weave.Model): + """ + `NVEmbed2Retriever` is a class for retrieving relevant text chunks from a dataset using the + [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2) model. + + This class leverages the SentenceTransformer model to encode text chunks into vector representations and + performs similarity-based retrieval. It supports indexing a dataset of text chunks, saving the vector index, + and retrieving the most relevant chunks for a given query. + + Args: + model_name (str): The name of the pre-trained model to use for encoding. + vector_index (Optional[torch.Tensor]): The tensor containing the vector representations of the indexed chunks. + chunk_dataset (Optional[list[dict]]): The dataset of text chunks to be indexed. + """ + + model_name: str + _chunk_dataset: Optional[list[dict]] + _model: SentenceTransformer + _vector_index: Optional[torch.Tensor] + + def __init__( + self, + model_name: str = "nvidia/NV-Embed-v2", + vector_index: Optional[torch.Tensor] = None, + chunk_dataset: Optional[list[dict]] = None, + ): + super().__init__(model_name=model_name) + self._model = SentenceTransformer( + self.model_name, + trust_remote_code=True, + model_kwargs={"torch_dtype": torch.float16}, + device=get_torch_backend(), + ) + self._model.max_seq_length = 32768 + self._model.tokenizer.padding_side = "right" + self._vector_index = vector_index + self._chunk_dataset = chunk_dataset + + def add_eos(self, input_examples): + input_examples = [ + input_example + self._model.tokenizer.eos_token + for input_example in input_examples + ] + return input_examples + + def index( + self, + chunk_dataset: Union[str, Dataset], + index_repo_id: Optional[str] = None, + cleanup: bool = True, + batch_size: int = 8, + ): + """ + Indexes a dataset of text chunks and optionally saves the vector index to a Huggingface repository. + + This method retrieves a dataset of text chunks from a specified source, encodes the + text chunks into vector representations using the NV-Embed-v2 model, and stores the + resulting vector index. If an index repository ID is provided, the vector index is saved to + a file in the safetensors format within the specified Huggingface repository. + + !!! example "Example Usage" + ```python + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + retriever = NVEmbed2Retriever() + retriever.index( + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + ) + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + cleanup (bool, optional): Whether to delete the local index directory after saving the vector index. + batch_size (int, optional): The batch size to use for encoding the corpus. + """ + self._chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + corpus = [row["text"] for row in self._chunk_dataset] + vector_indices = [] + + for idx in track( + range(0, len(corpus), batch_size), + description="Encoding corpus using NV-Embed-v2", + ): + batch = corpus[idx : idx + batch_size] + batch_embeddings = self._model.encode( + self.add_eos(batch), batch_size=len(batch), normalize_embeddings=True + ) + vector_indices.append(torch.tensor(batch_embeddings)) + + self._vector_index = torch.cat(vector_indices, dim=0) + with torch.no_grad(): + if index_repo_id: + index_save_dir = os.path.join( + ".huggingface", index_repo_id.split("/")[-1] + ) + os.makedirs(index_save_dir, exist_ok=True) + safetensors.torch.save_file( + {"vector_index": self._vector_index.cpu()}, + os.path.join(index_save_dir, "vector_index.safetensors"), + ) + commit_type = ( + "update" + if huggingface_hub.repo_exists(index_repo_id, repo_type="model") + else "add" + ) + with open( + os.path.join(index_save_dir, "config.json"), "w" + ) as config_file: + json.dump( + {"model_name": self.model_name}, + config_file, + indent=4, + ) + save_to_huggingface( + index_repo_id, + index_save_dir, + commit_message=f"{commit_type}: Contriever index", + ) + if cleanup: + shutil.rmtree(index_save_dir) + + @classmethod + def from_index(cls, chunk_dataset: Union[str, Dataset], index_repo_id: str): + """ + Creates an instance of the class from a Huggingface repository. + + This method retrieves a vector index and metadata from a Huggingface repository. It also retrieves a dataset of text chunks from a Huggingface dataset repository. The vector index is loaded from a safetensors file and moved to the appropriate device (CPU or GPU). The text chunks are converted into a list of dictionaries. The method then returns an instance of the class initialized with the retrieved model name, vector index, and chunk dataset. + Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave + reference. The vector index is loaded from a safetensors file and moved to the + appropriate device (CPU or GPU). The text chunks are converted into a list of + dictionaries. The method then returns an instance of the class initialized with + the retrieved model name, vector index, and chunk dataset. + + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a + dataset repository name or a dataset object can be provided. + index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved. + + Returns: + An instance of the class initialized with the retrieved model name, vector index, + and chunk dataset. + """ + index_dir = fetch_from_huggingface(index_repo_id, ".huggingface") + with safetensors.torch.safe_open( + os.path.join(index_dir, "vector_index.safetensors"), framework="pt" + ) as f: + vector_index = f.get_tensor("vector_index") + device = torch.device(get_torch_backend()) + vector_index = vector_index.to(device) + chunk_dataset = ( + load_dataset(chunk_dataset, split="chunks") + if isinstance(chunk_dataset, str) + else chunk_dataset + ) + with open(os.path.join(index_dir, "config.json"), "r") as config_file: + metadata = json.load(config_file) + return cls( + model_name=metadata["model_name"], + vector_index=vector_index, + chunk_dataset=chunk_dataset, + ) + + @weave.op() + def retrieve( + self, + query: list[str], + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Retrieves the top-k most relevant chunks for a given query using the specified similarity metric. + + This method encodes the input query into an embedding and computes similarity scores between + the query embedding and the precomputed vector index. The similarity metric can be either + cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores + are returned as a list of dictionaries, each containing a chunk and its corresponding score. + + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.retrieve(query="What is ribosome?") + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + query (list[str]): The input query strings to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + device = torch.device(get_torch_backend()) + with torch.no_grad(): + query_embedding = self._model.encode( + self.add_eos(query), normalize_embeddings=True + ) + query_embedding = torch.from_numpy(query_embedding).to(device) + if metric == SimilarityMetric.EUCLIDEAN: + scores = torch.squeeze(query_embedding @ self._vector_index.T) + else: + scores = F.cosine_similarity(query_embedding, self._vector_index) + scores = scores.cpu().numpy().tolist() + scores = argsort_scores(scores, descending=True)[:top_k] + retrieved_chunks = [] + for score in scores: + retrieved_chunks.append( + { + **self._chunk_dataset[score["original_index"]], + **{"score": score["item"]}, + } + ) + return retrieved_chunks + + @weave.op() + def predict( + self, + query: str, + top_k: int = 2, + metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Predicts the top-k most relevant chunks for a given query using the specified similarity metric. + + This method formats the input query string by prepending an instruction prompt and then calls the + `retrieve` method to get the most relevant chunks. The similarity metric can be either cosine similarity + or Euclidean distance. The top-k chunks with the highest similarity scores are returned. + + !!! example "Example Usage" + ```python + import weave + from medrag_multi_modal.retrieval.text_retrieval import NVEmbed2Retriever + + weave.init(project_name="ml-colabs/medrag-multi-modal") + retriever = NVEmbed2Retriever.from_index( + index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", + chunk_dataset="ashwiniai/medrag-text-corpus-chunks", + ) + retriever.predict(query="What is ribosome?") + ``` + + ??? note "Optional Speedup using Flash Attention" + If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply + installing the `flash-attn` package. + + ```bash + uv pip install flash-attn --no-build-isolation + ``` + + Args: + query (str): The input query string to search for relevant chunks. + top_k (int, optional): The number of top relevant chunks to retrieve. + metric (SimilarityMetric, optional): The similarity metric to use for scoring. + + Returns: + list: A list of dictionaries, each containing a retrieved chunk and its relevance score. + """ + query = [ + f"""Instruct: Given a question, retrieve passages that answer the question +Query: {query}""" + ] + return self.retrieve(query, top_k, metric) diff --git a/medrag_multi_modal/semantic_chunking.py b/medrag_multi_modal/semantic_chunking.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b82b3e7c40f9b80cd127a1646c71330f739457 --- /dev/null +++ b/medrag_multi_modal/semantic_chunking.py @@ -0,0 +1,135 @@ +import asyncio +from typing import Callable, Optional, Union + +import huggingface_hub +import semchunk +import tiktoken +import tokenizers +from datasets import Dataset, concatenate_datasets, load_dataset +from rich.progress import track +from transformers import PreTrainedTokenizer + +TOKENIZER_OR_TOKEN_COUNTER = Union[ + str, + tiktoken.Encoding, + PreTrainedTokenizer, + tokenizers.Tokenizer, + Callable[[str], int], +] + + +class SemanticChunker: + """ + SemanticChunker is a class that chunks documents into smaller segments and + publishes them as datasets. + + This class uses the `semchunk` library to break down large documents into + smaller, manageable chunks based on a specified tokenizer or token counter. + This is particularly useful for processing large text datasets where + smaller segments are needed for analysis or other operations. + + !!! example "Example Usage" + ```python + from medrag_multi_modal.semantic_chunking import SemanticChunker + + + chunker = SemanticChunker(chunk_size=256) + chunker.chunk( + document_dataset="geekyrakshit/grays-anatomy-test", + chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test", + ) + ``` + + Args: + tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or + token counter to be used for chunking. + chunk_size (Optional[int]): The size of each chunk. If not specified, the + default chunk size from `semchunk` will be used. + max_token_chars (Optional[int]): The maximum number of characters per token. + If not specified, the default value from `semchunk` will be used. + memoize (bool): Whether to memoize the chunking process for efficiency. + Default is True. + """ + + def __init__( + self, + tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base", + chunk_size: Optional[int] = None, + max_token_chars: Optional[int] = None, + memoize: bool = True, + ) -> None: + self.chunker = semchunk.chunkerify( + tokenizer_or_token_counter, + chunk_size=chunk_size, + max_token_chars=max_token_chars, + memoize=memoize, + ) + + def chunk( + self, + document_dataset: Union[Dataset, str], + chunk_dataset_repo_id: Optional[str] = None, + overwrite_dataset: bool = False, + ) -> Dataset: + """ + Chunks a document dataset into smaller segments and publishes them as a new dataset. + + This function takes a document dataset, either as a HuggingFace Dataset object or a string + representing the dataset repository ID, and chunks the documents into smaller segments using + the specified chunker. The resulting chunks are then optionally published to a HuggingFace + dataset repository. + + Args: + document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either + a HuggingFace Dataset object or a string representing the dataset repository ID. + chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish + the chunks to, if provided. Defaults to None. + overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. + + Returns: + Dataset: A HuggingFace Dataset object containing the chunks. + """ + document_dataset = ( + load_dataset(document_dataset, split="corpus") + if isinstance(document_dataset, str) + else document_dataset + ).to_list() + + chunks = [] + + async def process_document(idx, document): + document_chunks = self.chunker.chunk(str(document["text"])) + for chunk in document_chunks: + chunk_dict = {"document_idx": idx, "text": chunk} + for key, value in document.items(): + if key not in chunk_dict: + chunk_dict[key] = value + chunks.append(chunk_dict) + + async def process_all_documents(): + tasks = [] + for idx, document in track( + enumerate(document_dataset), + total=len(document_dataset), + description="Chunking documents", + ): + tasks.append(process_document(idx, document)) + await asyncio.gather(*tasks) + + asyncio.run(process_all_documents()) + + chunks.sort(key=lambda x: x["document_idx"]) + + dataset = Dataset.from_list(chunks) + if chunk_dataset_repo_id: + if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"): + if not overwrite_dataset: + dataset = concatenate_datasets( + [ + dataset, + load_dataset(chunk_dataset_repo_id, split="chunks"), + ] + ) + dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks") + + return dataset diff --git a/medrag_multi_modal/utils.py b/medrag_multi_modal/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b6237a4879199653c77235eb3ee9b0d32c764a --- /dev/null +++ b/medrag_multi_modal/utils.py @@ -0,0 +1,86 @@ +import base64 +import io + +import jsonlines +import torch +import wandb +from huggingface_hub import HfApi +from PIL import Image + + +def get_wandb_artifact( + artifact_name: str, + artifact_type: str, + get_metadata: bool = False, +) -> str: + if wandb.run: + artifact = wandb.use_artifact(artifact_name, type=artifact_type) + artifact_dir = artifact.download() + else: + api = wandb.Api() + artifact = api.artifact(artifact_name) + artifact_dir = artifact.download() + if get_metadata: + return artifact_dir, artifact.metadata + return artifact_dir + + +def get_torch_backend(): + if torch.cuda.is_available(): + if torch.backends.cuda.is_built(): + return "cuda" + if torch.backends.mps.is_available(): + if torch.backends.mps.is_built(): + return "mps" + return "cpu" + return "cpu" + + +def base64_encode_image(image: Image.Image, mimetype: str) -> str: + image.load() + if image.mode not in ("RGB", "RGBA"): + image = image.convert("RGB") + byte_arr = io.BytesIO() + image.save(byte_arr, format="PNG") + encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8") + encoded_string = f"data:{mimetype};base64,{encoded_string}" + return str(encoded_string) + + +def read_jsonl_file(file_path: str) -> list[dict[str, any]]: + with jsonlines.open(file_path) as reader: + for obj in reader: + return obj + + +def save_to_huggingface( + repo_id: str, local_dir: str, commit_message: str, private: bool = False +): + api = HfApi() + repo_url = api.create_repo( + repo_id=repo_id, + token=api.token, + private=private, + repo_type="model", + exist_ok=True, + ) + repo_id = repo_url.repo_id + api.upload_folder( + repo_id=repo_id, + commit_message=commit_message, + token=api.token, + folder_path=local_dir, + repo_type=repo_url.repo_type, + ) + + +def fetch_from_huggingface(repo_id: str, local_dir: str) -> str: + api = HfApi() + repo_url = api.repo_info(repo_id) + if repo_url is None: + raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") + + snapshot = api.snapshot_download(repo_id, revision=None, local_dir=local_dir) + if snapshot is None: + raise ValueError(f"Model {repo_id} not found on the Hugging Face Hub.") + return snapshot diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..c92b8110073e43000d923779828818dd796c888a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,101 @@ +[project] +name = "medrag-multi-modal" +version = "0.0.1" +description = "" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "bm25s[full]>=0.2.2", + "datasets>=3.1.0", + "einops>=0.8.0", + "firerequests>=0.0.7", + "pdf2image>=1.17.0", + "python-dotenv>=1.0.1", + "pymupdf4llm>=0.0.17", + "weave>=0.51.14", + "pip>=24.2", + "uv>=0.4.20", + "pytest>=8.3.3", + "PyPDF2>=3.0.1", + "PyStemmer>=2.2.0.3", + "safetensors>=0.4.5", + "isort>=5.13.2", + "black>=24.10.0", + "ruff>=0.6.9", + "marker-pdf>=0.2.17", + "mkdocs>=1.6.1", + "mkdocstrings>=0.26.1", + "mkdocstrings-python>=1.11.1", + "mkdocs-material>=9.5.39", + "mkdocs-minify-plugin>=0.8.0", + "mkdocs-glightbox>=0.4.0", + "mkdocs-jupyter>=0.25.0", + "jupyter>=1.1.1", + "pdfplumber>=0.11.4", + "semchunk>=2.2.0", + "tiktoken>=0.8.0", + "sentence-transformers>=3.2.0", + "google-generativeai>=0.8.3", + "mistralai>=1.1.0", + "instructor>=1.6.3", + "jsonlines>=4.0.0", + "opencv-python>=4.10.0.84", + "openai>=1.52.2", + "streamlit>=1.39.0", +] + +[project.optional-dependencies] +app = [ + "streamlit>=1.39.0", +] +core = [ + "bm25s[full]>=0.2.2", + "datasets>=3.1.0", + "einops>=0.8.0", + "firerequests>=0.0.7", + "marker-pdf>=0.2.17", + "pdf2image>=1.17.0", + "pdfplumber>=0.11.4", + "PyPDF2>=3.0.1", + "PyStemmer>=2.2.0.3", + "python-dotenv>=1.0.1", + "pymupdf4llm>=0.0.17", + "safetensors>=0.4.5", + "semchunk>=2.2.0", + "tiktoken>=0.8.0", + "weave>=0.51.18", + "sentence-transformers>=3.2.0", + "google-generativeai>=0.8.3", + "mistralai>=1.1.0", + "instructor>=1.6.3", + "jsonlines>=4.0.0", + "opencv-python>=4.10.0.84", + "openai>=1.52.2", +] +dev = [ + "pytest>=8.3.3", + "isort>=5.13.2", + "black>=24.10.0", + "ruff>=0.6.9", +] +docs = [ + "mkdocs>=1.6.1", + "mkdocstrings>=0.26.1", + "mkdocstrings-python>=1.11.1", + "mkdocs-material>=9.5.39", + "mkdocs-minify-plugin>=0.8.0", + "mkdocs-glightbox>=0.4.0", + "mkdocs-jupyter>=0.25.0", + "jupyter>=1.1.1", +] + +[project.scripts] +medrag = "medrag_multi_modal.cli:main" + +[tool.pytest.ini_options] +pythonpath = "." +testpaths = ["tests"] +filterwarnings = "ignore::DeprecationWarning" + +[tool.setuptools] +py-modules = ["medrag_multi_modal"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8aa85b060cff542af5b7e00093598f992ee01fd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,38 @@ +bm25s[full]>=0.2.2 +datasets>=3.1.0 +einops>=0.8.0 +firerequests>=0.0.7 +pdf2image>=1.17.0 +python-dotenv>=1.0.1 +pymupdf4llm>=0.0.17 +weave>=0.51.14 +pip>=24.2 +uv>=0.4.20 +pytest>=8.3.3 +PyPDF2>=3.0.1 +PyStemmer>=2.2.0.3 +safetensors>=0.4.5 +isort>=5.13.2 +black>=24.10.0 +ruff>=0.6.9 +marker-pdf>=0.2.17 +mkdocs>=1.6.1 +mkdocstrings>=0.26.1 +mkdocstrings-python>=1.11.1 +mkdocs-material>=9.5.39 +mkdocs-minify-plugin>=0.8.0 +mkdocs-glightbox>=0.4.0 +mkdocs-jupyter>=0.25.0 +jupyter>=1.1.1 +pdfplumber>=0.11.4 +semchunk>=2.2.0 +tiktoken>=0.8.0 +sentence-transformers>=3.2.0 +google-generativeai>=0.8.3 +mistralai>=1.1.0 +instructor>=1.6.3 +jsonlines>=4.0.0 +opencv-python>=4.10.0.84 +openai>=1.52.2 +streamlit>=1.39.0 +torch --index-url https://download.pytorch.org/whl/cpu \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..5e29c897e3a84397f83716b12e60cc33e019e523 --- /dev/null +++ b/test.py @@ -0,0 +1,12 @@ +import weave +from datasets import load_dataset + +weave.init("ml-colabs/medrag-multi-modal") +rows = load_dataset("cais/mmlu", "anatomy", split="test").to_list() +for idx, row in enumerate(rows): + rows[idx] = { + "query": row["question"], + "options": row["choices"], + "answer": row["answer"], + } +weave.publish(weave.Dataset(rows=rows, name="mmlu-anatomy-test"))