Shreyas094 commited on
Commit
6bdbafb
·
verified ·
1 Parent(s): ed2e431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -680
app.py CHANGED
@@ -1,602 +1,142 @@
1
  import os
2
- import json
3
- import re
4
  import gradio as gr
5
- import requests
6
- from duckduckgo_search import DDGS
7
- from typing import List
8
- from pydantic import BaseModel, Field
9
- from tempfile import NamedTemporaryFile
10
- from langchain_community.vectorstores import FAISS
11
- from langchain_core.vectorstores import VectorStore
12
- from langchain_core.documents import Document
13
- from langchain_community.document_loaders import PyPDFLoader
14
- from langchain_community.embeddings import HuggingFaceEmbeddings
15
- from llama_parse import LlamaParse
16
- from langchain_core.documents import Document
17
  from huggingface_hub import InferenceClient
18
- import inspect
19
- import logging
20
- import shutil
 
 
 
21
 
 
 
22
 
23
- # Set up basic configuration for logging
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
25
 
26
  # Environment variables and configurations
27
- huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
28
- llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
29
- ACCOUNT_ID = os.environ.get("CLOUDFARE_ACCOUNT_ID")
30
- API_TOKEN = os.environ.get("CLOUDFLARE_AUTH_TOKEN")
31
- API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/a17f03e0f049ccae0c15cdcf3b9737ce/ai/run/"
32
-
33
- print(f"ACCOUNT_ID: {ACCOUNT_ID}")
34
- print(f"CLOUDFLARE_AUTH_TOKEN: {API_TOKEN[:5]}..." if API_TOKEN else "Not set")
35
-
36
  MODELS = [
37
  "mistralai/Mistral-7B-Instruct-v0.3",
38
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
39
- "@cf/meta/llama-3.1-8b-instruct",
40
- "mistralai/Mistral-Nemo-Instruct-2407"
 
 
 
41
  ]
42
 
43
- # Initialize LlamaParse
44
- llama_parser = LlamaParse(
45
- api_key=llama_cloud_api_key,
46
- result_type="markdown",
47
- num_workers=4,
48
- verbose=True,
49
- language="en",
50
- )
51
 
52
- def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
53
- """Loads and splits the document into pages."""
54
- if parser == "pypdf":
55
- loader = PyPDFLoader(file.name)
56
- return loader.load_and_split()
57
- elif parser == "llamaparse":
58
  try:
59
- documents = llama_parser.load_data(file.name)
60
- return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
 
61
  except Exception as e:
62
- print(f"Error using Llama Parse: {str(e)}")
63
- print("Falling back to PyPDF parser")
64
- loader = PyPDFLoader(file.name)
65
- return loader.load_and_split()
66
- else:
67
- raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
68
 
 
69
  def get_embeddings():
70
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
71
 
72
- # Add this at the beginning of your script, after imports
73
- DOCUMENTS_FILE = "uploaded_documents.json"
74
-
75
- def load_documents():
76
- if os.path.exists(DOCUMENTS_FILE):
77
- with open(DOCUMENTS_FILE, "r") as f:
78
- return json.load(f)
79
- return []
80
-
81
- def save_documents(documents):
82
- with open(DOCUMENTS_FILE, "w") as f:
83
- json.dump(documents, f)
84
-
85
- # Replace the global uploaded_documents with this
86
- uploaded_documents = load_documents()
87
-
88
- # Modify the update_vectors function
89
- def update_vectors(files, parser):
90
- global uploaded_documents
91
- logging.info(f"Entering update_vectors with {len(files)} files and parser: {parser}")
92
-
93
- if not files:
94
- logging.warning("No files provided for update_vectors")
95
- return "Please upload at least one PDF file.", display_documents()
96
-
97
- embed = get_embeddings()
98
- total_chunks = 0
99
-
100
- all_data = []
101
- for file in files:
102
- logging.info(f"Processing file: {file.name}")
103
- try:
104
- data = load_document(file, parser)
105
- if not data:
106
- logging.warning(f"No chunks loaded from {file.name}")
107
- continue
108
- logging.info(f"Loaded {len(data)} chunks from {file.name}")
109
- all_data.extend(data)
110
- total_chunks += len(data)
111
- if not any(doc["name"] == file.name for doc in uploaded_documents):
112
- uploaded_documents.append({"name": file.name, "selected": True})
113
- logging.info(f"Added new document to uploaded_documents: {file.name}")
114
- else:
115
- logging.info(f"Document already exists in uploaded_documents: {file.name}")
116
- except Exception as e:
117
- logging.error(f"Error processing file {file.name}: {str(e)}")
118
-
119
- logging.info(f"Total chunks processed: {total_chunks}")
120
-
121
- if not all_data:
122
- logging.warning("No valid data extracted from uploaded files")
123
- return "No valid data could be extracted from the uploaded files. Please check the file contents and try again.", display_documents()
124
-
125
- try:
126
- if os.path.exists("faiss_database"):
127
- logging.info("Updating existing FAISS database")
128
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
129
- database.add_documents(all_data)
130
- else:
131
- logging.info("Creating new FAISS database")
132
- database = FAISS.from_documents(all_data, embed)
133
-
134
- database.save_local("faiss_database")
135
- logging.info("FAISS database saved")
136
- except Exception as e:
137
- logging.error(f"Error updating FAISS database: {str(e)}")
138
- return f"Error updating vector store: {str(e)}", display_documents()
139
-
140
- # Save the updated list of documents
141
- save_documents(uploaded_documents)
142
-
143
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}.", display_documents()
144
-
145
- def delete_documents(selected_docs):
146
- global uploaded_documents
147
-
148
- if not selected_docs:
149
- return "No documents selected for deletion.", display_documents()
150
-
151
  embed = get_embeddings()
152
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
153
-
154
- deleted_docs = []
155
- docs_to_keep = []
156
- for doc in database.docstore._dict.values():
157
- if doc.metadata.get("source") not in selected_docs:
158
- docs_to_keep.append(doc)
159
- else:
160
- deleted_docs.append(doc.metadata.get("source", "Unknown"))
161
-
162
- # Print debugging information
163
- logging.info(f"Total documents before deletion: {len(database.docstore._dict)}")
164
- logging.info(f"Documents to keep: {len(docs_to_keep)}")
165
- logging.info(f"Documents to delete: {len(deleted_docs)}")
166
-
167
- if not docs_to_keep:
168
- # If all documents are deleted, remove the FAISS database directory
169
- if os.path.exists("faiss_database"):
170
- shutil.rmtree("faiss_database")
171
- logging.info("All documents deleted. Removed FAISS database directory.")
172
- else:
173
- # Create new FAISS index with remaining documents
174
- new_database = FAISS.from_documents(docs_to_keep, embed)
175
- new_database.save_local("faiss_database")
176
- logging.info(f"Created new FAISS index with {len(docs_to_keep)} documents.")
177
-
178
- # Update uploaded_documents list
179
- uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in deleted_docs]
180
- save_documents(uploaded_documents)
181
-
182
- return f"Deleted documents: {', '.join(deleted_docs)}", display_documents()
183
-
184
- def generate_chunked_response(prompt, model, max_tokens=10000, num_calls=3, temperature=0.2, should_stop=False):
185
- print(f"Starting generate_chunked_response with {num_calls} calls")
186
- full_response = ""
187
- messages = [{"role": "user", "content": prompt}]
188
-
189
- if model == "@cf/meta/llama-3.1-8b-instruct":
190
- # Cloudflare API
191
- for i in range(num_calls):
192
- print(f"Starting Cloudflare API call {i+1}")
193
- if should_stop:
194
- print("Stop clicked, breaking loop")
195
- break
196
- try:
197
- response = requests.post(
198
- f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/@cf/meta/llama-3.1-8b-instruct",
199
- headers={"Authorization": f"Bearer {API_TOKEN}"},
200
- json={
201
- "stream": true,
202
- "messages": [
203
- {"role": "system", "content": "You are a friendly assistant"},
204
- {"role": "user", "content": prompt}
205
- ],
206
- "max_tokens": max_tokens,
207
- "temperature": temperature
208
- },
209
- stream=true
210
- )
211
-
212
- for line in response.iter_lines():
213
- if should_stop:
214
- print("Stop clicked during streaming, breaking")
215
- break
216
- if line:
217
- try:
218
- json_data = json.loads(line.decode('utf-8').split('data: ')[1])
219
- chunk = json_data['response']
220
- full_response += chunk
221
- except json.JSONDecodeError:
222
- continue
223
- print(f"Cloudflare API call {i+1} completed")
224
- except Exception as e:
225
- print(f"Error in generating response from Cloudflare: {str(e)}")
226
- else:
227
- # Original Hugging Face API logic
228
- client = InferenceClient(model, token=huggingface_token)
229
-
230
- for i in range(num_calls):
231
- print(f"Starting Hugging Face API call {i+1}")
232
- if should_stop:
233
- print("Stop clicked, breaking loop")
234
- break
235
- try:
236
- for message in client.chat_completion(
237
- messages=messages,
238
- max_tokens=max_tokens,
239
- temperature=temperature,
240
- stream=True,
241
- ):
242
- if should_stop:
243
- print("Stop clicked during streaming, breaking")
244
- break
245
- if message.choices and message.choices[0].delta and message.choices[0].delta.content:
246
- chunk = message.choices[0].delta.content
247
- full_response += chunk
248
- print(f"Hugging Face API call {i+1} completed")
249
- except Exception as e:
250
- print(f"Error in generating response from Hugging Face: {str(e)}")
251
-
252
- # Clean up the response
253
- clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
254
- clean_response = clean_response.replace("Using the following context:", "").strip()
255
- clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
256
-
257
- # Remove duplicate paragraphs and sentences
258
- paragraphs = clean_response.split('\n\n')
259
- unique_paragraphs = []
260
- for paragraph in paragraphs:
261
- if paragraph not in unique_paragraphs:
262
- sentences = paragraph.split('. ')
263
- unique_sentences = []
264
- for sentence in sentences:
265
- if sentence not in unique_sentences:
266
- unique_sentences.append(sentence)
267
- unique_paragraphs.append('. '.join(unique_sentences))
268
-
269
- final_response = '\n\n'.join(unique_paragraphs)
270
-
271
- print(f"Final clean response: {final_response[:100]}...")
272
- return final_response
273
-
274
- def duckduckgo_search(query):
275
- with DDGS() as ddgs:
276
- results = ddgs.text(query, max_results=5)
277
- return results
278
-
279
- class CitingSources(BaseModel):
280
- sources: List[str] = Field(
281
- ...,
282
- description="List of sources to cite. Should be an URL of the source."
283
- )
284
- def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
285
- if not message.strip():
286
- return "", history
287
-
288
- history = history + [(message, "")]
289
-
290
- try:
291
- for response in respond(message, history, model, temperature, num_calls, use_web_search):
292
- history[-1] = (message, response)
293
- yield history
294
- except gr.CancelledError:
295
- yield history
296
- except Exception as e:
297
- logging.error(f"Unexpected error in chatbot_interface: {str(e)}")
298
- history[-1] = (message, f"An unexpected error occurred: {str(e)}")
299
- yield history
300
 
301
- def retry_last_response(history, use_web_search, model, temperature, num_calls):
302
- if not history:
303
- return history
304
 
305
- last_user_msg = history[-1][0]
306
- history = history[:-1] # Remove the last response
307
-
308
- return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
309
-
310
- def respond(message, history, model, temperature, num_calls, use_web_search, selected_docs, instruction_key):
311
- logging.info(f"User Query: {message}")
312
- logging.info(f"Model Used: {model}")
313
- logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
314
- logging.info(f"Selected Documents: {selected_docs}")
315
- logging.info(f"Instruction Key: {instruction_key}")
316
-
317
- try:
318
- if instruction_key and instruction_key != "None":
319
- # This is a summary generation request
320
- instruction = INSTRUCTION_PROMPTS[instruction_key]
321
- context_str = get_context_for_summary(selected_docs)
322
- message = f"{instruction}\n\nUsing the following context from the PDF documents:\n{context_str}\nGenerate a detailed summary."
323
- use_web_search = False # Ensure we use PDF search for summaries
324
-
325
- if use_web_search:
326
- for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
327
- response = f"{main_content}\n\n{sources}"
328
- first_line = response.split('\n')[0] if response else ''
329
- # logging.info(f"Generated Response (first line): {first_line}")
330
- yield response
331
- else:
332
- embed = get_embeddings()
333
- if os.path.exists("faiss_database"):
334
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
335
- retriever = database.as_retriever()
336
-
337
- # Filter relevant documents based on user selection
338
- all_relevant_docs = retriever.get_relevant_documents(message)
339
- relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
340
-
341
- if not relevant_docs:
342
- yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
343
- return
344
-
345
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
346
- else:
347
- context_str = "No documents available."
348
- yield "No documents available. Please upload PDF documents to answer questions."
349
- return
350
-
351
- if model == "@cf/meta/llama-3.1-8b-instruct":
352
- # Use Cloudflare API
353
- for partial_response in get_response_from_cloudflare(prompt="", context=context_str, query=message, num_calls=num_calls, temperature=temperature, search_type="pdf"):
354
- first_line = partial_response.split('\n')[0] if partial_response else ''
355
- # logging.info(f"Generated Response (first line): {first_line}")
356
- yield partial_response
357
- else:
358
- # Use Hugging Face API
359
- for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
360
- first_line = partial_response.split('\n')[0] if partial_response else ''
361
- # logging.info(f"Generated Response (first line): {first_line}")
362
- yield partial_response
363
-
364
- except Exception as e:
365
- logging.error(f"Error with {model}: {str(e)}")
366
- if "microsoft/Phi-3-mini-4k-instruct" in model:
367
- logging.info("Falling back to Mistral model due to Phi-3 error")
368
- fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
369
- yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search, selected_docs, instruction_key)
370
- else:
371
- yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
372
-
373
- logging.basicConfig(level=logging.DEBUG)
374
 
375
- def get_context_for_summary(selected_docs):
376
- embed = get_embeddings()
377
- if os.path.exists("faiss_database"):
378
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
379
- retriever = database.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 most relevant chunks
380
-
381
- # Create a generic query that covers common financial summary topics
382
- generic_query = "financial performance revenue profit assets liabilities cash flow key metrics highlights"
383
-
384
- relevant_docs = retriever.get_relevant_documents(generic_query)
385
- filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
386
-
387
- if not filtered_docs:
388
- return "No relevant information found in the selected documents for summary generation."
389
-
390
- context_str = "\n".join([doc.page_content for doc in filtered_docs])
391
- return context_str
392
- else:
393
- return "No documents available for summary generation."
394
 
395
- def get_context_for_query(query, selected_docs):
396
- embed = get_embeddings()
397
- if os.path.exists("faiss_database"):
398
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
399
- retriever = database.as_retriever(search_kwargs={"k": 3}) # Retrieve top 3 most relevant chunks
400
-
401
  relevant_docs = retriever.get_relevant_documents(query)
402
- filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
403
-
404
- if not filtered_docs:
405
- return "No relevant information found in the selected documents for the given query."
406
-
407
- context_str = "\n".join([doc.page_content for doc in filtered_docs])
408
- return context_str
409
  else:
410
- return "No documents available to answer the query."
411
 
412
- def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
413
- headers = {
414
- "Authorization": f"Bearer {API_TOKEN}",
415
- "Content-Type": "application/json"
416
- }
417
- model = "@cf/meta/llama-3.1-8b-instruct"
418
 
419
- if search_type == "pdf":
420
- instruction = f"""Using the following context from the PDF documents:
421
  {context}
422
- Write a detailed and complete response that answers the following user question: '{query}'"""
423
- else: # web search
424
- instruction = f"""Using the following context:
425
- {context}
426
- Write a detailed and complete research document that fulfills the following user request: '{query}'
427
- After writing the document, please provide a list of sources used in your response."""
428
-
429
- inputs = [
430
- {"role": "system", "content": instruction},
431
- {"role": "user", "content": query}
432
- ]
433
 
434
- payload = {
435
- "messages": inputs,
436
- "stream": True,
437
- "temperature": temperature,
438
- "max_tokens": 32000
439
- }
440
 
441
- full_response = ""
442
- for i in range(num_calls):
443
  try:
444
- with requests.post(f"{API_BASE_URL}{model}", headers=headers, json=payload, stream=True) as response:
445
- if response.status_code == 200:
446
- for line in response.iter_lines():
447
- if line:
448
- try:
449
- json_response = json.loads(line.decode('utf-8').split('data: ')[1])
450
- if 'response' in json_response:
451
- chunk = json_response['response']
452
- full_response += chunk
453
- yield full_response
454
- except (json.JSONDecodeError, IndexError) as e:
455
- logging.error(f"Error parsing streaming response: {str(e)}")
456
- continue
457
- else:
458
- logging.error(f"HTTP Error: {response.status_code}, Response: {response.text}")
459
- yield f"I apologize, but I encountered an HTTP error: {response.status_code}. Please try again later."
460
  except Exception as e:
461
- logging.error(f"Error in generating response from Cloudflare: {str(e)}")
462
- yield f"I apologize, but an error occurred: {str(e)}. Please try again later."
463
-
464
- if not full_response:
465
- yield "I apologize, but I couldn't generate a response at this time. Please try again later."
466
-
467
- def create_web_search_vectors(search_results):
468
- embed = get_embeddings()
469
-
470
- documents = []
471
- for result in search_results:
472
- if 'body' in result:
473
- content = f"{result['title']}\n{result['body']}\nSource: {result['href']}"
474
- documents.append(Document(page_content=content, metadata={"source": result['href']}))
475
-
476
- return FAISS.from_documents(documents, embed)
477
-
478
- def get_response_with_search(query, model, num_calls=3, temperature=0.2):
479
- search_results = duckduckgo_search(query)
480
- web_search_database = create_web_search_vectors(search_results)
481
-
482
- if not web_search_database:
483
- yield "No web search results available. Please try again.", ""
484
- return
485
-
486
- retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
487
- relevant_docs = retriever.get_relevant_documents(query)
488
-
489
- context = "\n".join([doc.page_content for doc in relevant_docs])
490
-
491
- prompt = f"""Using the following context from web search results:
492
- {context}
493
- Write a detailed and complete research document that fulfills the following user request: '{query}'
494
- After writing the document, please provide a list of sources used in your response."""
495
-
496
- if model == "@cf/meta/llama-3.1-8b-instruct":
497
- # Use Cloudflare API
498
- for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
499
- yield response, "" # Yield streaming response without sources
500
- else:
501
- # Use Hugging Face API
502
- client = InferenceClient(model, token=huggingface_token)
503
-
504
- main_content = ""
505
- for i in range(num_calls):
506
- for message in client.chat_completion(
507
- messages=[{"role": "user", "content": prompt}],
508
- max_tokens=10000,
509
- temperature=temperature,
510
- stream=True,
511
- ):
512
- if message.choices and message.choices[0].delta and message.choices[0].delta.content:
513
- chunk = message.choices[0].delta.content
514
- main_content += chunk
515
- yield main_content, "" # Yield partial main content without sources
516
 
 
 
 
517
 
518
- INSTRUCTION_PROMPTS = {
519
- "Asset Managers": "Summarize the key financial metrics, assets under management, and performance highlights for this asset management company.",
520
- "Consumer Finance Companies": "Provide a summary of the company's loan portfolio, interest income, credit quality, and key operational metrics.",
521
- "Mortgage REITs": "Summarize the REIT's mortgage-backed securities portfolio, net interest income, book value per share, and dividend yield.",
522
- # Add more instruction prompts as needed
523
- }
524
-
525
- def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
526
- logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
527
-
528
- embed = get_embeddings()
529
- if os.path.exists("faiss_database"):
530
- logging.info("Loading FAISS database")
531
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
532
- else:
533
- logging.warning("No FAISS database found")
534
- yield "No documents available. Please upload PDF documents to answer questions."
535
- return
536
-
537
- # Pre-filter the documents
538
- filtered_docs = []
539
- for doc_id, doc in database.docstore._dict.items():
540
- if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs:
541
- filtered_docs.append(doc)
542
-
543
- logging.info(f"Number of documents after pre-filtering: {len(filtered_docs)}")
544
-
545
- if not filtered_docs:
546
- logging.warning(f"No documents found for the selected sources: {selected_docs}")
547
- yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
548
- return
549
-
550
- # Create a new FAISS index with only the selected documents
551
- filtered_db = FAISS.from_documents(filtered_docs, embed)
552
-
553
- retriever = filtered_db.as_retriever(search_kwargs={"k": 10})
554
- logging.info(f"Retrieving relevant documents for query: {query}")
555
- relevant_docs = retriever.get_relevant_documents(query)
556
- logging.info(f"Number of relevant documents retrieved: {len(relevant_docs)}")
557
-
558
- for doc in relevant_docs:
559
- logging.info(f"Document source: {doc.metadata['source']}")
560
- logging.info(f"Document content preview: {doc.page_content[:100]}...") # Log first 100 characters of each document
561
-
562
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
563
- logging.info(f"Total context length: {len(context_str)}")
564
 
565
- if model == "@cf/meta/llama-3.1-8b-instruct":
566
- logging.info("Using Cloudflare API")
567
- # Use Cloudflare API with the retrieved context
568
- for response in get_response_from_cloudflare(prompt="", context=context_str, query=query, num_calls=num_calls, temperature=temperature, search_type="pdf"):
569
- yield response
570
- else:
571
- logging.info("Using Hugging Face API")
572
- # Use Hugging Face API
573
- prompt = f"""Using the following context from the PDF documents:
574
- {context_str}
575
- Write a detailed and complete response that answers the following user question: '{query}'"""
576
-
577
- client = InferenceClient(model, token=huggingface_token)
578
-
579
- response = ""
580
- for i in range(num_calls):
581
- logging.info(f"API call {i+1}/{num_calls}")
582
- for message in client.chat_completion(
583
- messages=[{"role": "user", "content": prompt}],
584
- max_tokens=10000,
585
- temperature=temperature,
586
- stream=True,
587
- ):
588
- if message.choices and message.choices[0].delta and message.choices[0].delta.content:
589
- chunk = message.choices[0].delta.content
590
- response += chunk
591
- yield response # Yield partial response
592
-
593
- logging.info("Finished generating response")
594
 
595
- def vote(data: gr.LikeData):
596
- if data.liked:
597
- print(f"You upvoted this response: {data.value}")
598
- else:
599
- print(f"You downvoted this response: {data.value}")
 
 
 
 
600
 
601
  css = """
602
  /* Fine-tune chatbox size */
@@ -610,127 +150,54 @@ css = """
610
  }
611
  """
612
 
613
- uploaded_documents = []
614
-
615
- def display_documents():
616
- return gr.CheckboxGroup(
617
- choices=[doc["name"] for doc in uploaded_documents],
618
- value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
619
- label="Select documents to query or delete"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  )
621
 
622
- def initial_conversation():
623
- return [
624
- (None, "Welcome! I'm your AI assistant for web search and PDF analysis. Here's how you can use me:\n\n"
625
- "1. Set the toggle for Web Search and PDF Search from the checkbox in Additional Inputs drop down window\n"
626
- "2. Use web search to find information\n"
627
- "3. Upload the documents and ask questions about uploaded PDF documents by selecting your respective document\n"
628
- "4. For any queries feel free to reach out @[email protected] or discord - shreyas094\n\n"
629
- "To get started, upload some PDFs or ask me a question!")
630
- ]
631
- # Add this new function
632
- def refresh_documents():
633
- global uploaded_documents
634
- uploaded_documents = load_documents()
635
- return display_documents()
636
-
637
- # Define the checkbox outside the demo block
638
- document_selector = gr.CheckboxGroup(label="Select documents to query")
639
-
640
- use_web_search = gr.Checkbox(label="Use Web Search", value=True)
641
-
642
- custom_placeholder = "Ask a question (Note: You can toggle between Web Search and PDF Chat in Additional Inputs below)"
643
-
644
- instruction_choices = ["None"] + list(INSTRUCTION_PROMPTS.keys())
645
-
646
- demo = gr.ChatInterface(
647
- respond,
648
- additional_inputs=[
649
- gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
650
- gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
651
- gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
652
- use_web_search,
653
- document_selector,
654
- gr.Dropdown(choices=instruction_choices, label="Select Entity Type for Summary", value="None")
655
- ],
656
- title="AI-powered Web Search and PDF Chat Assistant",
657
- description="Chat with your PDFs, use web search to answer questions, or generate summaries. Select an Entity Type for Summary to generate a specific summary.",
658
- theme=gr.themes.Soft(
659
- primary_hue="orange",
660
- secondary_hue="amber",
661
- neutral_hue="gray",
662
- font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]
663
- ).set(
664
- body_background_fill_dark="#0c0505",
665
- block_background_fill_dark="#0c0505",
666
- block_border_width="1px",
667
- block_title_background_fill_dark="#1b0f0f",
668
- input_background_fill_dark="#140b0b",
669
- button_secondary_background_fill_dark="#140b0b",
670
- border_color_accent_dark="#1b0f0f",
671
- border_color_primary_dark="#1b0f0f",
672
- background_fill_secondary_dark="#0c0505",
673
- color_accent_soft_dark="transparent",
674
- code_background_fill_dark="#140b0b"
675
- ),
676
- css=css,
677
- examples=[
678
- ["Tell me about the contents of the uploaded PDFs."],
679
- ["What are the main topics discussed in the documents?"],
680
- ["Can you summarize the key points from the PDFs?"]
681
- ],
682
- cache_examples=False,
683
- analytics_enabled=False,
684
- textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
685
- chatbot = gr.Chatbot(
686
- show_copy_button=True,
687
- likeable=True,
688
- layout="bubble",
689
- height=400,
690
- value=initial_conversation()
691
- )
692
- )
693
-
694
- # Add file upload functionality
695
- with demo:
696
- gr.Markdown("## Upload and Manage PDF Documents")
697
-
698
- with gr.Row():
699
- file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
700
- parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
701
- update_button = gr.Button("Upload Document")
702
- refresh_button = gr.Button("Refresh Document List")
703
-
704
- update_output = gr.Textbox(label="Update Status")
705
- delete_button = gr.Button("Delete Selected Documents")
706
-
707
- # Update both the output text and the document selector
708
- update_button.click(update_vectors,
709
- inputs=[file_input, parser_dropdown],
710
- outputs=[update_output, document_selector])
711
-
712
- # Add the refresh button functionality
713
- refresh_button.click(refresh_documents,
714
- inputs=[],
715
- outputs=[document_selector])
716
-
717
- # Add the delete button functionality
718
- delete_button.click(delete_documents,
719
- inputs=[document_selector],
720
- outputs=[update_output, document_selector])
721
-
722
- gr.Markdown(
723
- """
724
- ## How to use
725
- 1. Upload PDF documents using the file input at the top.
726
- 2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
727
- 3. Select the documents you want to query using the checkboxes.
728
- 4. Ask questions in the chat interface.
729
- 5. Toggle "Use Web Search" to switch between PDF chat and web search.
730
- 6. Adjust Temperature and Number of API Calls to fine-tune the response generation.
731
- 7. Use the provided examples or ask your own questions.
732
- """
733
- )
734
 
735
  if __name__ == "__main__":
 
736
  demo.launch(share=True)
 
1
  import os
2
+ import logging
3
+ import asyncio
4
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
5
  from huggingface_hub import InferenceClient
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.schema import Document
9
+ from duckduckgo_search import DDGS
10
+ from dotenv import load_dotenv
11
+ from functools import lru_cache
12
 
13
+ # Load environment variables
14
+ load_dotenv()
15
 
16
+ # Configure logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
 
20
  # Environment variables and configurations
21
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
22
  MODELS = [
23
  "mistralai/Mistral-7B-Instruct-v0.3",
24
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
+ "mistralai/Mistral-Nemo-Instruct-2407",
26
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
27
+ "meta-llama/Meta-Llama-3.1-70B-Instruct",
28
+ "google/gemma-2-9b-it",
29
+ "google/gemma-2-27b-it"
30
  ]
31
 
32
+ DEFAULT_SYSTEM_PROMPT = """You are a world-class financial AI assistant, capable of complex reasoning and reflection.
33
+ Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags.
34
+ Providing comprehensive and accurate information based on web search results is essential.
35
+ Your goal is to synthesize the given context into a coherent and detailed response that directly addresses the user's query.
36
+ Please ensure that your response is well-structured and factual.
37
+ If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""
 
 
38
 
39
+ class WebSearcher:
40
+ def __init__(self):
41
+ self.ddgs = DDGS()
42
+
43
+ @lru_cache(maxsize=100)
44
+ def search(self, query, max_results=5):
45
  try:
46
+ results = list(self.ddgs.text(query, max_results=max_results))
47
+ logger.info(f"Search completed for query: {query}")
48
+ return results
49
  except Exception as e:
50
+ logger.error(f"Error during DuckDuckGo search: {str(e)}")
51
+ return []
 
 
 
 
52
 
53
+ @lru_cache(maxsize=1)
54
  def get_embeddings():
55
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
56
 
57
+ def create_web_search_vectors(search_results):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  embed = get_embeddings()
59
+ documents = [
60
+ Document(
61
+ page_content=f"{result['title']}\n{result['body']}\nSource: {result['href']}",
62
+ metadata={"source": result['href']}
63
+ )
64
+ for result in search_results if 'body' in result
65
+ ]
66
+ logger.info(f"Created vectors for {len(documents)} search results.")
67
+ return FAISS.from_documents(documents, embed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ async def get_response_with_search(query, system_prompt, model, use_embeddings, num_calls=3, temperature=0.2):
70
+ searcher = WebSearcher()
71
+ search_results = searcher.search(query)
72
 
73
+ if not search_results:
74
+ logger.warning(f"No web search results found for query: {query}")
75
+ yield "No web search results available. Please try again.", ""
76
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ sources = [result['href'] for result in search_results if 'href' in result]
79
+ source_list_str = "\n".join(sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ if use_embeddings:
82
+ web_search_database = create_web_search_vectors(search_results)
83
+ retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
 
 
 
84
  relevant_docs = retriever.get_relevant_documents(query)
85
+ context = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
 
 
86
  else:
87
+ context = "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
88
 
89
+ logger.info(f"Context created for query: {query}")
 
 
 
 
 
90
 
91
+ user_message = f"""Using the following context from web search results:
 
92
  {context}
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ Write a detailed and complete research document that fulfills the following user request: '{query}'."""
 
 
 
 
 
95
 
96
+ async with InferenceClient(model, token=HUGGINGFACE_TOKEN) as client:
97
+ full_response = ""
98
  try:
99
+ for _ in range(num_calls):
100
+ async for response in client.chat_completion_stream(
101
+ messages=[
102
+ {"role": "system", "content": system_prompt},
103
+ {"role": "user", "content": user_message}
104
+ ],
105
+ max_tokens=6000,
106
+ temperature=temperature,
107
+ top_p=0.8,
108
+ ):
109
+ if "content" in response:
110
+ chunk = response["content"]
111
+ full_response += chunk
112
+ yield full_response, ""
 
 
113
  except Exception as e:
114
+ logger.error(f"Error in get_response_with_search: {str(e)}")
115
+ yield f"An error occurred while processing your request: {str(e)}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ if not full_response:
118
+ logger.warning("No response generated from the model")
119
+ yield "No response generated from the model.", ""
120
 
121
+ yield f"{full_response}\n\nSources:\n{source_list_str}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ async def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
124
+ logger.info(f"User Query: {message}")
125
+ logger.info(f"Model Used: {model}")
126
+ logger.info(f"Temperature: {temperature}")
127
+ logger.info(f"Number of API Calls: {num_calls}")
128
+ logger.info(f"Use Embeddings: {use_embeddings}")
129
+ logger.info(f"System Prompt: {system_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ try:
132
+ async for main_content, sources in get_response_with_search(message, system_prompt, model, use_embeddings, num_calls=num_calls, temperature=temperature):
133
+ yield main_content
134
+ except asyncio.CancelledError:
135
+ logger.warning("The operation was cancelled.")
136
+ yield "The operation was cancelled. Please try again."
137
+ except Exception as e:
138
+ logger.error(f"Error in respond function: {str(e)}")
139
+ yield f"An error occurred: {str(e)}"
140
 
141
  css = """
142
  /* Fine-tune chatbox size */
 
150
  }
151
  """
152
 
153
+ def create_gradio_interface():
154
+ custom_placeholder = "Enter your question here for web search."
155
+
156
+ demo = gr.ChatInterface(
157
+ fn=respond,
158
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=True, render=False),
159
+ additional_inputs=[
160
+ gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=6, label="System Prompt", placeholder="Enter your system prompt here"),
161
+ gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
162
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
163
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
164
+ gr.Checkbox(label="Use Embeddings", value=False),
165
+ ],
166
+ title="AI-powered Web Search Assistant",
167
+ description="Use web search to answer questions or generate summaries.",
168
+ theme=gr.Theme.from_hub("allenai/gradio-theme"),
169
+ css=css,
170
+ examples=[
171
+ ["What are the latest developments in artificial intelligence?"],
172
+ ["Explain the concept of quantum computing."],
173
+ ["What are the environmental impacts of renewable energy?"]
174
+ ],
175
+ cache_examples=False,
176
+ analytics_enabled=False,
177
+ textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
178
+ chatbot=gr.Chatbot(
179
+ show_copy_button=True,
180
+ likeable=True,
181
+ layout="bubble",
182
+ height=400,
183
+ )
184
  )
185
 
186
+ with demo:
187
+ gr.Markdown("""
188
+ ## How to use
189
+ 1. Enter your question in the chat interface.
190
+ 2. Optionally, modify the System Prompt to guide the AI's behavior.
191
+ 3. Select the model you want to use from the dropdown.
192
+ 4. Adjust the Temperature to control the randomness of the response.
193
+ 5. Set the Number of API Calls to determine how many times the model will be queried.
194
+ 6. Check or uncheck the "Use Embeddings" box to toggle between using embeddings or direct text summarization.
195
+ 7. Press Enter or click the submit button to get your answer.
196
+ 8. Use the provided examples or ask your own questions.
197
+ """)
198
+
199
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  if __name__ == "__main__":
202
+ demo = create_gradio_interface()
203
  demo.launch(share=True)