Caleb Fahlgren commited on
Commit
e915c68
·
1 Parent(s): 853c083

add plotting capabilities

Browse files
Files changed (1) hide show
  1. app.py +46 -11
app.py CHANGED
@@ -1,13 +1,16 @@
1
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
2
  from huggingface_hub import HfApi
 
 
3
  import pandas as pd
4
  import gradio as gr
5
  import duckdb
6
  import requests
7
  import llama_cpp
8
  import instructor
 
9
 
10
- from pydantic import BaseModel
11
 
12
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
13
  view_name = "dataset_view"
@@ -21,6 +24,7 @@ llama = llama_cpp.Llama(
21
  chat_format="chatml",
22
  n_ctx=2048,
23
  verbose=False,
 
24
  )
25
 
26
  create = instructor.patch(
@@ -29,8 +33,23 @@ create = instructor.patch(
29
  )
30
 
31
 
 
 
 
 
 
 
32
  class SQLResponse(BaseModel):
33
  sql: str
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def get_dataset_ddl(dataset_id: str) -> str:
@@ -63,7 +82,7 @@ CREATE TABLE {} (
63
  return sql_ddl
64
 
65
 
66
- def generate_sql(dataset_id: str, query: str) -> str:
67
  ddl = get_dataset_ddl(dataset_id)
68
 
69
  system_prompt = f"""
@@ -76,6 +95,8 @@ def generate_sql(dataset_id: str, query: str) -> str:
76
  Please assist the user by writing a SQL query that answers the user's question.
77
  """
78
 
 
 
79
  resp: SQLResponse = create(
80
  model="Hermes-2-Pro-Llama-3-8B",
81
  messages=[
@@ -88,15 +109,28 @@ def generate_sql(dataset_id: str, query: str) -> str:
88
  response_model=SQLResponse,
89
  )
90
 
91
- return resp.sql
92
 
 
93
 
94
- def query_dataset(dataset_id: str, query: str) -> tuple[pd.DataFrame, str]:
95
- sql_query = generate_sql(dataset_id, query)
96
- df = conn.execute(sql_query).fetchdf()
97
 
98
- markdown_output = f"""```sql\n{sql_query}```"""
99
- return df, markdown_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
  with gr.Blocks() as demo:
@@ -105,19 +139,20 @@ with gr.Blocks() as demo:
105
  label="Hub Dataset ID",
106
  placeholder="Find your favorite dataset...",
107
  search_type="dataset",
108
- value="jamescalam/world-cities-geo",
109
  )
110
  user_query = gr.Textbox("", label="Ask anything...")
111
 
112
  btn = gr.Button("Ask 🪄")
113
 
114
- df = gr.DataFrame()
115
  sql_query = gr.Markdown(label="Output SQL Query")
 
 
116
 
117
  btn.click(
118
  query_dataset,
119
  inputs=[dataset_id, user_query],
120
- outputs=[df, sql_query],
121
  )
122
 
123
 
 
1
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
2
  from huggingface_hub import HfApi
3
+ import matplotlib.pyplot as plt
4
+ from typing import Tuple, Optional
5
  import pandas as pd
6
  import gradio as gr
7
  import duckdb
8
  import requests
9
  import llama_cpp
10
  import instructor
11
+ import enum
12
 
13
+ from pydantic import BaseModel, Field
14
 
15
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
16
  view_name = "dataset_view"
 
24
  chat_format="chatml",
25
  n_ctx=2048,
26
  verbose=False,
27
+ temperature=0.1,
28
  )
29
 
30
  create = instructor.patch(
 
33
  )
34
 
35
 
36
+ class OutputTypes(str, enum.Enum):
37
+ TABLE = "table"
38
+ BARCHART = "barchart"
39
+ LINECHART = "linechart"
40
+
41
+
42
  class SQLResponse(BaseModel):
43
  sql: str
44
+ visualization_type: Optional[OutputTypes] = Field(
45
+ None, description="The type of visualization to display"
46
+ )
47
+ data_key: Optional[str] = Field(
48
+ None, description="The column name that contains the data for chart responses"
49
+ )
50
+ label_key: Optional[str] = Field(
51
+ None, description="The column name that contains the labels for chart responses"
52
+ )
53
 
54
 
55
  def get_dataset_ddl(dataset_id: str) -> str:
 
82
  return sql_ddl
83
 
84
 
85
+ def generate_query(dataset_id: str, query: str) -> str:
86
  ddl = get_dataset_ddl(dataset_id)
87
 
88
  system_prompt = f"""
 
95
  Please assist the user by writing a SQL query that answers the user's question.
96
  """
97
 
98
+ print("Calling LLM with system prompt: ", system_prompt)
99
+
100
  resp: SQLResponse = create(
101
  model="Hermes-2-Pro-Llama-3-8B",
102
  messages=[
 
109
  response_model=SQLResponse,
110
  )
111
 
112
+ print("Received Response: ", resp)
113
 
114
+ return resp
115
 
 
 
 
116
 
117
+ def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
118
+ response: SQLResponse = generate_query(dataset_id, query)
119
+ df = conn.execute(response.sql).fetchdf()
120
+
121
+ plot = None
122
+
123
+ if response.visualization_type == OutputTypes.LINECHART:
124
+ plot = df.plot(
125
+ kind="line", x=response.data_key, y=response.label_key
126
+ ).get_figure()
127
+ elif response.visualization_type == OutputTypes.BARCHART:
128
+ plot = df.plot(
129
+ kind="bar", x=response.data_key, y=response.label_key
130
+ ).get_figure()
131
+
132
+ markdown_output = f"""```sql\n{response.sql}\n```"""
133
+ return df, markdown_output, plot
134
 
135
 
136
  with gr.Blocks() as demo:
 
139
  label="Hub Dataset ID",
140
  placeholder="Find your favorite dataset...",
141
  search_type="dataset",
142
+ value="teknium/OpenHermes-2.5",
143
  )
144
  user_query = gr.Textbox("", label="Ask anything...")
145
 
146
  btn = gr.Button("Ask 🪄")
147
 
 
148
  sql_query = gr.Markdown(label="Output SQL Query")
149
+ df = gr.DataFrame()
150
+ plot = gr.Plot()
151
 
152
  btn.click(
153
  query_dataset,
154
  inputs=[dataset_id, user_query],
155
+ outputs=[df, sql_query, plot],
156
  )
157
 
158