Spaces:
Sleeping
Sleeping
Upload 4 files
Browse filesinitial commit for DBToolkit behavior
- app.py +234 -0
- programme.db +0 -0
- requirements.txt +10 -0
- sqlite.py +32 -0
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain.agents import create_sql_agent,create_react_agent
|
3 |
+
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
4 |
+
from langchain.agents.agent_types import AgentType
|
5 |
+
from langchain_groq import ChatGroq
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain.sql_database import SQLDatabase
|
8 |
+
from sqlalchemy import create_engine
|
9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
11 |
+
from langchain_community.vectorstores import FAISS
|
12 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
13 |
+
from langchain.chains.retrieval import create_retrieval_chain
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
15 |
+
from sqlalchemy.orm import sessionmaker
|
16 |
+
from sqlalchemy import text
|
17 |
+
import sqlite3
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
from pathlib import Path
|
20 |
+
from PyPDF2 import PdfReader
|
21 |
+
import os
|
22 |
+
import re
|
23 |
+
|
24 |
+
load_dotenv()
|
25 |
+
os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY")
|
26 |
+
os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN")
|
27 |
+
|
28 |
+
st.set_page_config("Langchain interaction with DB")
|
29 |
+
st.title("Langchain with DB")
|
30 |
+
|
31 |
+
llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY'])
|
32 |
+
|
33 |
+
embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2")
|
34 |
+
|
35 |
+
duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)")
|
36 |
+
|
37 |
+
st.session_state.user_prompt = ""
|
38 |
+
st.session_state.summary = ""
|
39 |
+
|
40 |
+
pdf_prompt_template = ChatPromptTemplate.from_template("""
|
41 |
+
Answer the following question from the provided context only.
|
42 |
+
Please provide the most accurate response based on the question
|
43 |
+
<context>
|
44 |
+
{context}
|
45 |
+
</context>
|
46 |
+
Question : {input}
|
47 |
+
""")
|
48 |
+
|
49 |
+
def get_pdf_text(pdf_docs):
|
50 |
+
text=""
|
51 |
+
for pdf in pdf_docs:
|
52 |
+
pdf_reader= PdfReader(pdf)
|
53 |
+
for page in pdf_reader.pages:
|
54 |
+
text+= page.extract_text()
|
55 |
+
return text
|
56 |
+
|
57 |
+
def create_vector_embeddings(pdfText):
|
58 |
+
if "vectors" not in st.session_state:
|
59 |
+
st.session_state.docs = get_pdf_text(pdfText)
|
60 |
+
st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400)
|
61 |
+
st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs)
|
62 |
+
st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings)
|
63 |
+
|
64 |
+
def configure():
|
65 |
+
dbfilepath = (Path(__file__).parent /"programme.db").absolute()
|
66 |
+
creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False)
|
67 |
+
return create_engine("sqlite:///", creator= creator)
|
68 |
+
|
69 |
+
engine = configure()
|
70 |
+
db = SQLDatabase(engine)
|
71 |
+
#ChatGroq(model="gemma2-9b-it"
|
72 |
+
sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY'])
|
73 |
+
sql_toolkit.get_tools()
|
74 |
+
|
75 |
+
query=st.text_input("ask question here")
|
76 |
+
|
77 |
+
def clear_database():
|
78 |
+
|
79 |
+
connection = engine.raw_connection()
|
80 |
+
try:
|
81 |
+
# Create a cursor from the raw connection
|
82 |
+
cursor = connection.cursor()
|
83 |
+
|
84 |
+
# List of tables to clear
|
85 |
+
tables = ["programme", "episode"]
|
86 |
+
|
87 |
+
# Execute DELETE commands for each table
|
88 |
+
for table in tables:
|
89 |
+
cursor.execute(f"DELETE FROM {table}")
|
90 |
+
|
91 |
+
# Commit the changes to the database
|
92 |
+
connection.commit()
|
93 |
+
finally:
|
94 |
+
# Ensure the connection is closed properly
|
95 |
+
connection.close()
|
96 |
+
|
97 |
+
|
98 |
+
def process_sql_script(sql_script):
|
99 |
+
# Define the keyword to check
|
100 |
+
keyword = 'PACKAGE'
|
101 |
+
|
102 |
+
# Split the script into lines
|
103 |
+
lines = sql_script.strip().split(';')
|
104 |
+
|
105 |
+
programme_line = lines[0]
|
106 |
+
if keyword not in programme_line:
|
107 |
+
filtered_script = "\n".join([lines[0]])
|
108 |
+
else:
|
109 |
+
filtered_script = "\n".join(lines)
|
110 |
+
|
111 |
+
return filtered_script
|
112 |
+
|
113 |
+
import re
|
114 |
+
|
115 |
+
def convert_to_hms(duration):
|
116 |
+
hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower())
|
117 |
+
|
118 |
+
if hour_minute_match:
|
119 |
+
hours = int(hour_minute_match.group(1) or 0)
|
120 |
+
minutes = int(hour_minute_match.group(2) or 0)
|
121 |
+
else:
|
122 |
+
return duration
|
123 |
+
|
124 |
+
total_seconds = (hours * 60 * 60) + (minutes * 60)
|
125 |
+
hh = total_seconds // 3600
|
126 |
+
mm = (total_seconds % 3600) // 60
|
127 |
+
ss = total_seconds % 60
|
128 |
+
|
129 |
+
return f"{hh:02}:{mm:02}:{ss:02}"
|
130 |
+
|
131 |
+
def handleDurationForEachScript(scripts):
|
132 |
+
filtered_data = ""
|
133 |
+
for script in scripts.split(";"):
|
134 |
+
# Find all matches for durations like '60 minutes' or '60 mins'
|
135 |
+
matches = duration_pattern.findall(script)
|
136 |
+
|
137 |
+
for match in matches:
|
138 |
+
duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes'
|
139 |
+
converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss
|
140 |
+
script = script.replace(duration, converted_duration).replace('utes','') # Replace in script
|
141 |
+
if ('episode' not in filtered_data) & ('programme' not in filtered_data):
|
142 |
+
filtered_data = filtered_data + script
|
143 |
+
|
144 |
+
return filtered_data
|
145 |
+
|
146 |
+
with st.sidebar:
|
147 |
+
st.title("Menu:")
|
148 |
+
#if "uploaded_text" not in st.session_state:
|
149 |
+
st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True)
|
150 |
+
if st.button("Click To Process File"):
|
151 |
+
with st.spinner("Processing..."):
|
152 |
+
create_vector_embeddings(st.session_state.uploaded_text)
|
153 |
+
st.write("Vector Database is ready")
|
154 |
+
|
155 |
+
if query:
|
156 |
+
st.session_state.user_prompt = query
|
157 |
+
document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template)
|
158 |
+
retriever = st.session_state.vectors.as_retriever()
|
159 |
+
retrieval_chain=create_retrieval_chain(retriever,document_chain)
|
160 |
+
response = retrieval_chain.invoke({"input": st.session_state.user_prompt})
|
161 |
+
#st.write(response)
|
162 |
+
if response:
|
163 |
+
st.session_state.summary = response['answer']
|
164 |
+
st.write(response['answer'])
|
165 |
+
|
166 |
+
prompt=ChatPromptTemplate.from_messages(
|
167 |
+
[
|
168 |
+
("system",
|
169 |
+
"""
|
170 |
+
You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context.
|
171 |
+
|
172 |
+
1. Generate an `INSERT` statement for the `programme` table using the following values:
|
173 |
+
- `ProgrammeTitle`
|
174 |
+
- `ProgrammeType`
|
175 |
+
- `Genre`
|
176 |
+
- `SubGenre`
|
177 |
+
- `Language`
|
178 |
+
- `Duration`
|
179 |
+
Example:
|
180 |
+
|
181 |
+
|
182 |
+
2. After generating the `programme` statement, check the `ProgrammeTitle`:
|
183 |
+
- If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table.
|
184 |
+
- If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table.
|
185 |
+
|
186 |
+
3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`.
|
187 |
+
|
188 |
+
|
189 |
+
4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`.
|
190 |
+
|
191 |
+
Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions.
|
192 |
+
|
193 |
+
"""
|
194 |
+
),
|
195 |
+
("user","{question}\ ai: ")
|
196 |
+
])
|
197 |
+
|
198 |
+
agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True)
|
199 |
+
|
200 |
+
if st.button("Generate Scripts",type="primary"):
|
201 |
+
try:
|
202 |
+
if st.session_state.summary is not None:
|
203 |
+
response=agent.run(prompt.format_prompt(question=st.session_state.summary))
|
204 |
+
with st.expander("Expand here to view scripts"):
|
205 |
+
if "INSERT" in response:
|
206 |
+
final_response = process_sql_script(response)
|
207 |
+
final_response_new = handleDurationForEachScript(final_response)
|
208 |
+
print(final_response_new)
|
209 |
+
st.code(final_response_new, language = 'sql')
|
210 |
+
clear_database()
|
211 |
+
|
212 |
+
#st.write(response)
|
213 |
+
except Exception as e:
|
214 |
+
st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}")
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
# data = {
|
219 |
+
# 'Table': ['Programme', 'Episode'],
|
220 |
+
# 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'],
|
221 |
+
# 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...']
|
222 |
+
# }
|
223 |
+
|
224 |
+
# # Display summary table
|
225 |
+
# st.write("### Script Summary")
|
226 |
+
# st.table(data)
|
227 |
+
|
228 |
+
# # Display expandable sections for each script
|
229 |
+
# st.write("### Full SQL Scripts")
|
230 |
+
# with st.expander("Programme Insert Script"):
|
231 |
+
# st.code("INSERT INTO programme ...")
|
232 |
+
|
233 |
+
# with st.expander("Episode Insert Script"):
|
234 |
+
# st.code("INSERT INTO episode ...")
|
programme.db
ADDED
Binary file (12.3 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
langchain
|
3 |
+
langchain-community
|
4 |
+
langchain_groq
|
5 |
+
langchain-core
|
6 |
+
python-dotenv
|
7 |
+
langchain_huggingface
|
8 |
+
sentence-transformers
|
9 |
+
PyPDF2
|
10 |
+
faiss-cpu
|
sqlite.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
|
3 |
+
#create connection with sqlite
|
4 |
+
conn = sqlite3.connect("programme.db", check_same_thread= False)
|
5 |
+
|
6 |
+
#create cursor to establish connection
|
7 |
+
cursor = conn.cursor()
|
8 |
+
|
9 |
+
#create the table
|
10 |
+
prog_table_info = """
|
11 |
+
CREATE TABLE programme(ProgrammeTitle nvarchar(500), ProgrammeStatus nvarchar(30), ProgrammeType nvarchar(30), Genre nvarchar(40),
|
12 |
+
SubGenre nvarchar(50), Language nvarchar(30), Duration nvarchar(30))
|
13 |
+
"""
|
14 |
+
cursor.execute(prog_table_info)
|
15 |
+
|
16 |
+
print("Access the table programme")
|
17 |
+
data = cursor.execute('''SELECT * from programme''')
|
18 |
+
for row in data:
|
19 |
+
print(row)
|
20 |
+
|
21 |
+
eps_table_info = """
|
22 |
+
CREATE TABLE episode(EpisodeTitle nvarchar(500), EpisodeNumber int, Duration nvarchar(30))
|
23 |
+
"""
|
24 |
+
cursor.execute(eps_table_info)
|
25 |
+
|
26 |
+
print("Access the table episode")
|
27 |
+
data = cursor.execute('''SELECT * from episode''')
|
28 |
+
for row in data:
|
29 |
+
print(row)
|
30 |
+
|
31 |
+
conn.commit()
|
32 |
+
conn.close()
|