File size: 4,663 Bytes
5ea2a69
0d628a0
5ea2a69
 
0d628a0
76b8fa2
fecada4
 
 
 
0d628a0
fecada4
5ea2a69
0d628a0
5ea2a69
e73d501
5ea2a69
0d628a0
5ea2a69
 
 
0d628a0
 
5ea2a69
 
 
0d628a0
1480aa8
 
0d628a0
 
 
 
 
 
 
5ea2a69
1480aa8
 
 
0d628a0
5ea2a69
0d628a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1480aa8
0d628a0
 
1480aa8
 
 
0d628a0
 
 
 
 
5ea2a69
0d628a0
 
1480aa8
5ea2a69
0d628a0
 
 
 
 
 
 
 
 
 
5ea2a69
 
 
 
1480aa8
0d628a0
 
1480aa8
0d628a0
 
 
 
 
 
 
 
 
 
 
 
 
5ea2a69
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import shutil
import gradio as gr
from transformers import ReactCodeAgent, HfEngine, Tool
import pandas as pd

from gradio_agentchatbot import (
    AgentChatbot,
    stream_from_transformers_agent,
    ChatMessage,
    ChatFileMessage,
)
from huggingface_hub import login
from gradio.data_classes import FileData

login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))

llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")

agent = ReactCodeAgent(
    tools=[],
    llm_engine=llm_engine,
    additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"],
    max_iterations=10,
)

base_prompt = """You are an expert data analyst.
Please load the source file with pandas (you cannot use 'os' module).
According to the features you have and the dta structure given below, determine which feature should be the target.
Then list 3 interesting questions that could be asked on this data, for instance about specific correlations with target variable.
Then answer these questions one by one, by finding the relevant numbers.
Meanwhile, plot some figures using matplotlib/seaborn and save them to the (already existing) folder './figures/': take care to clear each figure with plt.clf() before doing another plot.

In your final answer: summarize these correlations and trends
After each number derive real worlds insights, for instance: "Correlation between is_december and boredness is 1.3453, which suggest people are more bored in winter".
Your final answer should be a long string with at least 3 numbered and detailed parts.

Source file for the data = {source_file}
Structure of the data:
{structure_notes}
"""

example_notes="""This data is about the Titanic wreck in 1912.
The target figure is the survival of passengers, notes by 'Survived'
pclass: A proxy for socio-economic status (SES)
1st = Upper
2nd = Middle
3rd = Lower
age: Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5
sibsp: The dataset defines family relations in this way...
Sibling = brother, sister, stepbrother, stepsister
Spouse = husband, wife (mistresses and fiancés were ignored)
parch: The dataset defines family relations in this way...
Parent = mother, father
Child = daughter, son, stepdaughter, stepson
Some children travelled only with a nanny, therefore parch=0 for them."""

def get_images_in_directory(directory):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}

    image_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if os.path.splitext(file)[1].lower() in image_extensions:
                image_files.append(os.path.join(root, file))
    return image_files

def interact_with_agent(file_input, additional_notes):
    shutil.rmtree("./figures")
    os.makedirs("./figures")

    read_file = pd.read_csv(file_input)
    data_structure_notes = f"""- Description (output of .describe()):
    {read_file.describe()}
    - Columns with dtypes:
    {read_file.dtypes}"""

    prompt = base_prompt.format(source_file=file_input, structure_notes=data_structure_notes)

    if additional_notes and len(additional_notes) > 0:
        prompt += "\nAdditional notes on the data:\n" + additional_notes

    messages = [ChatMessage(role="user", content=prompt, thought=True)]
    yield messages

    plot_image_paths = {}
    for msg in stream_from_transformers_agent(agent, prompt):
        messages.append(msg)
        for image_path in get_images_in_directory("./figures"):
            if image_path not in plot_image_paths:
                image_message = ChatFileMessage(
                    role="assistant",
                    file=FileData(path=image_path, mime_type="image/png"),
                    content="",
                    thought=True,
                )
                plot_image_paths[image_path] = True
                messages.append(image_message)
        yield messages
    yield messages


with gr.Blocks() as demo:
    gr.Markdown("""# Llama-3.1 Data analyst

Drop a `.csv` file below, add notes to describe this data if needed, and **Llama-3.1-70B will analyze the file content and draw figures for you!**""")
    file_input = gr.File(label="Your file to analyze")
    text_input = gr.Textbox(
        label="Additional notes to support the analysis"
    )
    submit = gr.Button("Run analysis!")
    chatbot = AgentChatbot(label="Agent")
    gr.Examples(
        examples=[["./example/titanic.csv", example_notes]],
        inputs=[file_input, text_input],
        cache_examples=False
    )

    submit.click(interact_with_agent, [file_input, text_input], [chatbot])

if __name__ == "__main__":
    demo.launch()