Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain.agents import create_sql_agent,create_react_agent | |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
from langchain.agents.agent_types import AgentType | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.sql_database import SQLDatabase | |
from sqlalchemy import create_engine | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains.retrieval import create_retrieval_chain | |
from langchain_core.output_parsers import StrOutputParser | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy import text | |
import sqlite3 | |
from dotenv import load_dotenv | |
from pathlib import Path | |
from PyPDF2 import PdfReader | |
import os | |
import re | |
load_dotenv() | |
os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") | |
os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN") | |
st.set_page_config("Langchain interaction with DB") | |
st.title("Document QnA with DB interaction") | |
llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY']) | |
embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2") | |
duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)") | |
st.session_state.user_prompt = "" | |
st.session_state.summary = "" | |
pdf_prompt_template = ChatPromptTemplate.from_template(""" | |
Answer the following question from the provided context only. | |
Please provide the most accurate response based on the question | |
<context> | |
{context} | |
</context> | |
Question : {input} | |
""") | |
def get_pdf_text(pdf_docs): | |
text="" | |
for pdf in pdf_docs: | |
pdf_reader= PdfReader(pdf) | |
for page in pdf_reader.pages: | |
text+= page.extract_text() | |
return text | |
def create_vector_embeddings(pdfText): | |
if "vectors" not in st.session_state: | |
st.session_state.docs = get_pdf_text(pdfText) | |
st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400) | |
st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs) | |
st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings) | |
def configure(): | |
dbfilepath = (Path(__file__).parent /"programme.db").absolute() | |
creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False) | |
return create_engine("sqlite:///", creator= creator) | |
engine = configure() | |
db = SQLDatabase(engine) | |
#ChatGroq(model="gemma2-9b-it" | |
sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY']) | |
sql_toolkit.get_tools() | |
prefilled_prompt = "" | |
# if "uploaded_text" in st.session_state: | |
# for m in st.session_state.uploaded_text: | |
# st.error(m) | |
# if 'PACKAGE' in st.session_state.uploaded_text: | |
# prefilled_prompt = "get the entire programme details linked to the package" | |
# else: | |
# prefilled_prompt = "get the entire programme details linked to the document" | |
# query=st.text_input("ask question here", value = prefilled_prompt) | |
def clear_database(): | |
connection = engine.raw_connection() | |
try: | |
# Create a cursor from the raw connection | |
cursor = connection.cursor() | |
# List of tables to clear | |
tables = ["programme", "episode"] | |
# Execute DELETE commands for each table | |
for table in tables: | |
cursor.execute(f"DELETE FROM {table}") | |
# Commit the changes to the database | |
connection.commit() | |
finally: | |
# Ensure the connection is closed properly | |
connection.close() | |
def process_sql_script(sql_script): | |
# Define the keyword to check | |
keyword = 'PACKAGE' | |
# Split the script into lines | |
lines = sql_script.strip().split(';') | |
programme_line = lines[0] | |
if keyword not in programme_line: | |
filtered_script = "\n".join([lines[0]]) | |
else: | |
filtered_script = "\n".join(lines) | |
return filtered_script | |
import re | |
def convert_to_hms(duration): | |
hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower()) | |
if hour_minute_match: | |
hours = int(hour_minute_match.group(1) or 0) | |
minutes = int(hour_minute_match.group(2) or 0) | |
else: | |
return duration | |
total_seconds = (hours * 60 * 60) + (minutes * 60) | |
hh = total_seconds // 3600 | |
mm = (total_seconds % 3600) // 60 | |
ss = total_seconds % 60 | |
return f"{hh:02}:{mm:02}:{ss:02}" | |
def handleDurationForEachScript(scripts): | |
filtered_data = "" | |
# for script in scripts.split(";"): | |
# # Find all matches for durations like '60 minutes' or '60 mins' | |
# matches = duration_pattern.findall(script) | |
# for match in matches: | |
# duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes' | |
# converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss | |
# script = script.replace(duration, converted_duration).replace('utes','') # Replace in script | |
# if ('episode' not in filtered_data) & ('programme' not in filtered_data): | |
# filtered_data = filtered_data + script | |
pattern = r"'(\d+\s*(?:mins|minutes))'" | |
for script in scripts.split(";"): | |
match = re.search(pattern, script) | |
if match: | |
duration = match.group(1) | |
converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss | |
script = script.replace(duration, converted_duration).replace('utes','') # Replace in script | |
if ('episode' not in filtered_data) & ('programme' not in filtered_data): | |
filtered_data = filtered_data + script | |
return filtered_data | |
def parse_insert_statement(insert_statement): | |
# Extract the table name | |
table_match = re.search(r'INSERT INTO (\w+)', insert_statement) | |
if not table_match: | |
return None, None, None | |
table = table_match.group(1) | |
# Extract columns and values | |
columns_match = re.search(r'\((.*?)\)', insert_statement, re.DOTALL) | |
values_match = re.search(r'VALUES\s*\((.*?)\)', insert_statement, re.DOTALL) | |
if not columns_match or not values_match: | |
return None, None, None | |
columns = columns_match.group(1).replace('"', '').replace('\n', ' ').strip() | |
values = values_match.group(1).replace("'", "").replace('\n', ' ').strip() | |
return table, columns, values | |
def build_data_from_sql(programme_sql, episode_sql=None): | |
data = { | |
'Table': [], | |
'Columns': [], | |
'Values': [] | |
} | |
# Parse the programme insert statement | |
programme_table, programme_columns, programme_values = parse_insert_statement(programme_sql) | |
if programme_table and programme_columns and programme_values: | |
data['Table'].append(programme_table.capitalize()) | |
data['Columns'].append(programme_columns) | |
data['Values'].append(programme_values) | |
# Parse the episode insert statement, if it exists | |
if episode_sql: | |
episode_table, episode_columns, episode_values = parse_insert_statement(episode_sql) | |
if episode_table and episode_columns and episode_values: | |
data['Table'].append(episode_table.capitalize()) | |
data['Columns'].append(episode_columns) | |
data['Values'].append(episode_values) | |
return data | |
with st.sidebar: | |
st.title("Menu:") | |
#if "uploaded_text" not in st.session_state: | |
st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True) | |
if st.button("Click To Process File"): | |
with st.spinner("Processing..."): | |
create_vector_embeddings(st.session_state.uploaded_text) | |
st.write("Vector Database is ready") | |
# if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None: | |
# uploaded_file_names = [file.name for file in st.session_state.uploaded_text] | |
# if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names): | |
# prefilled_prompt = "get the entire programme details linked to the package" | |
# else: | |
# prefilled_prompt = "get the entire programme details linked to the document" | |
query=st.text_input("ask question here") | |
if query and "vectors" in st.session_state: | |
st.session_state.user_prompt = query | |
document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template) | |
retriever = st.session_state.vectors.as_retriever() | |
retrieval_chain=create_retrieval_chain(retriever,document_chain) | |
response = retrieval_chain.invoke({"input": st.session_state.user_prompt}) | |
#st.write(response) | |
if response: | |
st.session_state.summary = response['answer'] | |
st.write(response['answer']) | |
prompt=ChatPromptTemplate.from_messages( | |
[ | |
("system", | |
""" | |
You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context. | |
1. Generate an `INSERT` statement for the `programme` table using the following values: | |
- `ProgrammeTitle` | |
- `ProgrammeType` | |
- `Genre` | |
- `SubGenre` | |
- `Language` | |
- `Duration` | |
Example: | |
2. After generating the `programme` statement, check the `ProgrammeTitle`: | |
- If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table. | |
- If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table. | |
3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`. | |
4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`. | |
Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions. | |
""" | |
), | |
("user","{question}\ ai: ") | |
]) | |
agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True) | |
if st.button("Generate Scripts",type="primary"): | |
try: | |
if st.session_state.summary is not None: | |
response=agent.run(prompt.format_prompt(question=st.session_state.summary)) | |
#with st.expander("Expand here to view scripts"): | |
if "INSERT" in response: | |
final_response = process_sql_script(response) | |
final_response_new = handleDurationForEachScript(final_response) | |
episode_sql = "" | |
splitted_data = [] | |
if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None: | |
uploaded_file_names = [file.name for file in st.session_state.uploaded_text] | |
if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names): | |
if ";" in final_response_new: | |
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(';') if stmt.strip()] | |
elif "\n" in final_response_new: | |
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split('\n') if stmt.strip()] | |
elif "," in final_response_new: | |
splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(',') if stmt.strip()] | |
else: | |
if final_response_new is list: | |
splitted_data = final_response_new | |
else: | |
splitted_data.append(final_response_new) | |
print(splitted_data) | |
if len(splitted_data) > 0: | |
programme_sql = splitted_data[0] + ';' # Re-add semicolon to the programme SQL statement | |
print(f"prog{programme_sql}") | |
if len(splitted_data) > 1: | |
episode_sql = splitted_data[1] | |
#print(f"eps{episode_sql}") | |
data = build_data_from_sql(programme_sql, episode_sql) | |
st.write("### Script Summary") | |
st.table(data) | |
st.write("### Full SQL Scripts") | |
with st.expander("Insert Scripts"): | |
st.code(programme_sql, language='sql') | |
st.code(episode_sql, language='sql') | |
#if episode_sql: | |
#with st.expander("Episode Insert Script"): | |
#st.code(episode_sql, language='sql') | |
#st.code(final_response_new, language = 'sql') | |
clear_database() | |
#st.write(response) | |
except Exception as e: | |
st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}") | |
# data = { | |
# 'Table': ['Programme', 'Episode'], | |
# 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'], | |
# 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...'] | |
# } | |
# # Display summary table | |
# st.write("### Script Summary") | |
# st.table(data) | |
# # Display expandable sections for each script | |
# st.write("### Full SQL Scripts") | |
# with st.expander("Programme Insert Script"): | |
# st.code("INSERT INTO programme ...") | |
# with st.expander("Episode Insert Script"): | |
# st.code("INSERT INTO episode ...") |