DrishtiSharma commited on
Commit
7752a10
Β·
verified Β·
1 Parent(s): a849379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -45
app.py CHANGED
@@ -20,10 +20,10 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # Setup API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # LLM Logging
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
@@ -37,7 +37,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
- # LLM Setup
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
@@ -45,98 +45,111 @@ llm = ChatGroq(
45
  )
46
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
- st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
-
50
- # Primary Option: Hugging Face Dataset
51
- st.subheader("Option 1: Use a Hugging Face Dataset")
52
- default_dataset = "Einstellung/demo-salaries"
53
- dataset_name = st.text_input("Enter Hugging Face dataset name:", value=default_dataset)
54
 
 
 
55
  df = None
56
- if dataset_name:
57
- try:
58
- with st.spinner("Loading Hugging Face dataset..."):
59
- dataset = load_dataset(dataset_name, split="train")
60
- df = pd.DataFrame(dataset)
61
- st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
- st.dataframe(df.head())
63
- except Exception as e:
64
- st.error(f"Error loading Hugging Face dataset: {e}")
65
-
66
- # Secondary Option: File Upload
67
- st.subheader("Option 2: Upload Your CSV File")
68
- uploaded_file = st.file_uploader("Upload your dataset (CSV format):", type=["csv"])
69
- if uploaded_file and df is None:
70
- with st.spinner("Loading uploaded file..."):
71
  df = pd.read_csv(uploaded_file)
72
  st.success("File uploaded successfully!")
73
  st.dataframe(df.head())
74
 
 
75
  if df is not None:
76
- # Create SQLite database
77
  temp_dir = tempfile.TemporaryDirectory()
78
  db_path = os.path.join(temp_dir.name, "data.db")
79
  connection = sqlite3.connect(db_path)
80
- df.to_sql("data_table", connection, if_exists="replace", index=False)
81
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
82
 
83
- # Tools
84
  @tool("list_tables")
85
  def list_tables() -> str:
 
86
  return ListSQLDatabaseTool(db=db).invoke("")
87
 
88
  @tool("tables_schema")
89
  def tables_schema(tables: str) -> str:
 
 
 
 
 
 
90
  return InfoSQLDatabaseTool(db=db).invoke(tables)
91
 
92
  @tool("execute_sql")
93
  def execute_sql(sql_query: str) -> str:
 
 
 
 
 
 
94
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
95
 
96
  @tool("check_sql")
97
  def check_sql(sql_query: str) -> str:
 
 
 
 
 
 
98
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
99
 
100
  # Agents
101
  sql_dev = Agent(
102
  role="Database Developer",
103
- goal="Extract data from the database.",
104
  llm=llm,
105
  tools=[list_tables, tables_schema, execute_sql, check_sql],
106
- allow_delegation=False,
107
  )
108
 
109
  data_analyst = Agent(
110
  role="Data Analyst",
111
- goal="Analyze and provide insights.",
112
  llm=llm,
113
- allow_delegation=False,
114
  )
115
 
116
  report_writer = Agent(
117
- role="Report Editor",
118
- goal="Summarize the analysis.",
119
  llm=llm,
120
- allow_delegation=False,
121
  )
122
 
123
  # Tasks
124
  extract_data = Task(
125
- description="Extract data required for the query: {query}.",
126
- expected_output="Database result for the query",
127
  agent=sql_dev,
128
  )
129
 
130
  analyze_data = Task(
131
- description="Analyze the data for: {query}.",
132
- expected_output="Detailed analysis text",
133
  agent=data_analyst,
134
  context=[extract_data],
135
  )
136
 
137
  write_report = Task(
138
- description="Summarize the analysis into a short report.",
139
- expected_output="Markdown report",
140
  agent=report_writer,
141
  context=[analyze_data],
142
  )
@@ -146,12 +159,11 @@ if df is not None:
146
  tasks=[extract_data, analyze_data, write_report],
147
  process=Process.sequential,
148
  verbose=2,
149
- memory=False,
150
  )
151
 
152
- query = st.text_input("Enter your query:", placeholder="e.g., 'What is the average salary by experience level?'")
153
- if query:
154
- with st.spinner("Processing your query..."):
155
  inputs = {"query": query}
156
  result = crew.kickoff(inputs=inputs)
157
  st.markdown("### Analysis Report:")
@@ -159,4 +171,4 @@ if df is not None:
159
 
160
  temp_dir.cleanup()
161
  else:
162
- st.warning("Please load a Hugging Face dataset or upload a CSV file to proceed.")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # Environment setup
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # LLM Callback Logger
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
+ # Initialize the LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
 
45
  )
46
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
 
 
 
 
 
49
 
50
+ # Input Options
51
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
52
  df = None
53
+
54
+ if input_option == "Use Hugging Face Dataset":
55
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
56
+ if st.button("Load Dataset"):
57
+ try:
58
+ with st.spinner("Loading Hugging Face dataset..."):
59
+ dataset = load_dataset(dataset_name, split="train")
60
+ df = pd.DataFrame(dataset)
61
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
+ st.dataframe(df.head())
63
+ except Exception as e:
64
+ st.error(f"Error loading dataset: {e}")
65
+ else:
66
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
67
+ if uploaded_file:
68
  df = pd.read_csv(uploaded_file)
69
  st.success("File uploaded successfully!")
70
  st.dataframe(df.head())
71
 
72
+ # SQL-RAG Analysis
73
  if df is not None:
 
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
76
  connection = sqlite3.connect(db_path)
77
+ df.to_sql("salaries", connection, if_exists="replace", index=False)
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
+ # Tools with proper docstrings
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
+ """List all tables in the SQLite database."""
84
  return ListSQLDatabaseTool(db=db).invoke("")
85
 
86
  @tool("tables_schema")
87
  def tables_schema(tables: str) -> str:
88
+ """
89
+ Get the schema and sample rows for specific tables in the database.
90
+
91
+ Input: Comma-separated table names.
92
+ Example: 'salaries'
93
+ """
94
  return InfoSQLDatabaseTool(db=db).invoke(tables)
95
 
96
  @tool("execute_sql")
97
  def execute_sql(sql_query: str) -> str:
98
+ """
99
+ Execute a valid SQL query on the database and return the results.
100
+
101
+ Input: A SQL query string.
102
+ Example: 'SELECT * FROM salaries LIMIT 5;'
103
+ """
104
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
105
 
106
  @tool("check_sql")
107
  def check_sql(sql_query: str) -> str:
108
+ """
109
+ Check the validity of a SQL query before execution.
110
+
111
+ Input: A SQL query string.
112
+ Example: 'SELECT salary FROM salaries WHERE salary > 10000;'
113
+ """
114
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
115
 
116
  # Agents
117
  sql_dev = Agent(
118
  role="Database Developer",
119
+ goal="Extract relevant data by executing SQL queries.",
120
  llm=llm,
121
  tools=[list_tables, tables_schema, execute_sql, check_sql],
 
122
  )
123
 
124
  data_analyst = Agent(
125
  role="Data Analyst",
126
+ goal="Analyze the extracted data and generate detailed insights.",
127
  llm=llm,
 
128
  )
129
 
130
  report_writer = Agent(
131
+ role="Report Writer",
132
+ goal="Summarize the analysis into an executive report.",
133
  llm=llm,
 
134
  )
135
 
136
  # Tasks
137
  extract_data = Task(
138
+ description="Extract data for the query: {query}.",
139
+ expected_output="Database query results.",
140
  agent=sql_dev,
141
  )
142
 
143
  analyze_data = Task(
144
+ description="Analyze the query results for: {query}.",
145
+ expected_output="Analysis report.",
146
  agent=data_analyst,
147
  context=[extract_data],
148
  )
149
 
150
  write_report = Task(
151
+ description="Summarize the analysis into an executive summary.",
152
+ expected_output="Markdown-formatted report.",
153
  agent=report_writer,
154
  context=[analyze_data],
155
  )
 
159
  tasks=[extract_data, analyze_data, write_report],
160
  process=Process.sequential,
161
  verbose=2,
 
162
  )
163
 
164
+ query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
165
+ if st.button("Submit Query"):
166
+ with st.spinner("Processing your query with CrewAI..."):
167
  inputs = {"query": query}
168
  result = crew.kickoff(inputs=inputs)
169
  st.markdown("### Analysis Report:")
 
171
 
172
  temp_dir.cleanup()
173
  else:
174
+ st.info("Load a dataset to proceed.")