PabloVD commited on
Commit
114ce4a
·
1 Parent(s): ab23f0c

First commit

Browse files
Files changed (3) hide show
  1. app.py +34 -0
  2. requirements.txt +4 -0
  3. worker.py +93 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import worker
3
+ import requests
4
+ from pathlib import Path
5
+ import torchvision
6
+ torchvision.disable_beta_transforms_warning()
7
+
8
+ # Get data from url
9
+ url = 'https://camels.readthedocs.io/_/downloads/en/latest/pdf/'
10
+ r = requests.get(url, stream=True)
11
+ document_path = Path('metadata.pdf')
12
+ document_path.write_bytes(r.content)
13
+
14
+ worker.process_document(document_path)
15
+
16
+ def handle_prompt(message, history):
17
+ bot_response = worker.process_prompt(message)
18
+ return bot_response
19
+
20
+ greetingsmessage = "Hi, I'm the CAMELS DocBot, I'm here to assist you with any question related to the CAMELS simulations documentation"
21
+ example_questions = [
22
+ "How can i read a halo file?",
23
+ "Which simulation suites are included in CAMELS?",
24
+ "Which are the largest volumes in CAMELS simulations?",
25
+ "How can I get the power spectrum of a simulation?"
26
+ ]
27
+ # chatbot = gr.Chatbot(value=[{"role": "assistant", "content": greetingsmessage}])
28
+ # chatbot = gr.Chatbot(value=[[None, greetingsmessage]])
29
+ # chatbot = gr.Chatbot(value=gr.ChatMessage(role="assistant",content="How can I help you?"))
30
+ # chatbot = gr.Chatbot(placeholder=greetingsmessage)
31
+
32
+ demo = gr.ChatInterface(handle_prompt, type="messages", title="CAMELS DocBot",examples=example_questions, theme=gr.themes.Soft(), description=greetingsmessage)#, chatbot=chatbot)
33
+
34
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-huggingface
4
+ chromadb
worker.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from langchain.chains import RetrievalQA
4
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_community.vectorstores import Chroma
8
+ from langchain_huggingface import HuggingFaceEndpoint
9
+
10
+ # Check for GPU availability and set the appropriate device for computation.
11
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Global variables
14
+ conversation_retrieval_chain = None
15
+ chat_history = []
16
+ llm_hub = None
17
+ embeddings = None
18
+
19
+ # Function to initialize the language model and its embeddings
20
+ def init_llm():
21
+ global llm_hub, embeddings
22
+ # Set up the environment variable for HuggingFace and initialize the desired model.
23
+ # tokenfile = open("api_token.txt")
24
+ # api_token = tokenfile.readline().replace("\n","")
25
+ # tokenfile.close()
26
+ # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_token
27
+
28
+ # repo name for the model
29
+ # model_id = "tiiuae/falcon-7b-instruct"
30
+ model_id = "microsoft/Phi-3.5-mini-instruct"
31
+ # model_id = "meta-llama/Llama-3.2-1B-Instruct"
32
+
33
+ # load the model into the HuggingFaceHub
34
+ llm_hub = HuggingFaceEndpoint(repo_id=model_id, temperature=0.1, max_new_tokens=600, model_kwargs={"max_length":600})
35
+ llm_hub.client.api_url = 'https://api-inference.huggingface.co/models/'+model_id
36
+ # llm_hub.invoke('foo bar')
37
+
38
+ #Initialize embeddings using a pre-trained model to represent the text data.
39
+ embedddings_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
40
+ # embedddings_model = "sentence-transformers/all-MiniLM-L6-v2"
41
+ embeddings = HuggingFaceInstructEmbeddings(
42
+ model_name=embedddings_model,
43
+ model_kwargs={"device": DEVICE}
44
+ )
45
+
46
+
47
+ # Function to process a PDF document
48
+ def process_document(document_path):
49
+ global conversation_retrieval_chain
50
+
51
+ # Load the document
52
+ loader = PyPDFLoader(document_path)
53
+ documents = loader.load()
54
+
55
+ # Split the document into chunks
56
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
57
+ texts = text_splitter.split_documents(documents)
58
+
59
+ # Create an embeddings database using Chroma from the split text chunks.
60
+ db = Chroma.from_documents(texts, embedding=embeddings)
61
+
62
+
63
+ # --> Build the QA chain, which utilizes the LLM and retriever for answering questions.
64
+ # By default, the vectorstore retriever uses similarity search.
65
+ # If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type (search_type="mmr").
66
+ # You can also specify search kwargs like k to use when doing retrieval. k represent how many search results send to llm
67
+ conversation_retrieval_chain = RetrievalQA.from_chain_type(
68
+ llm=llm_hub,
69
+ chain_type="stuff",
70
+ retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
71
+ return_source_documents=False,
72
+ input_key = "question"
73
+ # chain_type_kwargs={"prompt": prompt} # if you are using prompt template, you need to uncomment this part
74
+ )
75
+
76
+
77
+ # Function to process a user prompt
78
+ def process_prompt(prompt):
79
+ global conversation_retrieval_chain
80
+ global chat_history
81
+
82
+ # Query the model
83
+ output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
84
+ answer = output["result"]
85
+
86
+ # Update the chat history
87
+ chat_history.append((prompt, answer))
88
+
89
+ # Return the model's response
90
+ return answer
91
+
92
+ # Initialize the language model
93
+ init_llm()