SoumyaJ commited on
Commit
ae95391
·
verified ·
1 Parent(s): 9784fb6

Upload 4 files

Browse files

initial commit for DBToolkit behavior

Files changed (4) hide show
  1. app.py +234 -0
  2. programme.db +0 -0
  3. requirements.txt +10 -0
  4. sqlite.py +32 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.agents import create_sql_agent,create_react_agent
3
+ from langchain.agents.agent_toolkits import SQLDatabaseToolkit
4
+ from langchain.agents.agent_types import AgentType
5
+ from langchain_groq import ChatGroq
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain.sql_database import SQLDatabase
8
+ from sqlalchemy import create_engine
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from sqlalchemy.orm import sessionmaker
16
+ from sqlalchemy import text
17
+ import sqlite3
18
+ from dotenv import load_dotenv
19
+ from pathlib import Path
20
+ from PyPDF2 import PdfReader
21
+ import os
22
+ import re
23
+
24
+ load_dotenv()
25
+ os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY")
26
+ os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN")
27
+
28
+ st.set_page_config("Langchain interaction with DB")
29
+ st.title("Langchain with DB")
30
+
31
+ llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY'])
32
+
33
+ embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2")
34
+
35
+ duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)")
36
+
37
+ st.session_state.user_prompt = ""
38
+ st.session_state.summary = ""
39
+
40
+ pdf_prompt_template = ChatPromptTemplate.from_template("""
41
+ Answer the following question from the provided context only.
42
+ Please provide the most accurate response based on the question
43
+ <context>
44
+ {context}
45
+ </context>
46
+ Question : {input}
47
+ """)
48
+
49
+ def get_pdf_text(pdf_docs):
50
+ text=""
51
+ for pdf in pdf_docs:
52
+ pdf_reader= PdfReader(pdf)
53
+ for page in pdf_reader.pages:
54
+ text+= page.extract_text()
55
+ return text
56
+
57
+ def create_vector_embeddings(pdfText):
58
+ if "vectors" not in st.session_state:
59
+ st.session_state.docs = get_pdf_text(pdfText)
60
+ st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400)
61
+ st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs)
62
+ st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings)
63
+
64
+ def configure():
65
+ dbfilepath = (Path(__file__).parent /"programme.db").absolute()
66
+ creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False)
67
+ return create_engine("sqlite:///", creator= creator)
68
+
69
+ engine = configure()
70
+ db = SQLDatabase(engine)
71
+ #ChatGroq(model="gemma2-9b-it"
72
+ sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY'])
73
+ sql_toolkit.get_tools()
74
+
75
+ query=st.text_input("ask question here")
76
+
77
+ def clear_database():
78
+
79
+ connection = engine.raw_connection()
80
+ try:
81
+ # Create a cursor from the raw connection
82
+ cursor = connection.cursor()
83
+
84
+ # List of tables to clear
85
+ tables = ["programme", "episode"]
86
+
87
+ # Execute DELETE commands for each table
88
+ for table in tables:
89
+ cursor.execute(f"DELETE FROM {table}")
90
+
91
+ # Commit the changes to the database
92
+ connection.commit()
93
+ finally:
94
+ # Ensure the connection is closed properly
95
+ connection.close()
96
+
97
+
98
+ def process_sql_script(sql_script):
99
+ # Define the keyword to check
100
+ keyword = 'PACKAGE'
101
+
102
+ # Split the script into lines
103
+ lines = sql_script.strip().split(';')
104
+
105
+ programme_line = lines[0]
106
+ if keyword not in programme_line:
107
+ filtered_script = "\n".join([lines[0]])
108
+ else:
109
+ filtered_script = "\n".join(lines)
110
+
111
+ return filtered_script
112
+
113
+ import re
114
+
115
+ def convert_to_hms(duration):
116
+ hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower())
117
+
118
+ if hour_minute_match:
119
+ hours = int(hour_minute_match.group(1) or 0)
120
+ minutes = int(hour_minute_match.group(2) or 0)
121
+ else:
122
+ return duration
123
+
124
+ total_seconds = (hours * 60 * 60) + (minutes * 60)
125
+ hh = total_seconds // 3600
126
+ mm = (total_seconds % 3600) // 60
127
+ ss = total_seconds % 60
128
+
129
+ return f"{hh:02}:{mm:02}:{ss:02}"
130
+
131
+ def handleDurationForEachScript(scripts):
132
+ filtered_data = ""
133
+ for script in scripts.split(";"):
134
+ # Find all matches for durations like '60 minutes' or '60 mins'
135
+ matches = duration_pattern.findall(script)
136
+
137
+ for match in matches:
138
+ duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes'
139
+ converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss
140
+ script = script.replace(duration, converted_duration).replace('utes','') # Replace in script
141
+ if ('episode' not in filtered_data) & ('programme' not in filtered_data):
142
+ filtered_data = filtered_data + script
143
+
144
+ return filtered_data
145
+
146
+ with st.sidebar:
147
+ st.title("Menu:")
148
+ #if "uploaded_text" not in st.session_state:
149
+ st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True)
150
+ if st.button("Click To Process File"):
151
+ with st.spinner("Processing..."):
152
+ create_vector_embeddings(st.session_state.uploaded_text)
153
+ st.write("Vector Database is ready")
154
+
155
+ if query:
156
+ st.session_state.user_prompt = query
157
+ document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template)
158
+ retriever = st.session_state.vectors.as_retriever()
159
+ retrieval_chain=create_retrieval_chain(retriever,document_chain)
160
+ response = retrieval_chain.invoke({"input": st.session_state.user_prompt})
161
+ #st.write(response)
162
+ if response:
163
+ st.session_state.summary = response['answer']
164
+ st.write(response['answer'])
165
+
166
+ prompt=ChatPromptTemplate.from_messages(
167
+ [
168
+ ("system",
169
+ """
170
+ You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context.
171
+
172
+ 1. Generate an `INSERT` statement for the `programme` table using the following values:
173
+ - `ProgrammeTitle`
174
+ - `ProgrammeType`
175
+ - `Genre`
176
+ - `SubGenre`
177
+ - `Language`
178
+ - `Duration`
179
+ Example:
180
+
181
+
182
+ 2. After generating the `programme` statement, check the `ProgrammeTitle`:
183
+ - If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table.
184
+ - If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table.
185
+
186
+ 3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`.
187
+
188
+
189
+ 4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`.
190
+
191
+ Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions.
192
+
193
+ """
194
+ ),
195
+ ("user","{question}\ ai: ")
196
+ ])
197
+
198
+ agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True)
199
+
200
+ if st.button("Generate Scripts",type="primary"):
201
+ try:
202
+ if st.session_state.summary is not None:
203
+ response=agent.run(prompt.format_prompt(question=st.session_state.summary))
204
+ with st.expander("Expand here to view scripts"):
205
+ if "INSERT" in response:
206
+ final_response = process_sql_script(response)
207
+ final_response_new = handleDurationForEachScript(final_response)
208
+ print(final_response_new)
209
+ st.code(final_response_new, language = 'sql')
210
+ clear_database()
211
+
212
+ #st.write(response)
213
+ except Exception as e:
214
+ st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}")
215
+
216
+
217
+
218
+ # data = {
219
+ # 'Table': ['Programme', 'Episode'],
220
+ # 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'],
221
+ # 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...']
222
+ # }
223
+
224
+ # # Display summary table
225
+ # st.write("### Script Summary")
226
+ # st.table(data)
227
+
228
+ # # Display expandable sections for each script
229
+ # st.write("### Full SQL Scripts")
230
+ # with st.expander("Programme Insert Script"):
231
+ # st.code("INSERT INTO programme ...")
232
+
233
+ # with st.expander("Episode Insert Script"):
234
+ # st.code("INSERT INTO episode ...")
programme.db ADDED
Binary file (12.3 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ langchain-community
4
+ langchain_groq
5
+ langchain-core
6
+ python-dotenv
7
+ langchain_huggingface
8
+ sentence-transformers
9
+ PyPDF2
10
+ faiss-cpu
sqlite.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ #create connection with sqlite
4
+ conn = sqlite3.connect("programme.db", check_same_thread= False)
5
+
6
+ #create cursor to establish connection
7
+ cursor = conn.cursor()
8
+
9
+ #create the table
10
+ prog_table_info = """
11
+ CREATE TABLE programme(ProgrammeTitle nvarchar(500), ProgrammeStatus nvarchar(30), ProgrammeType nvarchar(30), Genre nvarchar(40),
12
+ SubGenre nvarchar(50), Language nvarchar(30), Duration nvarchar(30))
13
+ """
14
+ cursor.execute(prog_table_info)
15
+
16
+ print("Access the table programme")
17
+ data = cursor.execute('''SELECT * from programme''')
18
+ for row in data:
19
+ print(row)
20
+
21
+ eps_table_info = """
22
+ CREATE TABLE episode(EpisodeTitle nvarchar(500), EpisodeNumber int, Duration nvarchar(30))
23
+ """
24
+ cursor.execute(eps_table_info)
25
+
26
+ print("Access the table episode")
27
+ data = cursor.execute('''SELECT * from episode''')
28
+ for row in data:
29
+ print(row)
30
+
31
+ conn.commit()
32
+ conn.close()