|
import gradio as gr |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
import nbformat as nbf |
|
from huggingface_hub import HfApi |
|
from httpx import Client |
|
import logging |
|
from huggingface_hub import InferenceClient |
|
import json |
|
import re |
|
import pandas as pd |
|
from gradio.data_classes import FileData |
|
from utils.prompts import ( |
|
generate_mapping_prompt, |
|
generate_user_prompt, |
|
generate_rag_system_prompt, |
|
generate_eda_system_prompt, |
|
generate_embedding_system_prompt, |
|
) |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
""" |
|
TODOs: |
|
- Need feedback on the output commands to validate if operations are appropiate to data types |
|
- Refactor |
|
- Make the notebook generation more dynamic, add loading components to do not freeze the UI |
|
- Fix errors: |
|
- When generating output |
|
- When parsing output |
|
- When pushing notebook |
|
- Add target tasks to choose for the notebook: |
|
- Exploratory data analysis |
|
- Auto training |
|
- RAG |
|
- etc. |
|
- Enable 'generate notebook' button only if dataset is available and supports library |
|
- First get compatible-libraries and let user choose the library |
|
""" |
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
NOTEBOOKS_REPOSITORY = os.getenv("NOTEBOOKS_REPOSITORY") |
|
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" |
|
assert ( |
|
NOTEBOOKS_REPOSITORY is not None |
|
), "You need to set NOTEBOOKS_REPOSITORY in your environment variables" |
|
|
|
|
|
BASE_DATASETS_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> |
|
HEADERS = {"Accept": "application/json", "Content-Type": "application/json"} |
|
|
|
client = Client(headers=HEADERS) |
|
inference_client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def get_compatible_libraries(dataset: str): |
|
try: |
|
response = client.get( |
|
f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}" |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
except Exception as e: |
|
logging.error(f"Error fetching compatible libraries: {e}") |
|
raise |
|
|
|
|
|
def create_notebook_file(cell_commands, notebook_name): |
|
nb = nbf.v4.new_notebook() |
|
nb["cells"] = [ |
|
nbf.v4.new_code_cell( |
|
cmd["source"] |
|
if isinstance(cmd["source"], str) |
|
else "\n".join(cmd["source"]) |
|
) |
|
if cmd["cell_type"] == "code" |
|
else nbf.v4.new_markdown_cell(cmd["source"]) |
|
for cmd in cell_commands |
|
] |
|
|
|
with open(notebook_name, "w") as f: |
|
nbf.write(nb, f) |
|
logging.info(f"Notebook {notebook_name} created successfully") |
|
|
|
|
|
def get_first_rows_as_df(dataset: str, config: str, split: str, limit: int): |
|
try: |
|
resp = client.get( |
|
f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}" |
|
) |
|
resp.raise_for_status() |
|
content = resp.json() |
|
rows = content["rows"] |
|
rows = [row["row"] for row in rows] |
|
first_rows_df = pd.DataFrame.from_dict(rows).sample(frac=1).head(limit) |
|
features = content["features"] |
|
features_dict = {feature["name"]: feature["type"] for feature in features} |
|
return features_dict, first_rows_df |
|
except Exception as e: |
|
logging.error(f"Error fetching first rows: {e}") |
|
raise |
|
|
|
|
|
def get_txt_from_output(output): |
|
extracted_text = extract_content_from_output(output) |
|
logging.info("--> Extracted text between json block") |
|
logging.info(extracted_text) |
|
content = json.loads(extracted_text) |
|
return content |
|
|
|
|
|
def extract_content_from_output(output): |
|
patterns = [r"`json(.*?)`", r"```(.*?)```"] |
|
|
|
for pattern in patterns: |
|
match = re.search(pattern, output, re.DOTALL) |
|
if match: |
|
return match.group(1) |
|
|
|
try: |
|
index = output.index("```json") |
|
logging.info(f"Index: {index}") |
|
return output[index + 7 :] |
|
except ValueError: |
|
logging.error("Unable to generate Jupyter notebook.") |
|
raise |
|
|
|
|
|
def content_from_output(output): |
|
pattern = r"`json(.*?)`" |
|
match = re.search(pattern, output, re.DOTALL) |
|
if not match: |
|
pattern = r"```(.*?)```" |
|
match = re.search(pattern, output, re.DOTALL) |
|
if not match: |
|
try: |
|
index = output.index("```json") |
|
logging.info(f"Index: {index}") |
|
return output[index + 7 :] |
|
except: |
|
pass |
|
raise Exception("Unable to generate jupyter notebook.") |
|
return match.group(1) |
|
|
|
|
|
def generate_eda_cells(dataset_id): |
|
for messages in generate_cells(dataset_id, generate_eda_system_prompt, "eda"): |
|
yield messages, None |
|
|
|
yield ( |
|
messages, |
|
f"{dataset_id.replace('/', '-')}-eda.ipynb", |
|
) |
|
|
|
|
|
def generate_rag_cells(dataset_id): |
|
for messages in generate_cells(dataset_id, generate_rag_system_prompt, "rag"): |
|
yield messages, None |
|
|
|
yield ( |
|
messages, |
|
f"{dataset_id.replace('/', '-')}-rag.ipynb", |
|
) |
|
|
|
|
|
def generate_embedding_cells(dataset_id): |
|
for messages in generate_cells( |
|
dataset_id, generate_embedding_system_prompt, "embedding" |
|
): |
|
yield messages, None |
|
|
|
yield ( |
|
messages, |
|
f"{dataset_id.replace('/', '-')}-embedding.ipynb", |
|
) |
|
|
|
|
|
def _push_to_hub( |
|
history, |
|
dataset_id, |
|
notebook_file, |
|
): |
|
logging.info(f"Pushing notebook to hub: {dataset_id} on file {notebook_file}") |
|
|
|
notebook_name = notebook_file.split("/")[-1] |
|
api = HfApi(token=HF_TOKEN) |
|
try: |
|
logging.info(f"About to push {notebook_file} - {dataset_id}") |
|
api.upload_file( |
|
path_or_fileobj=notebook_file, |
|
path_in_repo=notebook_name, |
|
repo_id=NOTEBOOKS_REPOSITORY, |
|
repo_type="dataset", |
|
) |
|
link = f"https://huggingface.co/datasets/{NOTEBOOKS_REPOSITORY}/blob/main/{notebook_name}" |
|
logging.info(f"Notebook pushed to hub: {link}") |
|
yield history + [ |
|
gr.ChatMessage( |
|
role="user", |
|
content=f"[{notebook_name}]({link})", |
|
) |
|
] |
|
|
|
except Exception as e: |
|
logging.info("Failed to push notebook", e) |
|
yield history + [gr.ChatMessage(role="assistant", content=e)] |
|
|
|
|
|
def generate_cells(dataset_id, prompt_fn, notebook_type="eda"): |
|
try: |
|
libraries = get_compatible_libraries(dataset_id) |
|
except Exception as err: |
|
gr.Error("Unable to retrieve dataset info from HF Hub.") |
|
logging.error(f"Failed to fetch compatible libraries: {err}") |
|
return [] |
|
|
|
if not libraries: |
|
gr.Error("Dataset not compatible with pandas library.") |
|
logging.error(f"Dataset not compatible with pandas library") |
|
return gr.File(visible=False), gr.Row.update(visible=False) |
|
|
|
pandas_library = next( |
|
(lib for lib in libraries.get("libraries", []) if lib["library"] == "pandas"), |
|
None, |
|
) |
|
if not pandas_library: |
|
gr.Error("Dataset not compatible with pandas library.") |
|
return [] |
|
|
|
first_config_loading_code = pandas_library["loading_codes"][0] |
|
first_code = first_config_loading_code["code"] |
|
first_config = first_config_loading_code["config_name"] |
|
first_split = list(first_config_loading_code["arguments"]["splits"].keys())[0] |
|
features, df = get_first_rows_as_df(dataset_id, first_config, first_split, 3) |
|
prompt = generate_user_prompt( |
|
features, df.head(5).to_dict(orient="records"), first_code |
|
) |
|
messages = [gr.ChatMessage(role="user", content=prompt)] |
|
yield messages + [gr.ChatMessage(role="assistant", content="⏳ _Starting task..._")] |
|
|
|
prompt_messages = [ |
|
{"role": "system", "content": prompt_fn()}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
output = inference_client.chat_completion( |
|
messages=prompt_messages, |
|
stream=True, |
|
max_tokens=2500, |
|
top_p=0.8, |
|
seed=42, |
|
) |
|
|
|
generated_text = "" |
|
current_line = "" |
|
for chunk in output: |
|
current_line += chunk.choices[0].delta.content |
|
if current_line.endswith("\n"): |
|
generated_text += current_line |
|
messages.append(gr.ChatMessage(role="assistant", content=current_line)) |
|
current_line = "" |
|
yield messages |
|
yield messages |
|
|
|
logging.info("---> Notebook markdown code output") |
|
logging.info(generated_text) |
|
|
|
retries = 0 |
|
retry_limit = 3 |
|
while retries < retry_limit: |
|
try: |
|
formatted_prompt = generate_mapping_prompt(generated_text) |
|
prompt_messages = [{"role": "user", "content": formatted_prompt}] |
|
yield messages + [ |
|
gr.ChatMessage(role="assistant", content="⏳ _Generating notebook..._") |
|
] |
|
|
|
output = inference_client.chat_completion( |
|
messages=prompt_messages, |
|
stream=False, |
|
max_tokens=2500, |
|
top_p=0.8, |
|
seed=42, |
|
) |
|
cells_txt = output.choices[0].message.content |
|
logging.info(f"---> Mapping to json output attempt {retries}") |
|
logging.info(cells_txt) |
|
commands = get_txt_from_output(cells_txt) |
|
break |
|
except Exception as e: |
|
logging.warn("Error when parsing output, retrying ..") |
|
retries += 1 |
|
if retries == retry_limit: |
|
logging.error(f"Unable to parse output after {retry_limit} retries") |
|
gr.Error("Unable to generate notebook. Try again please") |
|
raise e |
|
|
|
html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>" |
|
|
|
commands.insert( |
|
0, |
|
{ |
|
"cell_type": "code", |
|
"source": f'from IPython.display import HTML\n\ndisplay(HTML("{html_code}"))', |
|
}, |
|
) |
|
commands.insert(0, {"cell_type": "markdown", "source": "# Dataset Viewer"}) |
|
notebook_name = f"{dataset_id.replace('/', '-')}-{notebook_type}.ipynb" |
|
create_notebook_file(commands, notebook_name=notebook_name) |
|
messages.append( |
|
gr.ChatMessage(role="user", content="See the generated notebook on the Hub") |
|
) |
|
yield messages |
|
yield from _push_to_hub(messages, dataset_id, notebook_name) |
|
|
|
|
|
def coming_soon_message(): |
|
return gr.Info("Coming soon") |
|
|
|
|
|
def handle_example(example, button_action): |
|
return button_action(example) |
|
|
|
|
|
with gr.Blocks(fill_width=True) as demo: |
|
gr.Markdown("# 🤖 Dataset notebook creator 🕵️") |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=2): |
|
text_input = gr.Textbox(label="Suggested notebook type", visible=False) |
|
|
|
dataset_name = HuggingfaceHubSearch( |
|
label="Hub Dataset ID", |
|
placeholder="Search for dataset id on Huggingface", |
|
search_type="dataset", |
|
value="", |
|
) |
|
|
|
dataset_samples = gr.Examples( |
|
examples=[ |
|
[ |
|
"infinite-dataset-hub/WorldPopCounts", |
|
"Try this dataset for Exploratory Data Analysis", |
|
], |
|
[ |
|
"infinite-dataset-hub/GlobaleCuisineRecipes", |
|
"Try this dataset for Embeddings generation", |
|
], |
|
[ |
|
"infinite-dataset-hub/GlobalBestSellersSummaries", |
|
"Try this dataset for RAG generation", |
|
], |
|
], |
|
inputs=[dataset_name, text_input], |
|
cache_examples=False, |
|
) |
|
|
|
@gr.render(inputs=dataset_name) |
|
def embed(name): |
|
if not name: |
|
return gr.Markdown("### No dataset provided") |
|
html_code = f""" |
|
<iframe |
|
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train" |
|
frameborder="0" |
|
width="100%" |
|
height="350px" |
|
></iframe> |
|
""" |
|
return gr.HTML(value=html_code) |
|
|
|
with gr.Row(): |
|
generate_eda_btn = gr.Button("Exploratory Data Analysis") |
|
generate_embedding_btn = gr.Button("Embeddings") |
|
generate_rag_btn = gr.Button("RAG") |
|
generate_training_btn = gr.Button( |
|
"Training - Coming soon", interactive=False |
|
) |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot( |
|
label="Results", |
|
type="messages", |
|
height=650, |
|
avatar_images=( |
|
None, |
|
None, |
|
), |
|
) |
|
|
|
notebook_file = gr.File(visible=False) |
|
generate_eda_btn.click( |
|
generate_eda_cells, |
|
inputs=[dataset_name], |
|
outputs=[chatbot, notebook_file], |
|
) |
|
|
|
generate_embedding_btn.click( |
|
generate_embedding_cells, |
|
inputs=[dataset_name], |
|
outputs=[chatbot, notebook_file], |
|
) |
|
|
|
generate_rag_btn.click( |
|
generate_rag_cells, |
|
inputs=[dataset_name], |
|
outputs=[chatbot, notebook_file], |
|
) |
|
|
|
generate_training_btn.click(coming_soon_message, inputs=[], outputs=[]) |
|
|
|
|
|
demo.launch() |
|
|