File size: 5,864 Bytes
a00be78
7247642
e735a4c
e8c1c43
a00be78
e915c68
 
a00be78
033af05
a00be78
 
13e0d1b
4ed656e
e915c68
e8c1c43
13e0d1b
e915c68
033af05
a00be78
13e0d1b
033af05
a00be78
 
033af05
e735a4c
e8c1c43
 
 
e735a4c
e8c1c43
 
 
 
 
 
 
a00be78
e915c68
 
 
 
 
 
13e0d1b
 
e915c68
 
 
 
9fc2d21
 
e915c68
 
9fc2d21
 
e915c68
44cb622
13e0d1b
 
a00be78
 
 
 
 
 
 
 
 
44cb622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00be78
44cb622
a00be78
 
20998cb
7247642
e735a4c
e8c1c43
 
4ec985f
e8c1c43
7247642
4ec985f
20998cb
4ec985f
 
 
 
 
 
 
 
13e0d1b
91c3a02
13e0d1b
 
e8c1c43
13e0d1b
 
 
 
 
e8c1c43
e915c68
13e0d1b
 
 
 
 
 
 
 
 
 
 
 
e915c68
13e0d1b
467c2a7
13e0d1b
 
e915c68
7247642
 
91c3a02
 
467c2a7
e915c68
 
 
467c2a7
 
 
 
e8c1c43
9fc2d21
467c2a7
 
 
 
 
 
e8c1c43
 
 
467c2a7
 
91c3a02
 
467c2a7
 
91c3a02
 
e915c68
 
13e0d1b
 
a00be78
 
13e0d1b
a00be78
 
 
91c3a02
a00be78
13e0d1b
9fc2d21
 
 
e8c1c43
 
9fc2d21
 
a00be78
 
13e0d1b
 
e915c68
 
a00be78
 
13e0d1b
 
e915c68
a00be78
 
 
 
e735a4c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
from llama_cpp_cuda_tensorcores import Llama
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import pandas as pd
import gradio as gr
import duckdb
import requests
import instructor
import spaces
import enum
import os

from pydantic import BaseModel, Field

BASE_DATASETS_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E
view_name = "dataset_view"

hf_api = HfApi()
conn = duckdb.connect()

gpu_layers = int(os.environ.get("GPU_LAYERS", 0))
draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2))

repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF")
model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf")

hf_hub_download(
    repo_id=repo_id,
    filename=model_file_name,
    local_dir="./models",
)


class OutputTypes(str, enum.Enum):
    TABLE = "table"
    BARCHART = "barchart"
    LINECHART = "linechart"


class SQLResponse(BaseModel):
    sql: str
    visualization_type: Optional[OutputTypes] = Field(
        None, description="The type of visualization to display"
    )
    data_key: Optional[str] = Field(
        None,
        description="The column name from the sql query that contains the data for chart responses",
    )
    label_key: Optional[str] = Field(
        None,
        description="The column name from the sql query that contains the labels for chart responses",
    )


def get_dataset_ddl(dataset_id: str) -> str:
    response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}")
    response.raise_for_status()  # Check if the request was successful

    first_parquet = response.json().get("parquet_files", [])[0]
    first_parquet_url = first_parquet.get("url")

    if not first_parquet_url:
        raise ValueError("No valid URL found for the first parquet file.")

    conn.execute(
        f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');"
    )
    dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall()

    column_data_types = ",\n\t".join(
        [f"{column[1]} {column[2]}" for column in dataset_ddl]
    )

    sql_ddl = """
CREATE TABLE {} (
    {}
);
    """.format(
        view_name, column_data_types
    )

    return sql_ddl


@spaces.GPU(duration=120)
def generate_query(ddl: str, query: str) -> dict:
    llama = Llama(
        model_path=f"models/{model_file_name}",
        n_gpu_layers=gpu_layers,
        chat_format="chatml",
        draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens),
        logits_all=True,
        n_ctx=2048,
        verbose=True,
        temperature=0.1,
    )

    create = instructor.patch(
        create=llama.create_chat_completion_openai_v1,
        mode=instructor.Mode.JSON_SCHEMA,
    )

    system_prompt = f"""
    You are an expert SQL assistant with access to the following PostgreSQL Table:
    
    ```sql
    {ddl.strip()}
    ```
    
    Please assist the user by writing a SQL query that answers the user's question.
    """

    print("Calling LLM with system prompt: ", system_prompt, query)

    resp: SQLResponse = create(
        model="Hermes-2-Pro-Llama-3-8B",
        messages=[
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": query,
            },
        ],
        response_model=SQLResponse,
    )

    print("Received Response: ", resp)

    return resp.model_dump()


def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
    ddl = get_dataset_ddl(dataset_id)
    response = generate_query(ddl, query)

    print("Querying Parquet...")
    df = conn.execute(response.get("sql")).fetchdf()

    plot = None

    label_key = response.get("label_key")
    data_key = response.get("data_key")
    viz_type = response.get("visualization_type")
    sql = response.get("sql")
    markdown_output = f"""```sql\n{sql}\n```"""

    # handle incorrect data and label keys
    if label_key and label_key not in df.columns:
        label_key = None
    if data_key and data_key not in df.columns:
        data_key = None

    if df.empty:
        return df, f"```sql\n{sql}\n```", plot

    if viz_type == OutputTypes.LINECHART:
        plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
    elif viz_type == OutputTypes.BARCHART:
        plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()

    return df, markdown_output, plot


with gr.Blocks() as demo:
    gr.Markdown("# Query your HF Datasets with Natural Language πŸ“ˆπŸ“Š")
    dataset_id = HuggingfaceHubSearch(
        label="Hub Dataset ID",
        placeholder="Find your favorite dataset...",
        search_type="dataset",
        value="gretelai/synthetic_text_to_sql",
    )
    user_query = gr.Textbox("", label="Ask anything...")
    examples = [
        ["Show me a preview of the data"],
        ["Show me something interesting"],
        ["Which row has longest description length?"],
        ["find the average length of sql query context"],
    ]
    gr.Examples(examples=examples, inputs=[user_query], outputs=[])

    btn = gr.Button("Ask πŸͺ„")

    sql_query = gr.Markdown(label="Output SQL Query")
    df = gr.DataFrame()
    plot = gr.Plot()

    btn.click(
        query_dataset,
        inputs=[dataset_id, user_query],
        outputs=[df, sql_query, plot],
    )


if __name__ == "__main__":
    demo.launch(
        show_error=True,
        quiet=False,
        debug=True,
    )