DrishtiSharma commited on
Commit
b6bff6f
·
verified ·
1 Parent(s): 7e4c68c

Create patentwiz/qa_agent.py

Browse files
Files changed (1) hide show
  1. patentwiz/qa_agent.py +333 -0
patentwiz/qa_agent.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import nltk
4
+ import openai
5
+ import chromadb
6
+ from langchain.document_loaders import UnstructuredXMLLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.embeddings.openai import OpenAIEmbeddings
9
+ from langchain.vectorstores import Chroma
10
+ from langchain.chat_models import ChatOpenAI
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.document_loaders import TextLoader
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains import AnalyzeDocumentChain
15
+ from langchain.chains.question_answering import load_qa_chain
16
+ from langchain.callbacks import get_openai_callback
17
+ from langchain.llms import OpenAI
18
+ from langchain.vectorstores import FAISS
19
+ from langchain.text_splitter import CharacterTextSplitter
20
+
21
+ # Clear ChromaDB cache to fix tenant issue
22
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
23
+
24
+ # Move variables and functions that don't need to be in the main function outside
25
+ nltk.download("punkt", quiet=True)
26
+
27
+ from nltk import word_tokenize, sent_tokenize
28
+
29
+
30
+ openai.api_key = os.getenv("OPENAI_API_KEY")
31
+ if openai.api_key is None:
32
+ raise Exception("OPENAI_API_KEY not found in environment variables")
33
+
34
+ embeddings = OpenAIEmbeddings()
35
+
36
+
37
+ def split_docs(documents, chunk_size=1000, chunk_overlap=0):
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
40
+ )
41
+ return text_splitter.split_documents(documents)
42
+
43
+
44
+ def call_QA_to_json(
45
+ prompt, year, month, day, saved_patent_names, index=0, logging=True, model_name="gpt-3.5-turbo"
46
+ ):
47
+ """
48
+ Generate embeddings from txt documents, retrieve data based on the provided prompt, and return the result as a JSON object.
49
+ Parameters:
50
+ prompt (str): The input prompt for the retrieval process.
51
+ year (int): The year part of the data folder name.
52
+ month (int): The month part of the data folder name.
53
+ day (int): The day part of the data folder name.
54
+ saved_patent_names (list): A list of strings containing the names of saved patent text files.
55
+ index (int): The index of the saved patent text file to process. Default is 0.
56
+ logging (bool): The boolean to print logs
57
+ Returns:
58
+ tuple: A tuple containing two elements:
59
+ - Cost of OpenAI API
60
+ - A JSON string representing the output from the retrieval chain.
61
+ This function loads the specified txt file, generates embeddings from its content,
62
+ and uses a retrieval chain to retrieve data based on the provided prompt.
63
+ The retrieved data is returned as a JSON object, and the raw documents are returned as a list of strings.
64
+ The output is also written to a file in the 'output' directory with the name '{index}.json'.
65
+ """
66
+
67
+ llm = ChatOpenAI(model_name=model_name, temperature=0, cache=False)
68
+ file_path = os.path.join(
69
+ os.getcwd(),
70
+ "data",
71
+ "ipa" + str(year)[2:] + f"{month:02d}" + f"{day:02d}",
72
+ saved_patent_names[index],
73
+ )
74
+
75
+ if logging:
76
+ print(f"Loading documents from: {file_path}")
77
+ loader = TextLoader(file_path)
78
+ documents_raw = loader.load()
79
+
80
+ documents = split_docs(documents_raw)
81
+
82
+
83
+ if logging:
84
+ print("Generating embeddings and persisting...")
85
+
86
+ vectordb = Chroma.from_documents(
87
+ documents=documents, embedding=embeddings,
88
+ )
89
+
90
+ # vectordb.persist()
91
+ PROMPT_FORMAT = """
92
+ Task: Use the following pieces of context to answer the question at the end.
93
+ {context}
94
+ Question: {question}
95
+ """
96
+
97
+ PROMPT = PromptTemplate(
98
+ template=PROMPT_FORMAT, input_variables=["context", "question"]
99
+ )
100
+
101
+ chain_type_kwargs = {"prompt": PROMPT}
102
+
103
+
104
+
105
+ retrieval_chain = RetrievalQA.from_chain_type(
106
+ llm, chain_type="stuff",
107
+ retriever=vectordb.as_retriever(),
108
+ chain_type_kwargs=chain_type_kwargs,
109
+ # return_source_documents=True
110
+
111
+ )
112
+
113
+ if logging:
114
+ print("Running retrieval chain...")
115
+
116
+ with get_openai_callback() as cb:
117
+ output = retrieval_chain.run(prompt)
118
+ if logging:
119
+ print(f"Total Tokens: {cb.total_tokens}")
120
+ print(f"Prompt Tokens: {cb.prompt_tokens}")
121
+ print(f"Completion Tokens: {cb.completion_tokens}")
122
+ print(f"Successful Requests: {cb.successful_requests}")
123
+ print(f"Total Cost (USD): ${cb.total_cost}")
124
+ cost = cb.total_cost
125
+
126
+
127
+ try:
128
+ # Convert output to dictionary
129
+ output_dict = json.loads(output)
130
+
131
+ # Manually assign the Patent Identifier
132
+ output_dict["Patent Identifier"] = saved_patent_names[index].split("-")[0]
133
+
134
+
135
+ # Check if the directory 'output' exists, if not create it
136
+ if not os.path.exists("output"):
137
+ os.makedirs("output")
138
+
139
+ if logging:
140
+ print("Writing the output to a file...")
141
+
142
+ with open(f"output/{saved_patent_names[index]}_{model_name}.json", "w", encoding="utf-8") as json_file:
143
+ json.dump(output_dict, json_file, indent=4, ensure_ascii=False)
144
+
145
+ if logging:
146
+ print("Call to 'call_QA_to_json' completed.")
147
+
148
+ except Exception as e:
149
+ print("An error occurred while processing the output.")
150
+ print("Error message:", str(e))
151
+
152
+ try:
153
+ vectordb.delete(ids=["*"])
154
+ except Exception as e:
155
+ print(f"Error deleting vector database: {str(e)}")
156
+ return cost, output
157
+
158
+
159
+ def call_TA_to_json(
160
+ prompt, year, month, day, saved_patent_names, index=0, logging=True
161
+ ):
162
+ """
163
+ Retrieve text analytics (TA) data from a specified patent file and convert the output to JSON format.
164
+ This function reads a text document from the patent file specified by the year, month, day, and file name parameters.
165
+ It then applies a QA retrieval process to the document using the provided prompt.
166
+ The result of the QA retrieval process is converted to a JSON object, which is then written to a file.
167
+ Additionally, a patent identifier is manually assigned to the output JSON object.
168
+ Parameters:
169
+ prompt (str): The input prompt for the retrieval process.
170
+ year (int): The year part of the data folder name.
171
+ month (int): The month part of the data folder name.
172
+ day (int): The day part of the data folder name.
173
+ saved_patent_names (list): A list of strings containing the names of saved patent text files.
174
+ index (int, optional): The index of the saved patent text file to process. Default is 0.
175
+ logging (bool, optional): If True, print logs to the console. Default is True.
176
+ Returns:
177
+ tuple: A tuple containing two elements:
178
+ - documents_raw (str): The raw document content loaded from the specified patent file.
179
+ - output (str): A JSON string representing the output from the TA retrieval process.
180
+ Note:
181
+ The output is also written to a file in the 'output' directory with the same name as the input file and a '.json' extension.
182
+ """
183
+
184
+ llm = ChatOpenAI(model_name='gpt-3.5-turbo', cache=False)
185
+
186
+ file_path = os.path.join(
187
+ os.getcwd(),
188
+ "data",
189
+ "ipa" + str(year)[2:] + f"{month:02d}" + f"{day:02d}",
190
+ saved_patent_names[index],
191
+ )
192
+
193
+ if logging:
194
+ print(f"Loading documents from: {file_path}")
195
+
196
+ with open(file_path, 'r') as f:
197
+ documents_raw = f.read()
198
+
199
+
200
+ PROMPT_FORMAT = """
201
+ Task: Use the following pieces of context to answer the question at the end.
202
+ Question:
203
+ """
204
+
205
+ prompt = PROMPT_FORMAT + prompt
206
+
207
+ qa_chain = load_qa_chain(llm, chain_type="map_reduce")
208
+
209
+ qa_document_chain = AnalyzeDocumentChain(combine_docs_chain=qa_chain)
210
+
211
+
212
+ if logging:
213
+ print("Running Analyze Document chain...")
214
+
215
+ output = qa_document_chain.run(input_document=documents_raw, question=prompt)
216
+
217
+
218
+ try:
219
+ # Convert output to dictionary
220
+ output_dict = json.loads(output)
221
+
222
+ # Manually assign the Patent Identifier
223
+ output_dict["Patent Identifier"] = saved_patent_names[index].split("-")[0]
224
+
225
+
226
+ # Check if the directory 'output' exists, if not create it
227
+ if not os.path.exists("output"):
228
+ os.makedirs("output")
229
+
230
+ if logging:
231
+ print("Writing the output to a file...")
232
+
233
+ # Write the output to a file in the 'output' directory
234
+ with open(f"output/{saved_patent_names[index]}.json", "w", encoding="utf-8") as json_file:
235
+ json.dump(output_dict, json_file, indent=4, ensure_ascii=False)
236
+
237
+ if logging:
238
+ print("Call to 'call_QA_to_json' completed.")
239
+ except Exception as e:
240
+ print("An error occurred while processing the output.")
241
+ print("Error message:", str(e))
242
+ return documents_raw, output
243
+
244
+
245
+
246
+ def call_QA_faiss_to_json(
247
+ prompt, year, month, day, saved_patent_names, index=0, logging=True, model_name="gpt-3.5-turbo"
248
+ ):
249
+ """
250
+ Generate embeddings from txt documents, retrieve data based on the provided prompt, and return the result as a JSON object.
251
+ Parameters:
252
+ prompt (str): The input prompt for the retrieval process.
253
+ year (int): The year part of the data folder name.
254
+ month (int): The month part of the data folder name.
255
+ day (int): The day part of the data folder name.
256
+ saved_patent_names (list): A list of strings containing the names of saved patent text files.
257
+ index (int): The index of the saved patent text file to process. Default is 0.
258
+ logging (bool): The boolean to print logs
259
+ Returns:
260
+ tuple: A tuple containing two elements:
261
+ - A list of strings representing the raw documents loaded from the specified XML file.
262
+ - A JSON string representing the output from the retrieval chain.
263
+ This function loads the specified txt file, generates embeddings from its content,
264
+ and uses a retrieval chain to retrieve data based on the provided prompt.
265
+ The retrieved data is returned as a JSON object, and the raw documents are returned as a list of strings.
266
+ The output is also written to a file in the 'output' directory with the name '{count}.json'.
267
+ """
268
+
269
+ llm = ChatOpenAI(model_name=model_name, cache=False)
270
+ chain = load_qa_chain(llm, chain_type="stuff")
271
+
272
+ file_path = os.path.join(
273
+ os.getcwd(),
274
+ "data",
275
+ "ipa" + str(year)[2:] + f"{month:02d}" + f"{day:02d}",
276
+ saved_patent_names[index],
277
+ )
278
+
279
+ if logging:
280
+ print(f"Loading documents from: {file_path}")
281
+ loader = TextLoader(file_path)
282
+ documents_raw = loader.load()
283
+
284
+ text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)
285
+
286
+ documents = text_splitter.split_documents(documents_raw)
287
+
288
+
289
+
290
+ docsearch = FAISS.from_documents(documents, embeddings)
291
+
292
+
293
+ docs = docsearch.similarity_search(prompt)
294
+
295
+
296
+ if logging:
297
+ print("Running chain...")
298
+
299
+ with get_openai_callback() as cb:
300
+ output = chain.run(input_documents=docs, question=prompt)
301
+ print(f"Total Tokens: {cb.total_tokens}")
302
+ print(f"Prompt Tokens: {cb.prompt_tokens}")
303
+ print(f"Completion Tokens: {cb.completion_tokens}")
304
+ print(f"Successful Requests: {cb.successful_requests}")
305
+ print(f"Total Cost (USD): ${cb.total_cost}")
306
+
307
+ try:
308
+ # Convert output to dictionary
309
+ output_dict = json.loads(output)
310
+
311
+ # Manually assign the Patent Identifier
312
+ output_dict["Patent Identifier"] = saved_patent_names[index].split("-")[0]
313
+
314
+ # Check if the directory 'output' exists, if not create it
315
+ if not os.path.exists("output"):
316
+ os.makedirs("output")
317
+
318
+ if logging:
319
+ print("Writing the output to a file...")
320
+
321
+ # Write the output to a file in the 'output' directory
322
+ with open(f"output/{saved_patent_names[index]}_{model_name}.json", "w", encoding="utf-8") as json_file:
323
+ json.dump(output_dict, json_file, indent=4, ensure_ascii=False)
324
+
325
+ if logging:
326
+ print("Call to 'call_QA_to_json' completed.")
327
+
328
+ except Exception as e:
329
+ print("An error occurred while processing the output.")
330
+ print("Error message:", str(e))
331
+
332
+ docsearch.delete
333
+ return output