Spaces:
Runtime error
Runtime error
Caleb Fahlgren
commited on
Commit
·
e915c68
1
Parent(s):
853c083
add plotting capabilities
Browse files
app.py
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
2 |
from huggingface_hub import HfApi
|
|
|
|
|
3 |
import pandas as pd
|
4 |
import gradio as gr
|
5 |
import duckdb
|
6 |
import requests
|
7 |
import llama_cpp
|
8 |
import instructor
|
|
|
9 |
|
10 |
-
from pydantic import BaseModel
|
11 |
|
12 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
13 |
view_name = "dataset_view"
|
@@ -21,6 +24,7 @@ llama = llama_cpp.Llama(
|
|
21 |
chat_format="chatml",
|
22 |
n_ctx=2048,
|
23 |
verbose=False,
|
|
|
24 |
)
|
25 |
|
26 |
create = instructor.patch(
|
@@ -29,8 +33,23 @@ create = instructor.patch(
|
|
29 |
)
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class SQLResponse(BaseModel):
|
33 |
sql: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def get_dataset_ddl(dataset_id: str) -> str:
|
@@ -63,7 +82,7 @@ CREATE TABLE {} (
|
|
63 |
return sql_ddl
|
64 |
|
65 |
|
66 |
-
def
|
67 |
ddl = get_dataset_ddl(dataset_id)
|
68 |
|
69 |
system_prompt = f"""
|
@@ -76,6 +95,8 @@ def generate_sql(dataset_id: str, query: str) -> str:
|
|
76 |
Please assist the user by writing a SQL query that answers the user's question.
|
77 |
"""
|
78 |
|
|
|
|
|
79 |
resp: SQLResponse = create(
|
80 |
model="Hermes-2-Pro-Llama-3-8B",
|
81 |
messages=[
|
@@ -88,15 +109,28 @@ def generate_sql(dataset_id: str, query: str) -> str:
|
|
88 |
response_model=SQLResponse,
|
89 |
)
|
90 |
|
91 |
-
|
92 |
|
|
|
93 |
|
94 |
-
def query_dataset(dataset_id: str, query: str) -> tuple[pd.DataFrame, str]:
|
95 |
-
sql_query = generate_sql(dataset_id, query)
|
96 |
-
df = conn.execute(sql_query).fetchdf()
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
with gr.Blocks() as demo:
|
@@ -105,19 +139,20 @@ with gr.Blocks() as demo:
|
|
105 |
label="Hub Dataset ID",
|
106 |
placeholder="Find your favorite dataset...",
|
107 |
search_type="dataset",
|
108 |
-
value="
|
109 |
)
|
110 |
user_query = gr.Textbox("", label="Ask anything...")
|
111 |
|
112 |
btn = gr.Button("Ask 🪄")
|
113 |
|
114 |
-
df = gr.DataFrame()
|
115 |
sql_query = gr.Markdown(label="Output SQL Query")
|
|
|
|
|
116 |
|
117 |
btn.click(
|
118 |
query_dataset,
|
119 |
inputs=[dataset_id, user_query],
|
120 |
-
outputs=[df, sql_query],
|
121 |
)
|
122 |
|
123 |
|
|
|
1 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
2 |
from huggingface_hub import HfApi
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from typing import Tuple, Optional
|
5 |
import pandas as pd
|
6 |
import gradio as gr
|
7 |
import duckdb
|
8 |
import requests
|
9 |
import llama_cpp
|
10 |
import instructor
|
11 |
+
import enum
|
12 |
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
|
15 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
16 |
view_name = "dataset_view"
|
|
|
24 |
chat_format="chatml",
|
25 |
n_ctx=2048,
|
26 |
verbose=False,
|
27 |
+
temperature=0.1,
|
28 |
)
|
29 |
|
30 |
create = instructor.patch(
|
|
|
33 |
)
|
34 |
|
35 |
|
36 |
+
class OutputTypes(str, enum.Enum):
|
37 |
+
TABLE = "table"
|
38 |
+
BARCHART = "barchart"
|
39 |
+
LINECHART = "linechart"
|
40 |
+
|
41 |
+
|
42 |
class SQLResponse(BaseModel):
|
43 |
sql: str
|
44 |
+
visualization_type: Optional[OutputTypes] = Field(
|
45 |
+
None, description="The type of visualization to display"
|
46 |
+
)
|
47 |
+
data_key: Optional[str] = Field(
|
48 |
+
None, description="The column name that contains the data for chart responses"
|
49 |
+
)
|
50 |
+
label_key: Optional[str] = Field(
|
51 |
+
None, description="The column name that contains the labels for chart responses"
|
52 |
+
)
|
53 |
|
54 |
|
55 |
def get_dataset_ddl(dataset_id: str) -> str:
|
|
|
82 |
return sql_ddl
|
83 |
|
84 |
|
85 |
+
def generate_query(dataset_id: str, query: str) -> str:
|
86 |
ddl = get_dataset_ddl(dataset_id)
|
87 |
|
88 |
system_prompt = f"""
|
|
|
95 |
Please assist the user by writing a SQL query that answers the user's question.
|
96 |
"""
|
97 |
|
98 |
+
print("Calling LLM with system prompt: ", system_prompt)
|
99 |
+
|
100 |
resp: SQLResponse = create(
|
101 |
model="Hermes-2-Pro-Llama-3-8B",
|
102 |
messages=[
|
|
|
109 |
response_model=SQLResponse,
|
110 |
)
|
111 |
|
112 |
+
print("Received Response: ", resp)
|
113 |
|
114 |
+
return resp
|
115 |
|
|
|
|
|
|
|
116 |
|
117 |
+
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
|
118 |
+
response: SQLResponse = generate_query(dataset_id, query)
|
119 |
+
df = conn.execute(response.sql).fetchdf()
|
120 |
+
|
121 |
+
plot = None
|
122 |
+
|
123 |
+
if response.visualization_type == OutputTypes.LINECHART:
|
124 |
+
plot = df.plot(
|
125 |
+
kind="line", x=response.data_key, y=response.label_key
|
126 |
+
).get_figure()
|
127 |
+
elif response.visualization_type == OutputTypes.BARCHART:
|
128 |
+
plot = df.plot(
|
129 |
+
kind="bar", x=response.data_key, y=response.label_key
|
130 |
+
).get_figure()
|
131 |
+
|
132 |
+
markdown_output = f"""```sql\n{response.sql}\n```"""
|
133 |
+
return df, markdown_output, plot
|
134 |
|
135 |
|
136 |
with gr.Blocks() as demo:
|
|
|
139 |
label="Hub Dataset ID",
|
140 |
placeholder="Find your favorite dataset...",
|
141 |
search_type="dataset",
|
142 |
+
value="teknium/OpenHermes-2.5",
|
143 |
)
|
144 |
user_query = gr.Textbox("", label="Ask anything...")
|
145 |
|
146 |
btn = gr.Button("Ask 🪄")
|
147 |
|
|
|
148 |
sql_query = gr.Markdown(label="Output SQL Query")
|
149 |
+
df = gr.DataFrame()
|
150 |
+
plot = gr.Plot()
|
151 |
|
152 |
btn.click(
|
153 |
query_dataset,
|
154 |
inputs=[dataset_id, user_query],
|
155 |
+
outputs=[df, sql_query, plot],
|
156 |
)
|
157 |
|
158 |
|