DrishtiSharma commited on
Commit
f4edd92
Β·
verified Β·
1 Parent(s): 41d630c

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +145 -46
interim.py CHANGED
@@ -4,12 +4,13 @@ import sqlite3
4
  import os
5
  import json
6
  from pathlib import Path
 
7
  from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
- from crewai_tools import tool
10
  from langchain_groq import ChatGroq
 
11
  from langchain.schema.output import LLMResult
12
- from langchain_core.callbacks.base import BaseCallbackHandler
13
  from langchain_community.tools.sql_database.tool import (
14
  InfoSQLDatabaseTool,
15
  ListSQLDatabaseTool,
@@ -20,39 +21,41 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # API Key
24
- os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
  # Initialize LLM
27
- class LLMCallbackHandler(BaseCallbackHandler):
28
- def __init__(self, log_path: Path):
29
- self.log_path = log_path
30
-
31
- def on_llm_start(self, serialized, prompts, **kwargs):
32
- with self.log_path.open("a", encoding="utf-8") as file:
33
- file.write(json.dumps({"event": "llm_start", "text": prompts[0], "timestamp": datetime.now().isoformat()}) + "\n")
34
-
35
- def on_llm_end(self, response: LLMResult, **kwargs):
36
- generation = response.generations[-1][-1].message.content
37
- with self.log_path.open("a", encoding="utf-8") as file:
38
- file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
-
40
- llm = ChatGroq(
41
- temperature=0,
42
- model_name="groq/llama-3.3-70b-versatile",
43
- max_tokens=1024,
44
- callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
- )
46
 
47
- st.title("SQL-RAG Using CrewAI πŸš€")
48
- st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Initialize session state for data persistence
51
  if "df" not in st.session_state:
52
  st.session_state.df = None
 
 
53
 
54
  # Dataset Input
55
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
 
56
  if input_option == "Use Hugging Face Dataset":
57
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
58
  if st.button("Load Dataset"):
@@ -60,16 +63,25 @@ if input_option == "Use Hugging Face Dataset":
60
  with st.spinner("Loading dataset..."):
61
  dataset = load_dataset(dataset_name, split="train")
62
  st.session_state.df = pd.DataFrame(dataset)
 
63
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
64
- st.dataframe(st.session_state.df.head())
65
  except Exception as e:
66
  st.error(f"Error: {e}")
 
67
  elif input_option == "Upload CSV File":
68
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
69
  if uploaded_file:
70
- st.session_state.df = pd.read_csv(uploaded_file)
71
- st.success("File uploaded successfully!")
72
- st.dataframe(st.session_state.df.head())
 
 
 
 
 
 
 
 
73
 
74
  # SQL-RAG Analysis
75
  if st.session_state.df is not None:
@@ -86,19 +98,20 @@ if st.session_state.df is not None:
86
 
87
  @tool("tables_schema")
88
  def tables_schema(tables: str) -> str:
89
- """Get schema and sample rows for given tables."""
90
  return InfoSQLDatabaseTool(db=db).invoke(tables)
91
 
92
  @tool("execute_sql")
93
  def execute_sql(sql_query: str) -> str:
94
- """Execute a SQL query against the database."""
95
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
96
 
97
  @tool("check_sql")
98
  def check_sql(sql_query: str) -> str:
99
- """Check the validity of a SQL query."""
100
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
101
 
 
102
  sql_dev = Agent(
103
  role="Senior Database Developer",
104
  goal="Extract data using optimized SQL queries.",
@@ -116,11 +129,19 @@ if st.session_state.df is not None:
116
 
117
  report_writer = Agent(
118
  role="Technical Report Writer",
119
- goal="Summarize the insights into a clear report.",
120
- backstory="An expert in summarizing data insights into readable reports.",
 
 
 
 
 
 
 
121
  llm=llm,
122
  )
123
 
 
124
  extract_data = Task(
125
  description="Extract data based on the query: {query}.",
126
  expected_output="Database results matching the query.",
@@ -129,33 +150,111 @@ if st.session_state.df is not None:
129
 
130
  analyze_data = Task(
131
  description="Analyze the extracted data for query: {query}.",
132
- expected_output="Analysis text summarizing findings.",
133
  agent=data_analyst,
134
  context=[extract_data],
135
  )
136
 
137
  write_report = Task(
138
- description="Summarize the analysis into an executive report.",
139
- expected_output="Markdown report of insights.",
140
  agent=report_writer,
141
  context=[analyze_data],
142
  )
143
 
144
- crew = Crew(
 
 
 
 
 
 
 
 
145
  agents=[sql_dev, data_analyst, report_writer],
146
  tasks=[extract_data, analyze_data, write_report],
147
  process=Process.sequential,
148
  verbose=True,
149
  )
150
 
151
- query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary for senior employees?'")
152
- if st.button("Submit Query"):
153
- with st.spinner("Processing query..."):
154
- inputs = {"query": query}
155
- result = crew.kickoff(inputs=inputs)
156
- st.markdown("### Analysis Report:")
157
- st.markdown(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  temp_dir.cleanup()
160
  else:
161
- st.info("Please load a dataset to proceed.")
 
 
 
 
 
 
 
4
  import os
5
  import json
6
  from pathlib import Path
7
+ import plotly.express as px
8
  from datetime import datetime, timezone
9
  from crewai import Agent, Crew, Process, Task
10
+ from crewai.tools import tool
11
  from langchain_groq import ChatGroq
12
+ from langchain_openai import ChatOpenAI
13
  from langchain.schema.output import LLMResult
 
14
  from langchain_community.tools.sql_database.tool import (
15
  InfoSQLDatabaseTool,
16
  ListSQLDatabaseTool,
 
21
  from datasets import load_dataset
22
  import tempfile
23
 
24
+ st.title("SQL-RAG Using CrewAI πŸš€")
25
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
26
 
27
  # Initialize LLM
28
+ llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Model Selection
31
+ model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
32
+
33
+ # API Key Validation and LLM Initialization
34
+ groq_api_key = os.getenv("GROQ_API_KEY")
35
+ openai_api_key = os.getenv("OPENAI_API_KEY")
36
+
37
+ if model_choice == "llama-3.3-70b":
38
+ if not groq_api_key:
39
+ st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
40
+ llm = None
41
+ else:
42
+ llm = ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
43
+ elif model_choice == "GPT-4o":
44
+ if not openai_api_key:
45
+ st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
46
+ llm = None
47
+ else:
48
+ llm = ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
49
 
50
  # Initialize session state for data persistence
51
  if "df" not in st.session_state:
52
  st.session_state.df = None
53
+ if "show_preview" not in st.session_state:
54
+ st.session_state.show_preview = False
55
 
56
  # Dataset Input
57
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
58
+
59
  if input_option == "Use Hugging Face Dataset":
60
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
61
  if st.button("Load Dataset"):
 
63
  with st.spinner("Loading dataset..."):
64
  dataset = load_dataset(dataset_name, split="train")
65
  st.session_state.df = pd.DataFrame(dataset)
66
+ st.session_state.show_preview = True # Show preview after loading
67
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
 
68
  except Exception as e:
69
  st.error(f"Error: {e}")
70
+
71
  elif input_option == "Upload CSV File":
72
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
73
  if uploaded_file:
74
+ try:
75
+ st.session_state.df = pd.read_csv(uploaded_file)
76
+ st.session_state.show_preview = True # Show preview after loading
77
+ st.success("File uploaded successfully!")
78
+ except Exception as e:
79
+ st.error(f"Error loading file: {e}")
80
+
81
+ # Show Dataset Preview Only After Loading
82
+ if st.session_state.df is not None and st.session_state.show_preview:
83
+ st.subheader("πŸ“‚ Dataset Preview")
84
+ st.dataframe(st.session_state.df.head())
85
 
86
  # SQL-RAG Analysis
87
  if st.session_state.df is not None:
 
98
 
99
  @tool("tables_schema")
100
  def tables_schema(tables: str) -> str:
101
+ """Get the schema and sample rows for the specified tables."""
102
  return InfoSQLDatabaseTool(db=db).invoke(tables)
103
 
104
  @tool("execute_sql")
105
  def execute_sql(sql_query: str) -> str:
106
+ """Execute a SQL query against the database and return the results."""
107
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
108
 
109
  @tool("check_sql")
110
  def check_sql(sql_query: str) -> str:
111
+ """Validate the SQL query syntax and structure before execution."""
112
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
113
 
114
+ # Agents for SQL data extraction and analysis
115
  sql_dev = Agent(
116
  role="Senior Database Developer",
117
  goal="Extract data using optimized SQL queries.",
 
129
 
130
  report_writer = Agent(
131
  role="Technical Report Writer",
132
+ goal="Write a structured report with Key Insights and Analysis. DO NOT include Introduction or Conclusion.",
133
+ backstory="Specializes in detailed analytical reports without conclusions.",
134
+ llm=llm,
135
+ )
136
+
137
+ conclusion_writer = Agent(
138
+ role="Conclusion Specialist",
139
+ goal="Summarize findings into a clear and concise 3-5 line Conclusion highlighting only the most important insights.",
140
+ backstory="An expert in crafting impactful and clear conclusions.",
141
  llm=llm,
142
  )
143
 
144
+ # Define tasks for report and conclusion
145
  extract_data = Task(
146
  description="Extract data based on the query: {query}.",
147
  expected_output="Database results matching the query.",
 
150
 
151
  analyze_data = Task(
152
  description="Analyze the extracted data for query: {query}.",
153
+ expected_output="Key Insights and Analysis without any Introduction or Conclusion.",
154
  agent=data_analyst,
155
  context=[extract_data],
156
  )
157
 
158
  write_report = Task(
159
+ description="Write the analysis report with Key Insights. DO NOT include a Conclusion.",
160
+ expected_output="Markdown-formatted report excluding Conclusion.",
161
  agent=report_writer,
162
  context=[analyze_data],
163
  )
164
 
165
+ write_conclusion = Task(
166
+ description="Write a brief and impactful 3-5 line Conclusion summarizing only the most important insights.",
167
+ expected_output="Markdown-formatted concise Conclusion section.",
168
+ agent=conclusion_writer,
169
+ context=[analyze_data],
170
+ )
171
+
172
+ # Separate Crews for report and conclusion
173
+ crew_report = Crew(
174
  agents=[sql_dev, data_analyst, report_writer],
175
  tasks=[extract_data, analyze_data, write_report],
176
  process=Process.sequential,
177
  verbose=True,
178
  )
179
 
180
+ crew_conclusion = Crew(
181
+ agents=[data_analyst, conclusion_writer],
182
+ tasks=[write_conclusion],
183
+ process=Process.sequential,
184
+ verbose=True,
185
+ )
186
+
187
+ # Tabs for Query Results and Visualizations
188
+ tab1, tab2 = st.tabs(["πŸ” Query Insights + Viz", "πŸ“Š Full Data Viz"])
189
+
190
+ # Query Insights + Visualization
191
+ with tab1:
192
+ query = st.text_area("Enter Query:", value="Provide insights into the salary of a Principal Data Scientist.")
193
+ if st.button("Submit Query"):
194
+ with st.spinner("Processing query..."):
195
+ # Step 1: Generate the analysis report
196
+ report_inputs = {"query": query + " Provide detailed analysis but DO NOT include Conclusion."}
197
+ report_result = crew_report.kickoff(inputs=report_inputs)
198
+
199
+ # Step 2: Generate only the concise conclusion
200
+ conclusion_inputs = {"query": query + " Provide ONLY the most important insights in 3-5 concise lines."}
201
+ conclusion_result = crew_conclusion.kickoff(inputs=conclusion_inputs)
202
+
203
+ # Step 3: Display the report
204
+ st.markdown("### Analysis Report:")
205
+ st.markdown(report_result if report_result else "⚠️ No Report Generated.")
206
+
207
+ # Step 4: Generate Visualizations
208
+ visualizations = []
209
+
210
+ fig_salary = px.box(st.session_state.df, x="job_title", y="salary_in_usd",
211
+ title="Salary Distribution by Job Title")
212
+ visualizations.append(fig_salary)
213
+
214
+ fig_experience = px.bar(
215
+ st.session_state.df.groupby("experience_level")["salary_in_usd"].mean().reset_index(),
216
+ x="experience_level", y="salary_in_usd",
217
+ title="Average Salary by Experience Level"
218
+ )
219
+ visualizations.append(fig_experience)
220
+
221
+ fig_employment = px.box(st.session_state.df, x="employment_type", y="salary_in_usd",
222
+ title="Salary Distribution by Employment Type")
223
+ visualizations.append(fig_employment)
224
+
225
+ # Step 5: Insert Visual Insights
226
+ st.markdown("## πŸ“Š Visual Insights")
227
+ for fig in visualizations:
228
+ st.plotly_chart(fig, use_container_width=True)
229
+
230
+ # Step 6: Display Concise Conclusion
231
+ st.markdown("## Conclusion")
232
+ st.markdown(conclusion_result if conclusion_result else "⚠️ No Conclusion Generated.")
233
+
234
+ # Full Data Visualization Tab
235
+ with tab2:
236
+ st.subheader("πŸ“Š Comprehensive Data Visualizations")
237
+
238
+ fig1 = px.histogram(st.session_state.df, x="job_title", title="Job Title Frequency")
239
+ st.plotly_chart(fig1)
240
+
241
+ fig2 = px.bar(
242
+ st.session_state.df.groupby("experience_level")["salary_in_usd"].mean().reset_index(),
243
+ x="experience_level", y="salary_in_usd",
244
+ title="Average Salary by Experience Level"
245
+ )
246
+ st.plotly_chart(fig2)
247
+
248
+ fig3 = px.box(st.session_state.df, x="employment_type", y="salary_in_usd",
249
+ title="Salary Distribution by Employment Type")
250
+ st.plotly_chart(fig3)
251
 
252
  temp_dir.cleanup()
253
  else:
254
+ st.info("Please load a dataset to proceed.")
255
+
256
+
257
+ # Sidebar Reference
258
+ with st.sidebar:
259
+ st.header("πŸ“š Reference:")
260
+ st.markdown("[SQL Agents w CrewAI & Llama 3 - Plaban Nayak](https://github.com/plaban1981/Agents/blob/main/SQL_Agents_with_CrewAI_and_Llama_3.ipynb)")