datasets-ai / app.py
Caleb Fahlgren
use correct llama
e735a4c
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
from llama_cpp_cuda_tensorcores import Llama
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import pandas as pd
import gradio as gr
import duckdb
import requests
import instructor
import spaces
import enum
import os
from pydantic import BaseModel, Field
BASE_DATASETS_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END -->
view_name = "dataset_view"
hf_api = HfApi()
conn = duckdb.connect()
gpu_layers = int(os.environ.get("GPU_LAYERS", 0))
draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2))
repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF")
model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf")
hf_hub_download(
repo_id=repo_id,
filename=model_file_name,
local_dir="./models",
)
class OutputTypes(str, enum.Enum):
TABLE = "table"
BARCHART = "barchart"
LINECHART = "linechart"
class SQLResponse(BaseModel):
sql: str
visualization_type: Optional[OutputTypes] = Field(
None, description="The type of visualization to display"
)
data_key: Optional[str] = Field(
None,
description="The column name from the sql query that contains the data for chart responses",
)
label_key: Optional[str] = Field(
None,
description="The column name from the sql query that contains the labels for chart responses",
)
def get_dataset_ddl(dataset_id: str) -> str:
response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}")
response.raise_for_status() # Check if the request was successful
first_parquet = response.json().get("parquet_files", [])[0]
first_parquet_url = first_parquet.get("url")
if not first_parquet_url:
raise ValueError("No valid URL found for the first parquet file.")
conn.execute(
f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');"
)
dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall()
column_data_types = ",\n\t".join(
[f"{column[1]} {column[2]}" for column in dataset_ddl]
)
sql_ddl = """
CREATE TABLE {} (
{}
);
""".format(
view_name, column_data_types
)
return sql_ddl
@spaces.GPU(duration=120)
def generate_query(ddl: str, query: str) -> dict:
llama = Llama(
model_path=f"models/{model_file_name}",
n_gpu_layers=gpu_layers,
chat_format="chatml",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens),
logits_all=True,
n_ctx=2048,
verbose=True,
temperature=0.1,
)
create = instructor.patch(
create=llama.create_chat_completion_openai_v1,
mode=instructor.Mode.JSON_SCHEMA,
)
system_prompt = f"""
You are an expert SQL assistant with access to the following PostgreSQL Table:
```sql
{ddl.strip()}
```
Please assist the user by writing a SQL query that answers the user's question.
"""
print("Calling LLM with system prompt: ", system_prompt, query)
resp: SQLResponse = create(
model="Hermes-2-Pro-Llama-3-8B",
messages=[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": query,
},
],
response_model=SQLResponse,
)
print("Received Response: ", resp)
return resp.model_dump()
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
ddl = get_dataset_ddl(dataset_id)
response = generate_query(ddl, query)
print("Querying Parquet...")
df = conn.execute(response.get("sql")).fetchdf()
plot = None
label_key = response.get("label_key")
data_key = response.get("data_key")
viz_type = response.get("visualization_type")
sql = response.get("sql")
markdown_output = f"""```sql\n{sql}\n```"""
# handle incorrect data and label keys
if label_key and label_key not in df.columns:
label_key = None
if data_key and data_key not in df.columns:
data_key = None
if df.empty:
return df, f"```sql\n{sql}\n```", plot
if viz_type == OutputTypes.LINECHART:
plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
elif viz_type == OutputTypes.BARCHART:
plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
return df, markdown_output, plot
with gr.Blocks() as demo:
gr.Markdown("# Query your HF Datasets with Natural Language πŸ“ˆπŸ“Š")
dataset_id = HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Find your favorite dataset...",
search_type="dataset",
value="gretelai/synthetic_text_to_sql",
)
user_query = gr.Textbox("", label="Ask anything...")
examples = [
["Show me a preview of the data"],
["Show me something interesting"],
["Which row has longest description length?"],
["find the average length of sql query context"],
]
gr.Examples(examples=examples, inputs=[user_query], outputs=[])
btn = gr.Button("Ask πŸͺ„")
sql_query = gr.Markdown(label="Output SQL Query")
df = gr.DataFrame()
plot = gr.Plot()
btn.click(
query_dataset,
inputs=[dataset_id, user_query],
outputs=[df, sql_query, plot],
)
if __name__ == "__main__":
demo.launch(
show_error=True,
quiet=False,
debug=True,
)