SoumyaJ's picture
Update app.py
e259e19 verified
import streamlit as st
from langchain.agents import create_sql_agent,create_react_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain.sql_database import SQLDatabase
from sqlalchemy import create_engine
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain
from langchain_core.output_parsers import StrOutputParser
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text
import sqlite3
from dotenv import load_dotenv
from pathlib import Path
from PyPDF2 import PdfReader
import os
import re
load_dotenv()
os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY")
os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN")
st.set_page_config("Langchain interaction with DB")
st.title("Document QnA with DB interaction")
llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY'])
embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2")
duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)")
st.session_state.user_prompt = ""
st.session_state.summary = ""
pdf_prompt_template = ChatPromptTemplate.from_template("""
Answer the following question from the provided context only.
Please provide the most accurate response based on the question
<context>
{context}
</context>
Question : {input}
""")
def get_pdf_text(pdf_docs):
text=""
for pdf in pdf_docs:
pdf_reader= PdfReader(pdf)
for page in pdf_reader.pages:
text+= page.extract_text()
return text
def create_vector_embeddings(pdfText):
if "vectors" not in st.session_state:
st.session_state.docs = get_pdf_text(pdfText)
st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400)
st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs)
st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings)
def configure():
dbfilepath = (Path(__file__).parent /"programme.db").absolute()
creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False)
return create_engine("sqlite:///", creator= creator)
engine = configure()
db = SQLDatabase(engine)
#ChatGroq(model="gemma2-9b-it"
sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY'])
sql_toolkit.get_tools()
prefilled_prompt = ""
# if "uploaded_text" in st.session_state:
# for m in st.session_state.uploaded_text:
# st.error(m)
# if 'PACKAGE' in st.session_state.uploaded_text:
# prefilled_prompt = "get the entire programme details linked to the package"
# else:
# prefilled_prompt = "get the entire programme details linked to the document"
# query=st.text_input("ask question here", value = prefilled_prompt)
def clear_database():
connection = engine.raw_connection()
try:
# Create a cursor from the raw connection
cursor = connection.cursor()
# List of tables to clear
tables = ["programme", "episode"]
# Execute DELETE commands for each table
for table in tables:
cursor.execute(f"DELETE FROM {table}")
# Commit the changes to the database
connection.commit()
finally:
# Ensure the connection is closed properly
connection.close()
def process_sql_script(sql_script):
# Define the keyword to check
keyword = 'PACKAGE'
# Split the script into lines
lines = sql_script.strip().split(';')
programme_line = lines[0]
if keyword not in programme_line:
filtered_script = "\n".join([lines[0]])
else:
filtered_script = "\n".join(lines)
return filtered_script
import re
def convert_to_hms(duration):
hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower())
if hour_minute_match:
hours = int(hour_minute_match.group(1) or 0)
minutes = int(hour_minute_match.group(2) or 0)
else:
return duration
total_seconds = (hours * 60 * 60) + (minutes * 60)
hh = total_seconds // 3600
mm = (total_seconds % 3600) // 60
ss = total_seconds % 60
return f"{hh:02}:{mm:02}:{ss:02}"
def handleDurationForEachScript(scripts):
filtered_data = ""
# for script in scripts.split(";"):
# # Find all matches for durations like '60 minutes' or '60 mins'
# matches = duration_pattern.findall(script)
# for match in matches:
# duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes'
# converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss
# script = script.replace(duration, converted_duration).replace('utes','') # Replace in script
# if ('episode' not in filtered_data) & ('programme' not in filtered_data):
# filtered_data = filtered_data + script
pattern = r"'(\d+\s*(?:mins|minutes))'"
for script in scripts.split(";"):
match = re.search(pattern, script)
if match:
duration = match.group(1)
converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss
script = script.replace(duration, converted_duration).replace('utes','') # Replace in script
if ('episode' not in filtered_data) & ('programme' not in filtered_data):
filtered_data = filtered_data + script
return filtered_data
def parse_insert_statement(insert_statement):
# Extract the table name
table_match = re.search(r'INSERT INTO (\w+)', insert_statement)
if not table_match:
return None, None, None
table = table_match.group(1)
# Extract columns and values
columns_match = re.search(r'\((.*?)\)', insert_statement, re.DOTALL)
values_match = re.search(r'VALUES\s*\((.*?)\)', insert_statement, re.DOTALL)
if not columns_match or not values_match:
return None, None, None
columns = columns_match.group(1).replace('"', '').replace('\n', ' ').strip()
values = values_match.group(1).replace("'", "").replace('\n', ' ').strip()
return table, columns, values
def build_data_from_sql(programme_sql, episode_sql=None):
data = {
'Table': [],
'Columns': [],
'Values': []
}
# Parse the programme insert statement
programme_table, programme_columns, programme_values = parse_insert_statement(programme_sql)
if programme_table and programme_columns and programme_values:
data['Table'].append(programme_table.capitalize())
data['Columns'].append(programme_columns)
data['Values'].append(programme_values)
# Parse the episode insert statement, if it exists
if episode_sql:
episode_table, episode_columns, episode_values = parse_insert_statement(episode_sql)
if episode_table and episode_columns and episode_values:
data['Table'].append(episode_table.capitalize())
data['Columns'].append(episode_columns)
data['Values'].append(episode_values)
return data
with st.sidebar:
st.title("Menu:")
#if "uploaded_text" not in st.session_state:
st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True)
if st.button("Click To Process File"):
with st.spinner("Processing..."):
create_vector_embeddings(st.session_state.uploaded_text)
st.write("Vector Database is ready")
# if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None:
# uploaded_file_names = [file.name for file in st.session_state.uploaded_text]
# if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names):
# prefilled_prompt = "get the entire programme details linked to the package"
# else:
# prefilled_prompt = "get the entire programme details linked to the document"
query=st.text_input("ask question here")
if query and "vectors" in st.session_state:
st.session_state.user_prompt = query
document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template)
retriever = st.session_state.vectors.as_retriever()
retrieval_chain=create_retrieval_chain(retriever,document_chain)
response = retrieval_chain.invoke({"input": st.session_state.user_prompt})
#st.write(response)
if response:
st.session_state.summary = response['answer']
st.write(response['answer'])
prompt=ChatPromptTemplate.from_messages(
[
("system",
"""
You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context.
1. Generate an `INSERT` statement for the `programme` table using the following values:
- `ProgrammeTitle`
- `ProgrammeType`
- `Genre`
- `SubGenre`
- `Language`
- `Duration`
Example:
2. After generating the `programme` statement, check the `ProgrammeTitle`:
- If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table.
- If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table.
3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`.
4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`.
Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions.
"""
),
("user","{question}\ ai: ")
])
agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True)
if st.button("Generate Scripts",type="primary"):
try:
if st.session_state.summary is not None:
response=agent.run(prompt.format_prompt(question=st.session_state.summary))
#with st.expander("Expand here to view scripts"):
if "INSERT" in response:
final_response = process_sql_script(response)
final_response_new = handleDurationForEachScript(final_response)
episode_sql = ""
splitted_data = []
if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None:
uploaded_file_names = [file.name for file in st.session_state.uploaded_text]
if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names):
if ";" in final_response_new:
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(';') if stmt.strip()]
elif "\n" in final_response_new:
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split('\n') if stmt.strip()]
elif "," in final_response_new:
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(',') if stmt.strip()]
else:
if final_response_new is list:
splitted_data = final_response_new
else:
splitted_data.append(final_response_new)
print(splitted_data)
if len(splitted_data) > 0:
programme_sql = splitted_data[0] + ';' # Re-add semicolon to the programme SQL statement
print(f"prog{programme_sql}")
if len(splitted_data) > 1:
episode_sql = splitted_data[1]
#print(f"eps{episode_sql}")
data = build_data_from_sql(programme_sql, episode_sql)
st.write("### Script Summary")
st.table(data)
st.write("### Full SQL Scripts")
with st.expander("Insert Scripts"):
st.code(programme_sql, language='sql')
st.code(episode_sql, language='sql')
#if episode_sql:
#with st.expander("Episode Insert Script"):
#st.code(episode_sql, language='sql')
#st.code(final_response_new, language = 'sql')
clear_database()
#st.write(response)
except Exception as e:
st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}")
# data = {
# 'Table': ['Programme', 'Episode'],
# 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'],
# 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...']
# }
# # Display summary table
# st.write("### Script Summary")
# st.table(data)
# # Display expandable sections for each script
# st.write("### Full SQL Scripts")
# with st.expander("Programme Insert Script"):
# st.code("INSERT INTO programme ...")
# with st.expander("Episode Insert Script"):
# st.code("INSERT INTO episode ...")