carlosgonzalezmartinez commited on
Commit
00536ec
·
verified ·
1 Parent(s): 8714c7f
Files changed (2) hide show
  1. app.py +237 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ## Setup
4
+ !pip install -q openai==1.23.2 \
5
+ tiktoken==0.6.0 \
6
+ pypdf==4.0.1 \
7
+ langchain==0.1.1 \
8
+ langchain-community==0.0.13 \
9
+ chromadb==0.4.22 \
10
+ sentence-transformers==2.3.1 \
11
+ datasets
12
+
13
+
14
+
15
+
16
+
17
+ # Import the necessary Libraries
18
+ import os
19
+ import json
20
+ import uuid
21
+ import gradio as gr
22
+ import tiktoken
23
+ from datasets import load_dataset
24
+
25
+ import pandas as pd
26
+
27
+ from openai import OpenAI
28
+
29
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
30
+ from langchain_core.documents import Document
31
+
32
+ from langchain_community.embeddings.sentence_transformer import (
33
+ SentenceTransformerEmbeddings
34
+ )
35
+ from langchain_community.vectorstores import Chroma
36
+
37
+ from google.colab import userdata, drive
38
+
39
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
40
+
41
+ from google.colab import userdata
42
+
43
+
44
+
45
+ from huggingface_hub import CommitScheduler
46
+ from pathlib import Path
47
+
48
+
49
+
50
+
51
+ # Create Client
52
+
53
+
54
+ OpenAI__api_key = userdata.get('CarlosGM')
55
+ client = OpenAI(
56
+ api_key=OpenAI__api_key
57
+ )
58
+
59
+ model_name = 'gpt-3.5-turbo'
60
+
61
+
62
+ # Define the embedding model and the vectorstore
63
+
64
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
65
+
66
+ # Load the persisted vectorDB
67
+
68
+ ## persisted_vectordb_location = '/content/drive/MyDrive/dataset_db'
69
+
70
+ dataset_10k_collection = 'Dataset-IBM-Meta-aws-google-msft'
71
+
72
+ vectorstore_persisted = Chroma(
73
+ collection_name=dataset_10k_collection,
74
+ persist_directory= './dataset_db',
75
+ embedding_function=embedding_model
76
+ )
77
+
78
+
79
+ retriever = vectorstore_persisted.as_retriever(
80
+ search_type='similarity',
81
+ search_kwargs={'k': 5}
82
+ )
83
+
84
+
85
+ # Prepare the logging functionality
86
+
87
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
88
+ log_folder = log_file.parent
89
+
90
+ scheduler = CommitScheduler(
91
+ repo_id="10k-logs",
92
+ repo_type="dataset",
93
+ folder_path=log_folder,
94
+ path_in_repo="data",
95
+ every=2
96
+ )
97
+
98
+ # Define the Q&A system message
99
+
100
+ qna_system_message = """
101
+ You are an assistant to a financial services firm who answers user queries on annual reports.
102
+ User input will have the context required by you to answer user questions.
103
+ This context will begin with the token: ###Context.
104
+ The context contains references to specific portions of a document relevant to the user query.
105
+ The source for a context will begin with the token ###Source
106
+
107
+ User questions will begin with the token: ###Question.
108
+
109
+
110
+
111
+ Please answer only using the context provided in the input. Do not mention anything about the context in your final answer.
112
+
113
+ Please adhere to the following guidelines:
114
+ - Your response should only be about the question asked and nothing else.
115
+ - Answer only using the context provided.
116
+ - Do not mention anything about the context in your final answer.
117
+ - If the answer is not found in the context, it is very very important for you to respond with "I don't know. Please check the docs @ 'Dataset-10k file'"
118
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
119
+ - Do not make up sources. Use the files provided in the sources section of the context and nothing else. You are prohibited from providing other sources.
120
+
121
+
122
+ If the answer is not found in the context, respond "I don't know".
123
+
124
+ Here is an example of how to structure your response:
125
+
126
+ Answer:
127
+ [Answer]
128
+
129
+ Source:
130
+ [Source]
131
+
132
+ """
133
+
134
+
135
+
136
+
137
+ # Define the user message template
138
+
139
+ qna_user_message_template = """
140
+ ###Context
141
+ Here are some documents that are relevant to the question.
142
+ {context}
143
+
144
+ ###Question
145
+ {question}
146
+ ""
147
+
148
+
149
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
150
+
151
+ def predict(user_input,company):
152
+
153
+ filter = "dataset/"+company+"-10-k-2023.pdf"
154
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
155
+
156
+ # Create context_for_query
157
+
158
+
159
+ context_list = [d.page_content for d in relevant_document_chunks]
160
+ context_for_query = ". ".join(context_list)
161
+
162
+
163
+ # Create messages
164
+
165
+ prompt = [
166
+ {'role':'system', 'content': qna_system_message},
167
+ {'role': 'user', 'content': qna_user_message_template.format(
168
+ context=context_for_query,
169
+ question=user_input
170
+ )
171
+ }
172
+ ]
173
+
174
+
175
+ # Get response from the LLM
176
+
177
+ try:
178
+ response = client.chat.completions.create(
179
+ model=model_name,
180
+ messages=prompt,
181
+ temperature=0
182
+ )
183
+
184
+ prediction = response.choices[0].message.content.strip()
185
+ except Exception as e:
186
+ prediction = f'Sorry, I encountered the following error: \n {e}'
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+ # While the prediction is made, log both the inputs and outputs to a local log file
196
+
197
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
198
+ # access
199
+
200
+ with scheduler.lock:
201
+ with log_file.open("a") as f:
202
+ f.write(json.dumps(
203
+ {
204
+ 'user_input': user_input,
205
+ 'retrieved_context': context_for_query,
206
+ 'model_response': prediction
207
+ }
208
+ ))
209
+ f.write("\n")
210
+
211
+ return prediction
212
+
213
+ # Set-up the Gradio UI
214
+
215
+
216
+
217
+ # Add text box and radio button to the interface
218
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
219
+
220
+ textbox = gr.Textbox(placeholder='Enter your query here', lines=6)
221
+ company = gr.Radio(['aws', 'google', ibm, 'meta', 'msft'], label= "Select Company 10-k Report")
222
+
223
+ # Create the interface
224
+
225
+ demo = gr.Interface(
226
+ fn=predict,
227
+ inputs=[textbox,company],
228
+ outputs= 'text'
229
+ title= '10-k Report Q&A',
230
+ description = 'This Web API presents an inteface to ask questions about the 10-k reports')
231
+
232
+
233
+ # For the inputs parameter of Interface provide [textbox,company]
234
+
235
+
236
+ demo.queue()
237
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openai==1.23.2
2
+ chromadb==0.4.22
3
+ langchain==0.1.9
4
+ langchain-community==0.0.32
5
+ sentence-transformers==2.3.1