import gradio as gr import os from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.llms import HuggingFacePipeline from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory from langchain.llms import HuggingFaceHub from langchain.memory import ConversationBufferWindowMemory from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import DirectoryLoader from langchain_community.document_loaders import UnstructuredHTMLLoader from pathlib import Path import chromadb from transformers import AutoTokenizer import transformers import torch import tqdm import accelerate from transformers import MBartForConditionalGeneration, MBart50TokenizerFast translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") languages_list = [("Gujarati", "gu_IN"), ('Hindi',"hi_IN") , ("Bengali","bn_IN"), ("Malayalam","ml_IN"), ("Marathi","mr_IN"), ("Tamil","ta_IN"), ("Telugu","te_IN")] lang_global = '' def intitalize_lang(language): global lang_global lang_global = language print("intitalize_lang"+lang_global) def english_to_indian(sentence): #print ("english_to_indian"+lang_global) translated_sentence = '' translation_tokenizer.src_lang = "en_xx" chunks = [sentence[i:i+500] for i in range(0, len(sentence), 500)] for chunk in chunks: encoded_hi = translation_tokenizer(chunk, return_tensors="pt") generated_tokens = translation_model.generate(**encoded_hi, forced_bos_token_id=translation_tokenizer.lang_code_to_id[lang_global] ) x = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) translated_sentence = translated_sentence + x[0] print(translated_sentence) return translated_sentence def indian_to_english(sentence): translated_sentence = '' translation_tokenizer.src_lang = lang_global chunks = [sentence[i:i+500] for i in range(0, len(sentence), 500)] for chunk in chunks: encoded_hi = translation_tokenizer(chunk, return_tensors="pt") generated_tokens = translation_model.generate(**encoded_hi, forced_bos_token_id=translation_tokenizer.lang_code_to_id["en_XX"] ) x = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) translated_sentence = translated_sentence + x[0] print(translated_sentence) return translated_sentence llm_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" tokenizer_name = "thenlper/gte-small" # default_persist_directory = './chroma_HF/' list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \ "google/gemma-7b-it","google/gemma-2b-it", \ "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \ "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Load PDF document and create doc splits def load_doc(list_file_path, chunk_size, chunk_overlap): # Processing for one document only # loader = PyPDFLoader(file_path) # pages = loader.load() loaders = [UnstructuredHTMLLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50) text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( AutoTokenizer.from_pretrained(tokenizer_name), chunk_size=chunk_size, chunk_overlap=chunk_overlap, strip_whitespace=True) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, # persist_directory=default_persist_directory ) return vectordb # Load vector database def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma( # persist_directory=default_persist_directory, embedding_function=embedding) return vectordb # Initialize langchain LLM chain def initialize_llmchain(temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): progress(0.1, desc="Initializing HF tokenizer...") # HuggingFaceHub uses HF inference endpoints progress(0.5, desc="Initializing HF Hub...") llm = HuggingFaceHub(repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}) progress(0.75, desc="Defining buffer memory...") #memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer',return_messages=True) memory = ConversationBufferWindowMemory(memory_key = 'chat_history', k=3,output_key='answer',return_messages=True) retriever=vector_db.as_retriever() progress(0.8, desc="Defining retrieval chain...") qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff", memory=memory,return_source_documents=True,verbose=False) progress(0.9, desc="Done!") return qa_chain # Initialize database def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()): # Create list of documents (when valid) list_file_path = [x.name for x in list_file_obj if x is not None] # Create collection_name for vector database progress(0.1, desc="Creating collection name...") collection_name = Path(list_file_path[0]).stem # Fix potential issues from naming convention ## Remove space collection_name = collection_name.replace(" ","-") # Remove periods collection_name = collection_name.replace(".","_") octets = collection_name.split('.') if len(octets) != 4 or any(not 0 <= int(octet) <= 255 for octet in octets): print('ipv4 address') # Not a valid IPv4 address ## Limit lenght to 50 characters collection_name = collection_name[:50] ## Enforce start and end as alphanumeric character if not collection_name[0].isalnum(): collection_name[0] = 'A' if not collection_name[-1].isalnum(): collection_name[-1] = 'Z' # print('list_file_path: ', list_file_path) print('Collection name: ', collection_name) progress(0.25, desc="Loading document...") # Load document and create splits doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) # Create or load vector database progress(0.5, desc="Generating vector database...") # global vector_db vector_db = create_db(doc_splits, collection_name) progress(0.9, desc="Done!") return vector_db, collection_name, "Complete!" def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): # print("llm_option",llm_option) llm_name = llm_model print("llm_name: ",llm_name) qa_chain = initialize_llmchain(llm_temperature, max_tokens, top_k, vector_db, progress) return qa_chain, "Complete!" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history): formatted_chat_history = format_chat_history(message, history) #print("formatted_chat_history",formatted_chat_history) # Generate response using QA chain response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source3 = response_sources[2].page_content.strip() # Langchain sources are zero-based response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 response_source3_page = response_sources[2].metadata["page"] + 1 # print ('chat response: ', response_answer) # print('DB source', response_sources) # Append user message and response to chat history new_history = history + [(message, response_answer)] # return gr.update(value=""), new_history, response_sources[0], response_sources[1] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page def upload_file(file_obj): list_file_path = [] for idx, file in enumerate(file_obj): file_path = file_obj.name list_file_path.append(file_path) # print(file_path) # initialize_database(file_path, progress) return list_file_path def demo(): with gr.Blocks(theme=gr.themes.Soft()) as demo: vector_db = gr.State() qa_chain = gr.State() collection_name = gr.State() pdf_directory = '/home/user/app/htmls/' def process_pdfs(): # List all PDF files in the directory #pdf_files = [os.path.join(pdf_directory, file) for file in os.listdir(pdf_directory) if file.endswith(".html")] pdf_files = [os.path.join(pdf_directory, file) for file in os.listdir(pdf_directory)] print('pdf files: ',len(pdf_files)) return pdf_files # Create a dictionary with the necessary information pdf_dict = {"value": process_pdfs, "height": 100, "file_count": "multiple", "visible": False, "file_types": ["html"], "interactive": True, "label": "Uploaded PDF documents"} # Create a gr.Files component with the dictionary #document_files = gr.Files(**pdf_dict) with gr.Row(): # document = gr.Files(value = process_pdfs, height=100, file_count="multiple",visible=True, # file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)") document = gr.Files(**pdf_dict) with gr.Row(): db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database",visible=False) with gr.Accordion("Advanced options - Document text splitter", open=False, visible=False): with gr.Row(): slider_chunk_size = gr.Slider(value=2000, label="Chunk size", info="Chunk size", interactive=False, visible=False) with gr.Row(): slider_chunk_overlap = gr.Slider(value=256, label="Chunk overlap", info="Chunk overlap", interactive=False, visible=False) with gr.Accordion("Advanced options - LLM model", open=False, visible=False): with gr.Row(): slider_temperature = gr.Slider(value = 0.1,visible=False) with gr.Row(): slider_maxtokens = gr.Slider(value = 4000, visible=False) with gr.Row(): slider_topk = gr.Slider(value = 3, visible=False) with gr.Row(): lang_btn = gr.Dropdown(languages_list, label="Languages", value = languages_list[1], type="value", info="Choose your language",interactive = True) lang_btn.change(intitalize_lang, inputs = lang_btn) with gr.Row(): db_progress = gr.Textbox(label="Vector database initialization", value="None", visible=True) llm_progress = gr.Textbox(value="None",label="QA chain initialization", visible=True) with gr.Row(): db_btn = gr.Button("Generate vector database") qachain_btn = gr.Button("Initialize model") # with gr.Row(): # with gr.Row(): chatbot = gr.Chatbot(height=300, bubble_full_width = False, layout = 'panel') chatbot.change(preprocess = english_to_indian, postprocess = indian_to_english) with gr.Row(): msg = gr.Textbox(placeholder="Type message", container=True) with gr.Accordion("References", open=False): with gr.Row(): doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20) source1_page = gr.Number(label="Page", scale=1) with gr.Row(): doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20) source2_page = gr.Number(label="Page", scale=1) with gr.Row(): doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20) source3_page = gr.Number(label="Page", scale=1) with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.ClearButton([msg, chatbot]) # Preprocessing events #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document]) db_btn.click(initialize_database, \ inputs=[document, slider_chunk_size, slider_chunk_overlap], \ outputs=[vector_db, collection_name, db_progress]) qachain_btn.click(initialize_LLM, \ inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \ inputs=None, \ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ queue=False) # Chatbot events msg.submit(conversation, \ inputs=[qa_chain, msg, chatbot], \ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ queue=False) submit_btn.click(conversation, \ inputs=[qa_chain, msg, chatbot], \ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ queue=False) clear_btn.click(lambda:[None,"",0,"",0,"",0], \ inputs=None, \ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ queue=False) demo.queue().launch(debug=True) if __name__ == "__main__": demo()