import torch import transformers import gradio as gr from ragatouille import RAGPretrainedModel import re from datetime import datetime import json import arxiv from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search # Constants RETRIEVE_RESULTS = 20 LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None'] DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2' GENERATE_KWARGS = { "temperature": None, "max_new_tokens": 512, "top_p": None, "do_sample": False, } # RAG Model setup RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert") try: gr.Info("Setting up retriever, please wait...") rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1) gr.Info("Retriever working successfully!") except Exception as e: gr.Warning(f"Retriever not working: {str(e)}") # Header setup mark_text = '# πŸ©ΊπŸ” Search Results\n' header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n" try: with open("README.md", "r") as f: mdfile = f.read() date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}' match = re.search(date_pattern, mdfile) date = match.group().split(': ')[1] formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') header_text += f'Index Last Updated: {formatted_date}\n' index_info = f"Semantic Search - up to {formatted_date}" except FileNotFoundError: index_info = "Semantic Search" database_choices = [index_info, 'Arxiv Search - Latest - (EXPERIMENTAL)'] # Arxiv API setup arx_client = arxiv.Client() is_arxiv_available = True check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS) if len(check_arxiv_result) == 0: is_arxiv_available = False print("Arxiv search not working, switching to default search ...") database_choices = [index_info] # Gradio UI setup with gr.Blocks(theme=gr.themes.Soft()) as demo: header = gr.Markdown(header_text) with gr.Group(): search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?') with gr.Accordion("Advanced Settings", open=False): with gr.Row(equal_height=True): llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model') llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context") database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source') stream_results = gr.Checkbox(value=True, label="Stream output", visible=False) output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True) input = gr.Textbox(show_label=False, visible=False) gr_md = gr.Markdown(mark_text) def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL): prompt_text_from_data = "" database_to_use = database_choice if database_choice == index_info: rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS) else: arxiv_search_success = True try: rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS) if len(rag_out) == 0: arxiv_search_success = False except Exception as e: arxiv_search_success = False gr.Warning(f"Arxiv Search not working: {str(e)}, switching to semantic search ...") if not arxiv_search_success: rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS) database_to_use = index_info md_text_updated = mark_text for i, rag_answer in enumerate(rag_out): if i < llm_results_use: md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True) prompt_text_from_data += f"{i+1}. {prompt_text}" else: md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use) md_text_updated += md_text_paper prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked) return md_text_updated, prompt def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False): model_disabled_text = "LLM Model is disabled" output = "" if llm_model_picked == 'None': if stream_outputs: for out in model_disabled_text: output += out yield output else: return model_disabled_text client = InferenceClient(llm_model_picked) try: response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS) if stream_outputs: for token in response: output += token yield SaveResponseAndRead(output) else: output = response except Exception as e: gr.Warning(f"LLM Inference failed: {str(e)}") output = "" return output search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text) demo.queue().launch()