davidberenstein1957 HF staff commited on
Commit
adc79ce
·
1 Parent(s): 54d4d8d

feat: Add support for textcat

Browse files
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,4 +1,205 @@
1
- from src.distilabel_dataset_generator.utils import get_base_app
2
 
3
- with get_base_app() as app:
4
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
 
3
+ import gradio as gr
4
+ import pandas as pd
5
+
6
+ from src.distilabel_dataset_generator.apps.base import (
7
+ get_main_ui,
8
+ get_pipeline_code_ui,
9
+ hide_success_message,
10
+ push_dataset_to_hub,
11
+ push_pipeline_code_to_hub,
12
+ show_success_message_argilla,
13
+ show_success_message_hub,
14
+ validate_argilla_user_workspace_dataset,
15
+ )
16
+ from src.distilabel_dataset_generator.pipelines.textcat import (
17
+ DEFAULT_DATASET_DESCRIPTIONS,
18
+ DEFAULT_DATASETS,
19
+ DEFAULT_SYSTEM_PROMPTS,
20
+ generate_pipeline_code,
21
+ )
22
+
23
+
24
+ def push_dataset_to_argilla(dataset: pd.DataFrame, dataset_name: str) -> pd.DataFrame:
25
+ return dataset
26
+
27
+
28
+ def generate_system_prompt(dataset_description: str) -> str:
29
+ return dataset_description
30
+
31
+
32
+ def generate_dataset(
33
+ system_prompt: str, labels: List[str], multi_label: bool
34
+ ) -> pd.DataFrame:
35
+ return pd.DataFrame({"prompt": [system_prompt], "completion": [system_prompt]})
36
+
37
+
38
+ (
39
+ app,
40
+ main_ui,
41
+ custom_input_ui,
42
+ dataset_description,
43
+ examples,
44
+ btn_generate_system_prompt,
45
+ system_prompt,
46
+ sample_dataset,
47
+ btn_generate_sample_dataset,
48
+ dataset_name,
49
+ add_to_existing_dataset,
50
+ btn_generate_full_dataset_copy,
51
+ btn_generate_and_push_to_argilla,
52
+ btn_push_to_argilla,
53
+ org_name,
54
+ repo_name,
55
+ private,
56
+ btn_generate_full_dataset,
57
+ btn_generate_and_push_to_hub,
58
+ btn_push_to_hub,
59
+ final_dataset,
60
+ success_message,
61
+ ) = get_main_ui(
62
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
63
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
64
+ default_datasets=DEFAULT_DATASETS,
65
+ fn_generate_system_prompt=generate_system_prompt,
66
+ fn_generate_dataset=generate_dataset,
67
+ )
68
+
69
+ with app:
70
+ with main_ui:
71
+ with custom_input_ui:
72
+ labels = gr.Dropdown(
73
+ choices=[],
74
+ allow_custom_value=True,
75
+ interactive=True,
76
+ label="Labels",
77
+ multiselect=True,
78
+ )
79
+ num_labels = gr.Number(
80
+ label="Number of labels", value=2, minimum=1, maximum=10
81
+ )
82
+ num_rows = gr.Number(
83
+ label="Number of rows", value=10, minimum=1, maximum=500
84
+ )
85
+
86
+ pipeline_code = get_pipeline_code_ui(
87
+ generate_pipeline_code(system_prompt.value, labels.value, multi_label.value)
88
+ )
89
+
90
+ # define app triggers
91
+ gr.on(
92
+ triggers=[
93
+ btn_generate_full_dataset.click,
94
+ btn_generate_full_dataset_copy.click,
95
+ ],
96
+ fn=hide_success_message,
97
+ outputs=[success_message],
98
+ ).then(
99
+ fn=generate_dataset,
100
+ inputs=[system_prompt, labels, multi_label],
101
+ outputs=[final_dataset],
102
+ show_progress=True,
103
+ )
104
+
105
+ btn_generate_and_push_to_argilla.click(
106
+ fn=validate_argilla_user_workspace_dataset,
107
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
108
+ outputs=[final_dataset],
109
+ show_progress=True,
110
+ ).success(
111
+ fn=hide_success_message,
112
+ outputs=[success_message],
113
+ ).success(
114
+ fn=generate_dataset,
115
+ inputs=[system_prompt, labels, multi_label],
116
+ outputs=[final_dataset],
117
+ show_progress=True,
118
+ ).success(
119
+ fn=push_dataset_to_argilla,
120
+ inputs=[final_dataset, dataset_name],
121
+ outputs=[final_dataset],
122
+ show_progress=True,
123
+ ).success(
124
+ fn=show_success_message_argilla,
125
+ inputs=[],
126
+ outputs=[success_message],
127
+ )
128
+
129
+ btn_generate_and_push_to_hub.click(
130
+ fn=hide_success_message,
131
+ outputs=[success_message],
132
+ ).then(
133
+ fn=generate_dataset,
134
+ inputs=[system_prompt, labels, multi_label],
135
+ outputs=[final_dataset],
136
+ show_progress=True,
137
+ ).then(
138
+ fn=push_dataset_to_hub,
139
+ inputs=[final_dataset, private, org_name, repo_name],
140
+ outputs=[final_dataset],
141
+ show_progress=True,
142
+ ).then(
143
+ fn=push_pipeline_code_to_hub,
144
+ inputs=[pipeline_code, org_name, repo_name],
145
+ outputs=[],
146
+ show_progress=True,
147
+ ).success(
148
+ fn=show_success_message_hub,
149
+ inputs=[org_name, repo_name],
150
+ outputs=[success_message],
151
+ )
152
+
153
+ btn_push_to_hub.click(
154
+ fn=hide_success_message,
155
+ outputs=[success_message],
156
+ ).then(
157
+ fn=push_dataset_to_hub,
158
+ inputs=[final_dataset, private, org_name, repo_name],
159
+ outputs=[final_dataset],
160
+ show_progress=True,
161
+ ).then(
162
+ fn=push_pipeline_code_to_hub,
163
+ inputs=[pipeline_code, org_name, repo_name],
164
+ outputs=[],
165
+ show_progress=True,
166
+ ).success(
167
+ fn=show_success_message_hub,
168
+ inputs=[org_name, repo_name],
169
+ outputs=[success_message],
170
+ )
171
+
172
+ btn_push_to_argilla.click(
173
+ fn=hide_success_message,
174
+ outputs=[success_message],
175
+ ).success(
176
+ fn=validate_argilla_user_workspace_dataset,
177
+ inputs=[dataset_name, final_dataset, add_to_existing_dataset],
178
+ outputs=[final_dataset],
179
+ show_progress=True,
180
+ ).success(
181
+ fn=push_dataset_to_argilla,
182
+ inputs=[final_dataset, dataset_name],
183
+ outputs=[final_dataset],
184
+ show_progress=True,
185
+ ).success(
186
+ fn=show_success_message_argilla,
187
+ inputs=[],
188
+ outputs=[success_message],
189
+ )
190
+
191
+ system_prompt.change(
192
+ fn=generate_pipeline_code,
193
+ inputs=[system_prompt, labels, multi_label],
194
+ outputs=[pipeline_code],
195
+ )
196
+ labels.change(
197
+ fn=generate_pipeline_code,
198
+ inputs=[system_prompt, labels, multi_label],
199
+ outputs=[pipeline_code],
200
+ )
201
+ multi_label.change(
202
+ fn=generate_pipeline_code,
203
+ inputs=[system_prompt, labels, multi_label],
204
+ outputs=[pipeline_code],
205
+ )
src/distilabel_dataset_generator/pipelines/textcat.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+
5
+ DEFAULT_DATASET_DESCRIPTIONS = [
6
+ "A dataset covering customer reviews for an e-commerce website.",
7
+ "A dataset covering news articles about various topics.",
8
+ ]
9
+
10
+ DEFAULT_DATASETS = [
11
+ pd.DataFrame.from_dict(
12
+ {
13
+ "text": [
14
+ "I love the product! It's amazing and I'll buy it again.",
15
+ "The product was okay, but I wouldn't buy it again.",
16
+ ],
17
+ "label": ["positive", "negative"],
18
+ }
19
+ ),
20
+ pd.DataFrame.from_dict(
21
+ {
22
+ "text": [
23
+ "Yesterday, the US stock market had a significant increase.",
24
+ "New research suggests that the Earth is not a perfect sphere.",
25
+ ],
26
+ "label": [["economy", "politics"], ["science", "environment"]],
27
+ }
28
+ ),
29
+ ]
30
+
31
+ DEFAULT_SYSTEM_PROMPTS = [
32
+ "Classify the following customer review as positive or negative.",
33
+ "Classify the following news article into one or more categories.",
34
+ ]
35
+
36
+
37
+ def generate_pipeline_code(
38
+ system_prompt: str, labels: List[str], multi_label: bool
39
+ ) -> str:
40
+ return """
41
+ from distilabel import Distilabel
42
+
43
+ #### PIPELINE CODE HERE
44
+ """