DrishtiSharma commited on
Commit
9320eb6
Β·
verified Β·
1 Parent(s): 564add3

Update interim_radio.py

Browse files
Files changed (1) hide show
  1. interim_radio.py +41 -51
interim_radio.py CHANGED
@@ -20,10 +20,10 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # Environment setup
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # LLM Callback Logger
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
@@ -37,116 +37,106 @@ class LLMCallbackHandler(BaseCallbackHandler):
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
- # Initialize the LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
- model_name="mixtral-8x7b-32768",
 
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
 
50
- # Input Options
51
- input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
52
- df = None
53
 
 
 
54
  if input_option == "Use Hugging Face Dataset":
55
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
56
  if st.button("Load Dataset"):
57
  try:
58
- with st.spinner("Loading Hugging Face dataset..."):
59
  dataset = load_dataset(dataset_name, split="train")
60
- df = pd.DataFrame(dataset)
61
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
- st.dataframe(df.head())
63
  except Exception as e:
64
- st.error(f"Error loading dataset: {e}")
65
- else:
66
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
67
  if uploaded_file:
68
- df = pd.read_csv(uploaded_file)
69
  st.success("File uploaded successfully!")
70
- st.dataframe(df.head())
71
 
72
  # SQL-RAG Analysis
73
- if df is not None:
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
76
  connection = sqlite3.connect(db_path)
77
- df.to_sql("salaries", connection, if_exists="replace", index=False)
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
- # Tools with proper docstrings
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
- """List all tables in the SQLite database."""
84
  return ListSQLDatabaseTool(db=db).invoke("")
85
 
86
  @tool("tables_schema")
87
  def tables_schema(tables: str) -> str:
88
- """
89
- Get the schema and sample rows for specific tables in the database.
90
- Input: Comma-separated table names.
91
- Example: 'salaries'
92
- """
93
  return InfoSQLDatabaseTool(db=db).invoke(tables)
94
 
95
  @tool("execute_sql")
96
  def execute_sql(sql_query: str) -> str:
97
- """
98
- Execute a valid SQL query on the database and return the results.
99
- Input: A SQL query string.
100
- Example: 'SELECT * FROM salaries LIMIT 5;'
101
- """
102
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
103
 
104
  @tool("check_sql")
105
  def check_sql(sql_query: str) -> str:
106
- """
107
- Check the validity of a SQL query before execution.
108
- Input: A SQL query string.
109
- Example: 'SELECT salary FROM salaries WHERE salary > 10000;'
110
- """
111
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
112
 
113
- # Agents
114
  sql_dev = Agent(
115
- role="Database Developer",
116
- goal="Extract relevant data by executing SQL queries.",
 
117
  llm=llm,
118
  tools=[list_tables, tables_schema, execute_sql, check_sql],
119
  )
120
 
121
  data_analyst = Agent(
122
- role="Data Analyst",
123
- goal="Analyze the extracted data and generate detailed insights.",
 
124
  llm=llm,
125
  )
126
 
127
  report_writer = Agent(
128
- role="Report Writer",
129
- goal="Summarize the analysis into an executive report.",
 
130
  llm=llm,
131
  )
132
 
133
- # Tasks
134
  extract_data = Task(
135
- description="Extract data for the query: {query}.",
136
- expected_output="Database query results.",
137
  agent=sql_dev,
138
  )
139
 
140
  analyze_data = Task(
141
- description="Analyze the query results for: {query}.",
142
- expected_output="Analysis report.",
143
  agent=data_analyst,
144
  context=[extract_data],
145
  )
146
 
147
  write_report = Task(
148
- description="Summarize the analysis into an executive summary.",
149
- expected_output="Markdown-formatted report.",
150
  agent=report_writer,
151
  context=[analyze_data],
152
  )
@@ -155,12 +145,12 @@ if df is not None:
155
  agents=[sql_dev, data_analyst, report_writer],
156
  tasks=[extract_data, analyze_data, write_report],
157
  process=Process.sequential,
158
- verbose=2,
159
  )
160
 
161
- query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
162
  if st.button("Submit Query"):
163
- with st.spinner("Processing your query with CrewAI..."):
164
  inputs = {"query": query}
165
  result = crew.kickoff(inputs=inputs)
166
  st.markdown("### Analysis Report:")
@@ -168,4 +158,4 @@ if df is not None:
168
 
169
  temp_dir.cleanup()
170
  else:
171
- st.info("Load a dataset to proceed.")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # Initialize LLM
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
 
40
  llm = ChatGroq(
41
  temperature=0,
42
+ model_name="groq/llama-3.3-70b-versatile",
43
+ max_tokens=500,
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
 
50
+ # Initialize session state for data persistence
51
+ if "df" not in st.session_state:
52
+ st.session_state.df = None
53
 
54
+ # Dataset Input
55
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
56
  if input_option == "Use Hugging Face Dataset":
57
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
58
  if st.button("Load Dataset"):
59
  try:
60
+ with st.spinner("Loading dataset..."):
61
  dataset = load_dataset(dataset_name, split="train")
62
+ st.session_state.df = pd.DataFrame(dataset)
63
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
64
+ st.dataframe(st.session_state.df.head())
65
  except Exception as e:
66
+ st.error(f"Error: {e}")
67
+ elif input_option == "Upload CSV File":
68
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
69
  if uploaded_file:
70
+ st.session_state.df = pd.read_csv(uploaded_file)
71
  st.success("File uploaded successfully!")
72
+ st.dataframe(st.session_state.df.head())
73
 
74
  # SQL-RAG Analysis
75
+ if st.session_state.df is not None:
76
  temp_dir = tempfile.TemporaryDirectory()
77
  db_path = os.path.join(temp_dir.name, "data.db")
78
  connection = sqlite3.connect(db_path)
79
+ st.session_state.df.to_sql("salaries", connection, if_exists="replace", index=False)
80
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
81
 
 
82
  @tool("list_tables")
83
  def list_tables() -> str:
84
+ """List all tables in the database."""
85
  return ListSQLDatabaseTool(db=db).invoke("")
86
 
87
  @tool("tables_schema")
88
  def tables_schema(tables: str) -> str:
89
+ """Get schema and sample rows for given tables."""
 
 
 
 
90
  return InfoSQLDatabaseTool(db=db).invoke(tables)
91
 
92
  @tool("execute_sql")
93
  def execute_sql(sql_query: str) -> str:
94
+ """Execute a SQL query against the database."""
 
 
 
 
95
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
96
 
97
  @tool("check_sql")
98
  def check_sql(sql_query: str) -> str:
99
+ """Check the validity of a SQL query."""
 
 
 
 
100
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
101
 
 
102
  sql_dev = Agent(
103
+ role="Senior Database Developer",
104
+ goal="Extract data using optimized SQL queries.",
105
+ backstory="An expert in writing optimized SQL queries for complex databases.",
106
  llm=llm,
107
  tools=[list_tables, tables_schema, execute_sql, check_sql],
108
  )
109
 
110
  data_analyst = Agent(
111
+ role="Senior Data Analyst",
112
+ goal="Analyze the data and produce insights.",
113
+ backstory="A seasoned analyst who identifies trends and patterns in datasets.",
114
  llm=llm,
115
  )
116
 
117
  report_writer = Agent(
118
+ role="Technical Report Writer",
119
+ goal="Summarize the insights into a clear report.",
120
+ backstory="An expert in summarizing data insights into readable reports.",
121
  llm=llm,
122
  )
123
 
 
124
  extract_data = Task(
125
+ description="Extract data based on the query: {query}.",
126
+ expected_output="Database results matching the query.",
127
  agent=sql_dev,
128
  )
129
 
130
  analyze_data = Task(
131
+ description="Analyze the extracted data for query: {query}.",
132
+ expected_output="Analysis text summarizing findings.",
133
  agent=data_analyst,
134
  context=[extract_data],
135
  )
136
 
137
  write_report = Task(
138
+ description="Summarize the analysis into an executive report.",
139
+ expected_output="Markdown report of insights.",
140
  agent=report_writer,
141
  context=[analyze_data],
142
  )
 
145
  agents=[sql_dev, data_analyst, report_writer],
146
  tasks=[extract_data, analyze_data, write_report],
147
  process=Process.sequential,
148
+ verbose=True,
149
  )
150
 
151
+ query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary for senior employees?'")
152
  if st.button("Submit Query"):
153
+ with st.spinner("Processing query..."):
154
  inputs = {"query": query}
155
  result = crew.kickoff(inputs=inputs)
156
  st.markdown("### Analysis Report:")
 
158
 
159
  temp_dir.cleanup()
160
  else:
161
+ st.info("Please load a dataset to proceed.")