import gradio as gr import os import torch from llama_parse import LlamaParse from llama_index.core import StorageContext, load_index_from_storage from llama_index.core.indices import MultiModalVectorStoreIndex from llama_index.core.schema import Document, ImageDocument from llama_index.embeddings.huggingface import HuggingFaceEmbedding example_indexes = { "IONIQ 2024": "./iconiq_report_index", "Uber 10k 2021": "./uber_index", } device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" image_embed_model = HuggingFaceEmbedding( model_name="llamaindex/vdr-2b-multi-v1", device=device, trust_remote_code=True, token=os.getenv("HUGGINGFACE_TOKEN"), model_kwargs={"torch_dtype": torch.float16}, embed_batch_size=2, ) text_embed_model = HuggingFaceEmbedding( model_name="BAAI/bge-small-en", device=device, trust_remote_code=True, token=os.getenv("HUGGINGFACE_TOKEN"), embed_batch_size=2, ) def load_index(index_path: str) -> MultiModalVectorStoreIndex: storage_context = StorageContext.from_defaults(persist_dir=index_path) return load_index_from_storage( storage_context, embed_model=text_embed_model, image_embed_model=image_embed_model, ) def create_index(file, llama_parse_key, progress=gr.Progress()): if not file or not llama_parse_key: return None, "Please provide both a file and LlamaParse API key" try: progress(0, desc="Initializing LlamaParse...") parser = LlamaParse( api_key=llama_parse_key, take_screenshot=True, ) # Process document progress(0.2, desc="Processing document with LlamaParse...") md_json_obj = parser.get_json_result(file.name)[0] progress(0.4, desc="Downloading and processing images...") image_dicts = parser.get_images( [md_json_obj], download_path=os.path.join(os.path.dirname(file.name), f"{file.name}_images") ) # Create text document progress(0.6, desc="Creating text documents...") text = "" for page in md_json_obj["pages"]: text += page["md"] + "\n\n" text_docs = [Document(text=text.strip())] # Create image documents progress(0.8, desc="Creating image documents...") image_docs = [] for image_dict in image_dicts: image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"])) # Create index progress(0.9, desc="Creating final index...") index = MultiModalVectorStoreIndex.from_documents( text_docs + image_docs, embed_model=text_embed_model, image_embed_model=image_embed_model, ) progress(1.0, desc="Complete!") return index, "Index created successfully!" except Exception as e: return None, f"Error creating index: {str(e)}" def run_search(index, query, text_top_k, image_top_k): if not index: return "Please create or select an index first.", [], [] retriever = index.as_retriever( similarity_top_k=text_top_k, image_similarity_top_k=image_top_k, ) image_nodes = retriever.text_to_image_retrieve(query) text_nodes = retriever.text_retrieve(query) # Extract text and scores from nodes text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes] # Load images and scores image_results = [] for node in image_nodes: if hasattr(node.node, 'image_path') and os.path.exists(node.node.image_path): try: image_results.append(( node.node.image_path, f"Similarity: {node.score:.3f}", )) except Exception as e: print(f"Error loading image {node.node.image_path}: {e}") return "Search completed!", text_results, image_results # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Multi-Modal Retrieval with LlamaIndex and llamaindex/vdr-2b-multi-v1") gr.Markdown(""" This demo shows how to use the new `llamaindex/vdr-2b-multi-v1` model for multi-modal document search. Using this model, we can index images and perform text-to-image retrieval. This demo compares to pure text retrieval using the `BAAI/bge-small-en` model. Is this a fair comparison? Not really, but it's the easiest to run in a limited huggingface space, and shows the strengths of screenshot-based retrieval. """ ) with gr.Row(): with gr.Column(): # Index selection/creation with gr.Tab("Use Existing Index"): existing_index_dropdown = gr.Dropdown( choices=list(example_indexes.keys()), label="Select Pre-made Index", value=list(example_indexes.keys())[0] ) with gr.Tab("Create New Index"): gr.Markdown( """ To create a new index, enter your LlamaParse API key and upload a PDF. You can get a free API key by signing up [here](https://cloud.llamaindex.ai). Processing will take a few minutes when creating a new index, depending on the size of the document. """ ) file_upload = gr.File(label="Upload PDF") llama_parse_key = gr.Textbox( label="LlamaParse API Key", type="password" ) create_btn = gr.Button("Create Index") create_status = gr.Textbox(label="Status", interactive=False) # Search controls query_input = gr.Textbox(label="Search Query", value="What is the Executive Summary?") with gr.Row(): text_top_k = gr.Slider( minimum=1, maximum=10, value=2, step=1, label="Text Top-K" ) image_top_k = gr.Slider( minimum=1, maximum=10, value=2, step=1, label="Image Top-K" ) search_btn = gr.Button("Search") with gr.Column(): # Results display status_output = gr.Textbox(label="Search Status") image_output = gr.Gallery( label="Retrieved Images", show_label=True, # This will show the similarity score captions elem_id="gallery" ) text_output = gr.JSON( label="Retrieved Text with Similarity Scores", elem_id="text_results" ) # State index_state = gr.State() # Load default index on startup default_index = load_index(example_indexes["IONIQ 2024"]) index_state.value = default_index # Event handlers def load_existing_index(index_name): if index_name: try: index = load_index(example_indexes[index_name]) return index, f"Loaded index: {index_name}" except Exception as e: return None, f"Error loading index: {str(e)}" return None, "No index selected" existing_index_dropdown.change( fn=load_existing_index, inputs=[existing_index_dropdown], outputs=[index_state, create_status], api_name=False ) create_btn.click( fn=create_index, inputs=[file_upload, llama_parse_key], outputs=[index_state, create_status], api_name=False, show_progress=True # Enable progress bar ) search_btn.click( fn=run_search, inputs=[index_state, query_input, text_top_k, image_top_k], outputs=[status_output, text_output, image_output], api_name=False ) gr.Markdown(""" This demo was built with [LlamaIndex](https://docs.llamaindex.ai) and [LlamaParse](https://cloud.llamaindex.ai). To see more multi-modal demos, check out the [llama parse examples](https://github.com/run-llama/llama_parse/tree/main/examples/multimodal). """ ) if __name__ == "__main__": demo.launch()