Spaces:
Runtime error
Runtime error
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import HfApi | |
import matplotlib.pyplot as plt | |
from typing import Tuple, Optional | |
import pandas as pd | |
import gradio as gr | |
import duckdb | |
import requests | |
import llama_cpp | |
import instructor | |
import spaces | |
import enum | |
from pydantic import BaseModel, Field | |
BASE_DATASETS_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> | |
view_name = "dataset_view" | |
hf_api = HfApi() | |
conn = duckdb.connect() | |
class OutputTypes(str, enum.Enum): | |
TABLE = "table" | |
BARCHART = "barchart" | |
LINECHART = "linechart" | |
class SQLResponse(BaseModel): | |
sql: str | |
visualization_type: Optional[OutputTypes] = Field( | |
None, description="The type of visualization to display" | |
) | |
data_key: Optional[str] = Field( | |
None, | |
description="The column name from the sql query that contains the data for chart responses", | |
) | |
label_key: Optional[str] = Field( | |
None, | |
description="The column name from the sql query that contains the labels for chart responses", | |
) | |
def get_dataset_ddl(dataset_id: str) -> str: | |
response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}") | |
response.raise_for_status() # Check if the request was successful | |
first_parquet = response.json().get("parquet_files", [])[0] | |
first_parquet_url = first_parquet.get("url") | |
if not first_parquet_url: | |
raise ValueError("No valid URL found for the first parquet file.") | |
conn.execute( | |
f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');" | |
) | |
dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall() | |
column_data_types = ",\n\t".join( | |
[f"{column[1]} {column[2]}" for column in dataset_ddl] | |
) | |
sql_ddl = """ | |
CREATE TABLE {} ( | |
{} | |
); | |
""".format( | |
view_name, column_data_types | |
) | |
return sql_ddl | |
def generate_query(dataset_id: str, query: str) -> dict: | |
llama = llama_cpp.Llama( | |
model_path="Hermes-2-Pro-Llama-3-8B-Q8_0.gguf", | |
n_gpu_layers=50, | |
chat_format="chatml", | |
n_ctx=2048, | |
verbose=False, | |
temperature=0.1, | |
) | |
create = instructor.patch( | |
create=llama.create_chat_completion_openai_v1, | |
mode=instructor.Mode.JSON_SCHEMA, | |
) | |
ddl = get_dataset_ddl(dataset_id) | |
system_prompt = f""" | |
You are an expert SQL assistant with access to the following PostgreSQL Table: | |
```sql | |
{ddl} | |
``` | |
Please assist the user by writing a SQL query that answers the user's question. | |
Use Label Key as the column name for the x-axis and Data Key as the column name for the y-axis for chart responses. The | |
label key and data key must be present in the SQL output. | |
""" | |
print("Calling LLM with system prompt: ", system_prompt) | |
resp: SQLResponse = create( | |
model="Hermes-2-Pro-Llama-3-8B", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{ | |
"role": "user", | |
"content": query, | |
}, | |
], | |
response_model=SQLResponse, | |
) | |
print("Received Response: ", resp) | |
return resp.model_dump() | |
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]: | |
response = generate_query(dataset_id, query) | |
print("Querying Parquet...") | |
df = conn.execute(response.get("sql")).fetchdf() | |
plot = None | |
label_key = response.get("label_key") | |
data_key = response.get("data_key") | |
viz_type = response.get("visualization_type") | |
sql = response.get("sql") | |
# handle incorrect data and label keys | |
if label_key and label_key not in df.columns: | |
label_key = None | |
if data_key and data_key not in df.columns: | |
data_key = None | |
if viz_type == OutputTypes.LINECHART: | |
plot = df.plot(kind="line", x=label_key, y=data_key).get_figure() | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
elif viz_type == OutputTypes.BARCHART: | |
plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure() | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
markdown_output = f"""```sql\n{sql}\n```""" | |
return df, markdown_output, plot | |
with gr.Blocks() as demo: | |
gr.Markdown("# Query your HF Datasets with Natural Language ππ") | |
dataset_id = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Find your favorite dataset...", | |
search_type="dataset", | |
value="gretelai/synthetic_text_to_sql", | |
) | |
user_query = gr.Textbox("", label="Ask anything...") | |
examples = [ | |
["Show me a preview of the data"], | |
["Show me something interesting"], | |
["What is the largest length of sql query context?"], | |
["show me counts by sql_query_type in a bar chart"], | |
] | |
gr.Examples(examples=examples, inputs=[user_query], outputs=[]) | |
btn = gr.Button("Ask πͺ") | |
sql_query = gr.Markdown(label="Output SQL Query") | |
df = gr.DataFrame() | |
plot = gr.Plot() | |
btn.click( | |
query_dataset, | |
inputs=[dataset_id, user_query], | |
outputs=[df, sql_query, plot], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |