davidberenstein1957 HF staff commited on
Commit
54d4d8d
·
1 Parent(s): 2e2beb7

refactor: re-usable gradio component

Browse files
app.py CHANGED
@@ -26,8 +26,8 @@ css = """
26
  """
27
 
28
  demo = gr.TabbedInterface(
29
- [sft_app, textcat_app, faq_app],
30
- ["Supervised Fine-Tuning", "Text Classification", "FAQ"],
31
  css=css,
32
  title="""
33
  <style>
 
26
  """
27
 
28
  demo = gr.TabbedInterface(
29
+ [textcat_app, sft_app, faq_app],
30
+ ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
31
  css=css,
32
  title="""
33
  <style>
src/distilabel_dataset_generator/apps/base.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import uuid
3
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
+
5
+ import argilla as rg
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from datasets import Dataset
9
+ from distilabel.distiset import Distiset
10
+ from gradio import OAuthToken
11
+ from huggingface_hub import HfApi, upload_file
12
+
13
+ from src.distilabel_dataset_generator.utils import (
14
+ _LOGGED_OUT_CSS,
15
+ get_argilla_client,
16
+ list_orgs,
17
+ )
18
+
19
+
20
+ def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
21
+ if oauth_token:
22
+ return gr.update(elem_classes=["main_ui_logged_in"])
23
+ else:
24
+ return gr.update(elem_classes=["main_ui_logged_out"])
25
+
26
+
27
+ def get_main_ui(
28
+ default_dataset_descriptions: List[str],
29
+ default_system_prompts: List[str],
30
+ default_datasets: List[pd.DataFrame],
31
+ fn_generate_system_prompt: Callable,
32
+ fn_generate_dataset: Callable,
33
+ ):
34
+ def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
35
+ if system_prompt in default_system_prompts:
36
+ index = default_system_prompts.index(system_prompt)
37
+ if index < len(default_datasets):
38
+ return default_datasets[index]
39
+ result = fn_generate_dataset(
40
+ system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
41
+ )
42
+ return result
43
+
44
+ with gr.Blocks(
45
+ title="🧬 Synthetic Data Generator",
46
+ head="🧬 Synthetic Data Generator",
47
+ css=_LOGGED_OUT_CSS,
48
+ ) as app:
49
+ with gr.Row():
50
+ gr.Markdown(
51
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
52
+ )
53
+ with gr.Row():
54
+ gr.Column()
55
+ get_login_button()
56
+ gr.Column()
57
+
58
+ gr.Markdown("## Iterate on a sample dataset")
59
+ with gr.Column() as main_ui:
60
+ (
61
+ dataset_description,
62
+ examples,
63
+ btn_generate_system_prompt,
64
+ system_prompt,
65
+ sample_dataset,
66
+ btn_generate_sample_dataset,
67
+ ) = get_iterate_on_sample_dataset_ui(
68
+ default_dataset_descriptions=default_dataset_descriptions,
69
+ default_system_prompts=default_system_prompts,
70
+ default_datasets=default_datasets,
71
+ )
72
+ gr.Markdown("## Generate full dataset")
73
+ gr.Markdown(
74
+ "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
75
+ )
76
+ with gr.Row(variant="panel") as custom_input_ui:
77
+ pass
78
+
79
+ (
80
+ dataset_name,
81
+ add_to_existing_dataset,
82
+ btn_generate_full_dataset_copy,
83
+ btn_generate_and_push_to_argilla,
84
+ btn_push_to_argilla,
85
+ org_name,
86
+ repo_name,
87
+ private,
88
+ btn_generate_full_dataset,
89
+ btn_generate_and_push_to_hub,
90
+ btn_push_to_hub,
91
+ final_dataset,
92
+ success_message,
93
+ ) = get_push_to_hub_ui(default_datasets)
94
+
95
+ sample_dataset.change(
96
+ fn=lambda x: x,
97
+ inputs=[sample_dataset],
98
+ outputs=[final_dataset],
99
+ )
100
+
101
+ btn_generate_system_prompt.click(
102
+ fn=fn_generate_system_prompt,
103
+ inputs=[dataset_description],
104
+ outputs=[system_prompt],
105
+ show_progress=True,
106
+ ).then(
107
+ fn=fn_generate_sample_dataset,
108
+ inputs=[system_prompt],
109
+ outputs=[sample_dataset],
110
+ show_progress=True,
111
+ )
112
+
113
+ btn_generate_sample_dataset.click(
114
+ fn=fn_generate_sample_dataset,
115
+ inputs=[system_prompt],
116
+ outputs=[sample_dataset],
117
+ show_progress=True,
118
+ )
119
+
120
+ app.load(fn=swap_visibilty, outputs=main_ui)
121
+ app.load(get_org_dropdown, outputs=[org_name])
122
+
123
+ return (
124
+ app,
125
+ main_ui,
126
+ custom_input_ui,
127
+ dataset_description,
128
+ examples,
129
+ btn_generate_system_prompt,
130
+ system_prompt,
131
+ sample_dataset,
132
+ btn_generate_sample_dataset,
133
+ dataset_name,
134
+ add_to_existing_dataset,
135
+ btn_generate_full_dataset_copy,
136
+ btn_generate_and_push_to_argilla,
137
+ btn_push_to_argilla,
138
+ org_name,
139
+ repo_name,
140
+ private,
141
+ btn_generate_full_dataset,
142
+ btn_generate_and_push_to_hub,
143
+ btn_push_to_hub,
144
+ final_dataset,
145
+ success_message,
146
+ )
147
+
148
+
149
+ def validate_argilla_user_workspace_dataset(
150
+ dataset_name: str,
151
+ final_dataset: pd.DataFrame,
152
+ add_to_existing_dataset: bool,
153
+ oauth_token: Union[OAuthToken, None] = None,
154
+ progress=gr.Progress(),
155
+ ) -> str:
156
+ progress(0, desc="Validating dataset configuration")
157
+ hf_user = HfApi().whoami(token=oauth_token.token)["name"]
158
+ client = get_argilla_client()
159
+ if dataset_name is None or dataset_name == "":
160
+ raise gr.Error("Dataset name is required")
161
+ # Create user if it doesn't exist
162
+ rg_user = client.users(username=hf_user)
163
+ if rg_user is None:
164
+ rg_user = client.users.add(
165
+ rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
166
+ )
167
+ # Create workspace if it doesn't exist
168
+ workspace = client.workspaces(name=hf_user)
169
+ if workspace is None:
170
+ workspace = client.workspaces.add(rg.Workspace(name=hf_user))
171
+ workspace.add_user(hf_user)
172
+ # Check if dataset exists
173
+ dataset = client.datasets(name=dataset_name, workspace=hf_user)
174
+ if dataset and not add_to_existing_dataset:
175
+ raise gr.Error(f"Dataset {dataset_name} already exists")
176
+ return final_dataset
177
+
178
+
179
+ def get_login_button():
180
+ return gr.LoginButton(
181
+ value="Sign in with Hugging Face!", size="lg", scale=2
182
+ ).activate()
183
+
184
+
185
+ def get_org_dropdown(oauth_token: OAuthToken = None):
186
+ orgs = list_orgs(oauth_token)
187
+ return gr.Dropdown(
188
+ label="Organization",
189
+ choices=orgs,
190
+ value=orgs[0] if orgs else None,
191
+ allow_custom_value=True,
192
+ )
193
+
194
+
195
+ def get_push_to_hub_ui(default_datasets):
196
+ with gr.Column() as push_to_hub_ui:
197
+ (
198
+ dataset_name,
199
+ add_to_existing_dataset,
200
+ btn_generate_full_dataset_copy,
201
+ btn_generate_and_push_to_argilla,
202
+ btn_push_to_argilla,
203
+ ) = get_argilla_tab()
204
+ (
205
+ org_name,
206
+ repo_name,
207
+ private,
208
+ btn_generate_full_dataset,
209
+ btn_generate_and_push_to_hub,
210
+ btn_push_to_hub,
211
+ ) = get_hf_tab()
212
+ final_dataset = get_final_dataset_row(default_datasets)
213
+ success_message = get_success_message_row()
214
+ return (
215
+ dataset_name,
216
+ add_to_existing_dataset,
217
+ btn_generate_full_dataset_copy,
218
+ btn_generate_and_push_to_argilla,
219
+ btn_push_to_argilla,
220
+ org_name,
221
+ repo_name,
222
+ private,
223
+ btn_generate_full_dataset,
224
+ btn_generate_and_push_to_hub,
225
+ btn_push_to_hub,
226
+ final_dataset,
227
+ success_message,
228
+ )
229
+
230
+
231
+ def get_iterate_on_sample_dataset_ui(
232
+ default_dataset_descriptions: List[str],
233
+ default_system_prompts: List[str],
234
+ default_datasets: List[pd.DataFrame],
235
+ ):
236
+ with gr.Column():
237
+ dataset_description = gr.TextArea(
238
+ label="Give a precise description of the assistant or tool. Don't describe the dataset",
239
+ value=default_dataset_descriptions[0],
240
+ lines=2,
241
+ )
242
+ examples = gr.Examples(
243
+ elem_id="system_prompt_examples",
244
+ examples=[[example] for example in default_dataset_descriptions],
245
+ inputs=[dataset_description],
246
+ )
247
+ with gr.Row():
248
+ gr.Column(scale=1)
249
+ btn_generate_system_prompt = gr.Button(
250
+ value="Generate system prompt and sample dataset"
251
+ )
252
+ gr.Column(scale=1)
253
+
254
+ system_prompt = gr.TextArea(
255
+ label="System prompt for dataset generation. You can tune it and regenerate the sample",
256
+ value=default_system_prompts[0],
257
+ lines=5,
258
+ )
259
+
260
+ with gr.Row():
261
+ sample_dataset = gr.Dataframe(
262
+ value=default_datasets[0],
263
+ label="Sample dataset. Prompts and completions truncated to 256 tokens.",
264
+ interactive=False,
265
+ wrap=True,
266
+ )
267
+
268
+ with gr.Row():
269
+ gr.Column(scale=1)
270
+ btn_generate_sample_dataset = gr.Button(
271
+ value="Generate sample dataset",
272
+ )
273
+ gr.Column(scale=1)
274
+
275
+ return (
276
+ dataset_description,
277
+ examples,
278
+ btn_generate_system_prompt,
279
+ system_prompt,
280
+ sample_dataset,
281
+ btn_generate_sample_dataset,
282
+ )
283
+
284
+
285
+ def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
286
+ gr.Markdown("## Or run this pipeline locally with distilabel")
287
+ gr.Markdown(
288
+ "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
289
+ )
290
+ with gr.Accordion(
291
+ "Run this pipeline using distilabel",
292
+ open=False,
293
+ ):
294
+ pipeline_code = gr.Code(
295
+ value=pipeline_code,
296
+ language="python",
297
+ label="Distilabel Pipeline Code",
298
+ )
299
+ return pipeline_code
300
+
301
+
302
+ def get_argilla_tab() -> Tuple[Any]:
303
+ with gr.Tab(label="Argilla"):
304
+ if get_argilla_client() is not None:
305
+ with gr.Row(variant="panel"):
306
+ dataset_name = gr.Textbox(
307
+ label="Dataset name",
308
+ placeholder="dataset_name",
309
+ value="my-distiset",
310
+ )
311
+ add_to_existing_dataset = gr.Checkbox(
312
+ label="Allow adding records to existing dataset",
313
+ info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
314
+ value=False,
315
+ interactive=True,
316
+ scale=0.5,
317
+ )
318
+
319
+ with gr.Row(variant="panel"):
320
+ btn_generate_full_dataset_copy = gr.Button(
321
+ value="Generate", variant="primary", scale=2
322
+ )
323
+ btn_generate_and_push_to_argilla = gr.Button(
324
+ value="Generate and Push to Argilla",
325
+ variant="primary",
326
+ scale=2,
327
+ )
328
+ btn_push_to_argilla = gr.Button(
329
+ value="Push to Argilla", variant="primary", scale=2
330
+ )
331
+ else:
332
+ gr.Markdown(
333
+ "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
334
+ )
335
+ return (
336
+ dataset_name,
337
+ add_to_existing_dataset,
338
+ btn_generate_full_dataset_copy,
339
+ btn_generate_and_push_to_argilla,
340
+ btn_push_to_argilla,
341
+ )
342
+
343
+
344
+ def get_hf_tab() -> Tuple[Any]:
345
+ with gr.Tab("Hugging Face Hub"):
346
+ with gr.Row(variant="panel"):
347
+ org_name = get_org_dropdown()
348
+ repo_name = gr.Textbox(
349
+ label="Repo name",
350
+ placeholder="dataset_name",
351
+ value="my-distiset",
352
+ )
353
+ private = gr.Checkbox(
354
+ label="Private dataset",
355
+ value=True,
356
+ interactive=True,
357
+ scale=0.5,
358
+ )
359
+ with gr.Row(variant="panel"):
360
+ btn_generate_full_dataset = gr.Button(
361
+ value="Generate", variant="primary", scale=2
362
+ )
363
+ btn_generate_and_push_to_hub = gr.Button(
364
+ value="Generate and Push to Hub", variant="primary", scale=2
365
+ )
366
+ btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2)
367
+ return (
368
+ org_name,
369
+ repo_name,
370
+ private,
371
+ btn_generate_full_dataset,
372
+ btn_generate_and_push_to_hub,
373
+ btn_push_to_hub,
374
+ )
375
+
376
+
377
+ def push_pipeline_code_to_hub(
378
+ pipeline_code: str,
379
+ org_name: str,
380
+ repo_name: str,
381
+ oauth_token: Union[OAuthToken, None] = None,
382
+ progress=gr.Progress(),
383
+ ):
384
+ repo_id = _check_push_to_hub(org_name, repo_name)
385
+ progress(0.1, desc="Uploading pipeline code")
386
+ with io.BytesIO(pipeline_code.encode("utf-8")) as f:
387
+ upload_file(
388
+ path_or_fileobj=f,
389
+ path_in_repo="pipeline.py",
390
+ repo_id=repo_id,
391
+ repo_type="dataset",
392
+ token=oauth_token.token,
393
+ commit_message="Include pipeline script",
394
+ create_pr=False,
395
+ )
396
+ progress(1.0, desc="Pipeline code uploaded")
397
+
398
+
399
+ def push_dataset_to_hub(
400
+ dataframe: pd.DataFrame,
401
+ private: bool = True,
402
+ org_name: str = None,
403
+ repo_name: str = None,
404
+ oauth_token: Union[OAuthToken, None] = None,
405
+ progress=gr.Progress(),
406
+ ) -> pd.DataFrame:
407
+ progress(0.1, desc="Setting up dataset")
408
+ repo_id = _check_push_to_hub(org_name, repo_name)
409
+ distiset = Distiset(
410
+ {
411
+ "default": Dataset.from_pandas(dataframe),
412
+ }
413
+ )
414
+ progress(0.2, desc="Pushing dataset to hub")
415
+ distiset.push_to_hub(
416
+ repo_id=repo_id,
417
+ private=private,
418
+ include_script=False,
419
+ token=oauth_token.token,
420
+ create_pr=False,
421
+ )
422
+ progress(1.0, desc="Dataset pushed to hub")
423
+ return dataframe
424
+
425
+
426
+ def _check_push_to_hub(org_name, repo_name):
427
+ repo_id = (
428
+ f"{org_name}/{repo_name}"
429
+ if repo_name is not None and org_name is not None
430
+ else None
431
+ )
432
+ if repo_id is not None:
433
+ if not all([repo_id, org_name, repo_name]):
434
+ raise gr.Error(
435
+ "Please provide a `repo_name` and `org_name` to push the dataset to."
436
+ )
437
+ return repo_id
438
+
439
+
440
+ def get_final_dataset_row(default_datasets) -> gr.Dataframe:
441
+ with gr.Row():
442
+ final_dataset = gr.Dataframe(
443
+ value=default_datasets[0],
444
+ label="Generated dataset",
445
+ interactive=False,
446
+ wrap=True,
447
+ )
448
+ return final_dataset
449
+
450
+
451
+ def get_success_message_row() -> gr.Markdown:
452
+ with gr.Row():
453
+ success_message = gr.Markdown(visible=False)
454
+ return success_message
455
+
456
+
457
+ def show_success_message_argilla() -> gr.Markdown:
458
+ client = get_argilla_client()
459
+ argilla_api_url = client.api_url
460
+ return gr.Markdown(
461
+ value=f"""
462
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
463
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
464
+ <p style="margin-top: 0.5em;">
465
+ Your dataset is now available at:
466
+ <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
467
+ {argilla_api_url}
468
+ </a>
469
+ <br>Unfamiliar with Argilla? Here are some docs to help you get started:
470
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
471
+ <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
472
+ </p>
473
+ </div>
474
+ """,
475
+ visible=True,
476
+ )
477
+
478
+
479
+ def show_success_message_hub(org_name, repo_name) -> gr.Markdown:
480
+ return gr.Markdown(
481
+ value=f"""
482
+ <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
483
+ <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
484
+ <p style="margin-top: 0.5em;">
485
+ The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
486
+ Your dataset is now available at:
487
+ <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
488
+ https://huggingface.co/datasets/{org_name}/{repo_name}
489
+ </a>
490
+ </p>
491
+ </div>
492
+ """,
493
+ visible=True,
494
+ )
495
+
496
+
497
+ def hide_success_message() -> gr.Markdown:
498
+ return gr.Markdown(visible=False)
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,6 +1,4 @@
1
  import ast
2
- import io
3
- import uuid
4
  from typing import Dict, List, Union
5
 
6
  import argilla as rg
@@ -8,17 +6,29 @@ import gradio as gr
8
  import pandas as pd
9
  from datasets import Dataset
10
  from distilabel.distiset import Distiset
11
- from distilabel.steps.tasks.text_generation import TextGeneration
12
- from gradio.oauth import OAuthToken
13
- from huggingface_hub import upload_file
14
- from huggingface_hub.hf_api import HfApi
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from src.distilabel_dataset_generator.pipelines.embeddings import (
17
  get_embeddings,
18
  get_sentence_embedding_dimensions,
19
  )
20
  from src.distilabel_dataset_generator.pipelines.sft import (
21
- DEFAULT_BATCH_SIZE,
22
  DEFAULT_DATASET_DESCRIPTIONS,
23
  DEFAULT_DATASETS,
24
  DEFAULT_SYSTEM_PROMPTS,
@@ -28,222 +38,45 @@ from src.distilabel_dataset_generator.pipelines.sft import (
28
  get_prompt_generator,
29
  get_response_generator,
30
  )
31
- from src.distilabel_dataset_generator.utils import (
32
- get_argilla_client,
33
- get_base_app,
34
- get_org_dropdown,
35
- swap_visibilty,
36
- )
37
-
38
-
39
- def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
40
- return ast.literal_eval(
41
- messages.replace("'user'}", "'user'},")
42
- .replace("'system'}", "'system'},")
43
- .replace("'assistant'}", "'assistant'},")
44
- )
45
-
46
 
47
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
48
- progress(0.0, desc="Generating system prompt")
49
- if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
50
- index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
51
- if index < len(DEFAULT_SYSTEM_PROMPTS):
52
- return DEFAULT_SYSTEM_PROMPTS[index]
53
 
54
- progress(0.3, desc="Initializing text generation")
55
- generate_description: TextGeneration = get_prompt_generator()
56
- progress(0.7, desc="Generating system prompt")
57
- result = next(
58
- generate_description.process(
59
- [
60
- {
61
- "system_prompt": PROMPT_CREATION_PROMPT,
62
- "instruction": dataset_description,
63
- }
64
- ]
65
  )
66
- )[0]["generation"]
67
- progress(1.0, desc="System prompt generated")
68
- return result
69
-
70
-
71
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
72
- if system_prompt in DEFAULT_SYSTEM_PROMPTS:
73
- index = DEFAULT_SYSTEM_PROMPTS.index(system_prompt)
74
- if index < len(DEFAULT_DATASETS):
75
- return DEFAULT_DATASETS[index]
76
- result = generate_dataset(
77
- system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
78
- )
79
- return result
80
-
81
-
82
- def _check_push_to_hub(org_name, repo_name):
83
- repo_id = (
84
- f"{org_name}/{repo_name}"
85
- if repo_name is not None and org_name is not None
86
- else None
87
- )
88
- if repo_id is not None:
89
- if not all([repo_id, org_name, repo_name]):
90
- raise gr.Error(
91
- "Please provide a `repo_name` and `org_name` to push the dataset to."
92
- )
93
- return repo_id
94
-
95
 
96
- def generate_dataset(
97
- system_prompt: str,
98
- num_turns: int = 1,
99
- num_rows: int = 5,
100
- is_sample: bool = False,
101
- progress=gr.Progress(),
102
- ) -> pd.DataFrame:
103
- progress(0.0, desc="(1/2) Generating instructions")
104
- magpie_generator = get_magpie_generator(
105
- num_turns, num_rows, system_prompt, is_sample
106
- )
107
- response_generator = get_response_generator(num_turns, system_prompt, is_sample)
108
- total_steps: int = num_rows * 2
109
- batch_size = DEFAULT_BATCH_SIZE
110
-
111
- # create instructions
112
- n_processed = 0
113
- magpie_results = []
114
- while n_processed < num_rows:
115
- progress(
116
- 0.5 * n_processed / num_rows,
117
- total=total_steps,
118
- desc="(1/2) Generating instructions",
119
  )
120
- remaining_rows = num_rows - n_processed
121
- batch_size = min(batch_size, remaining_rows)
122
- inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
123
- batch = list(magpie_generator.process(inputs=inputs))
124
- magpie_results.extend(batch[0])
125
- n_processed += batch_size
126
- progress(0.5, desc="(1/2) Generating instructions")
127
-
128
- # generate responses
129
- n_processed = 0
130
- response_results = []
131
- if num_turns == 1:
132
- while n_processed < num_rows:
133
- progress(
134
- 0.5 + 0.5 * n_processed / num_rows,
135
- total=total_steps,
136
- desc="(2/2) Generating responses",
137
- )
138
- batch = magpie_results[n_processed : n_processed + batch_size]
139
- responses = list(response_generator.process(inputs=batch))
140
- response_results.extend(responses[0])
141
- n_processed += batch_size
142
- for result in response_results:
143
- result["prompt"] = result["instruction"]
144
- result["completion"] = result["generation"]
145
- result["system_prompt"] = system_prompt
146
- else:
147
- for result in magpie_results:
148
- result["conversation"].insert(
149
- 0, {"role": "system", "content": system_prompt}
150
- )
151
- result["messages"] = result["conversation"]
152
- while n_processed < num_rows:
153
- progress(
154
- 0.5 + 0.5 * n_processed / num_rows,
155
- total=total_steps,
156
- desc="(2/2) Generating responses",
157
- )
158
- batch = magpie_results[n_processed : n_processed + batch_size]
159
- responses = list(response_generator.process(inputs=batch))
160
- response_results.extend(responses[0])
161
- n_processed += batch_size
162
- for result in response_results:
163
- result["messages"].append(
164
- {"role": "assistant", "content": result["generation"]}
165
- )
166
- progress(
167
- 1,
168
- total=total_steps,
169
- desc="(2/2) Generating responses",
170
- )
171
-
172
- # create distiset
173
- distiset_results = []
174
- for result in response_results:
175
- record = {}
176
- for relevant_keys in [
177
- "messages",
178
- "prompt",
179
- "completion",
180
- "model_name",
181
- "system_prompt",
182
- ]:
183
- if relevant_keys in result:
184
- record[relevant_keys] = result[relevant_keys]
185
- distiset_results.append(record)
186
-
187
- distiset = Distiset(
188
- {
189
- "default": Dataset.from_list(distiset_results),
190
- }
191
- )
192
-
193
- # If not pushing to hub generate the dataset directly
194
- distiset = distiset["default"]
195
- if num_turns == 1:
196
- outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
197
- else:
198
- outputs = distiset.to_pandas()[["messages"]]
199
- dataframe = pd.DataFrame(outputs)
200
- progress(1.0, desc="Dataset generation completed")
201
  return dataframe
202
 
203
 
204
- def push_to_hub(
205
  dataframe: pd.DataFrame,
206
  private: bool = True,
207
  org_name: str = None,
208
  repo_name: str = None,
209
- oauth_token: Union[OAuthToken, None] = None,
210
  progress=gr.Progress(),
211
- ) -> pd.DataFrame:
212
  original_dataframe = dataframe.copy(deep=True)
213
- if "messages" in dataframe.columns:
214
- dataframe["messages"] = dataframe["messages"].apply(
215
- lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
216
- )
217
- progress(0.1, desc="Setting up dataset")
218
- repo_id = _check_push_to_hub(org_name, repo_name)
219
- distiset = Distiset(
220
- {
221
- "default": Dataset.from_pandas(dataframe),
222
- }
223
- )
224
- progress(0.2, desc="Pushing dataset to hub")
225
- distiset.push_to_hub(
226
- repo_id=repo_id,
227
- private=private,
228
- include_script=False,
229
- token=oauth_token.token,
230
- create_pr=False,
231
- )
232
- progress(1.0, desc="Dataset pushed to hub")
233
  return original_dataframe
234
 
235
 
236
- def push_to_argilla(
237
  dataframe: pd.DataFrame,
238
  dataset_name: str,
239
- oauth_token: Union[OAuthToken, None] = None,
240
  progress=gr.Progress(),
241
  ) -> pd.DataFrame:
242
  original_dataframe = dataframe.copy(deep=True)
243
- if "messages" in dataframe.columns:
244
- dataframe["messages"] = dataframe["messages"].apply(
245
- lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
246
- )
247
  try:
248
  progress(0.1, desc="Setting up user and workspace")
249
  client = get_argilla_client()
@@ -363,273 +196,193 @@ def push_to_argilla(
363
  return original_dataframe
364
 
365
 
366
- def validate_argilla_dataset_name(
367
- dataset_name: str,
368
- final_dataset: pd.DataFrame,
369
- add_to_existing_dataset: bool,
370
- oauth_token: Union[OAuthToken, None] = None,
371
- progress=gr.Progress(),
372
- ) -> str:
373
- progress(0, desc="Validating dataset configuration")
374
- hf_user = HfApi().whoami(token=oauth_token.token)["name"]
375
- client = get_argilla_client()
376
- if dataset_name is None or dataset_name == "":
377
- raise gr.Error("Dataset name is required")
378
- # Create user if it doesn't exist
379
- rg_user = client.users(username=hf_user)
380
- if rg_user is None:
381
- rg_user = client.users.add(
382
- rg.User(username=hf_user, role="admin", password=str(uuid.uuid4()))
383
- )
384
- # Create workspace if it doesn't exist
385
- workspace = client.workspaces(name=hf_user)
386
- if workspace is None:
387
- workspace = client.workspaces.add(rg.Workspace(name=hf_user))
388
- workspace.add_user(rg_user)
389
- # Check if dataset exists
390
- dataset = client.datasets(name=dataset_name, workspace=hf_user)
391
- if dataset and not add_to_existing_dataset:
392
- raise gr.Error(f"Dataset {dataset_name} already exists")
393
- return final_dataset
394
-
395
 
396
- def upload_pipeline_code(
397
- pipeline_code,
398
- org_name,
399
- repo_name,
400
- oauth_token: Union[OAuthToken, None] = None,
401
- progress=gr.Progress(),
402
- ):
403
- repo_id = _check_push_to_hub(org_name, repo_name)
404
- progress(0.1, desc="Uploading pipeline code")
405
- with io.BytesIO(pipeline_code.encode("utf-8")) as f:
406
- upload_file(
407
- path_or_fileobj=f,
408
- path_in_repo="pipeline.py",
409
- repo_id=repo_id,
410
- repo_type="dataset",
411
- token=oauth_token.token,
412
- commit_message="Include pipeline script",
413
- create_pr=False,
414
  )
415
- progress(1.0, desc="Pipeline code uploaded")
 
 
416
 
417
 
418
- with get_base_app() as app:
419
- gr.Markdown("## Iterate on a sample dataset")
420
- with gr.Column() as main_ui:
421
- dataset_description = gr.TextArea(
422
- label="Give a precise description of the assistant or tool. Don't describe the dataset",
423
- value=DEFAULT_DATASET_DESCRIPTIONS[0],
424
- lines=2,
425
- )
426
- examples = gr.Examples(
427
- elem_id="system_prompt_examples",
428
- examples=[[example] for example in DEFAULT_DATASET_DESCRIPTIONS],
429
- inputs=[dataset_description],
430
- )
431
- with gr.Row():
432
- gr.Column(scale=1)
433
- btn_generate_system_prompt = gr.Button(
434
- value="Generate system prompt and sample dataset"
435
- )
436
- gr.Column(scale=1)
437
 
438
- system_prompt = gr.TextArea(
439
- label="System prompt for dataset generation. You can tune it and regenerate the sample",
440
- value=DEFAULT_SYSTEM_PROMPTS[0],
441
- lines=5,
 
 
 
 
442
  )
 
 
 
 
 
 
 
443
 
444
- with gr.Row():
445
- sample_dataset = gr.Dataframe(
446
- value=DEFAULT_DATASETS[0],
447
- label="Sample dataset. Prompts and completions truncated to 256 tokens.",
448
- interactive=False,
449
- wrap=True,
 
 
 
450
  )
451
-
452
- with gr.Row():
453
- gr.Column(scale=1)
454
- btn_generate_sample_dataset = gr.Button(
455
- value="Generate sample dataset",
 
 
 
 
 
 
 
456
  )
457
- gr.Column(scale=1)
458
-
459
- result = btn_generate_system_prompt.click(
460
- fn=generate_system_prompt,
461
- inputs=[dataset_description],
462
- outputs=[system_prompt],
463
- show_progress=True,
464
- ).then(
465
- fn=generate_sample_dataset,
466
- inputs=[system_prompt],
467
- outputs=[sample_dataset],
468
- show_progress=True,
469
- )
470
-
471
- btn_generate_sample_dataset.click(
472
- fn=generate_sample_dataset,
473
- inputs=[system_prompt],
474
- outputs=[sample_dataset],
475
- show_progress=True,
476
- )
477
-
478
- # Add a header for the full dataset generation section
479
- gr.Markdown("## Generate full dataset")
480
- gr.Markdown(
481
- "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
482
- )
483
-
484
- with gr.Column() as push_to_hub_ui:
485
- with gr.Row(variant="panel"):
486
- num_turns = gr.Number(
487
- value=1,
488
- label="Number of turns in the conversation",
489
- minimum=1,
490
- maximum=4,
491
- step=1,
492
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
493
- )
494
- num_rows = gr.Number(
495
- value=10,
496
- label="Number of rows in the dataset",
497
- minimum=1,
498
- maximum=500,
499
- info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
500
- )
501
-
502
- with gr.Tab(label="Argilla"):
503
- if get_argilla_client() is not None:
504
- with gr.Row(variant="panel"):
505
- dataset_name = gr.Textbox(
506
- label="Dataset name",
507
- placeholder="dataset_name",
508
- value="my-distiset",
509
- )
510
- add_to_existing_dataset = gr.Checkbox(
511
- label="Allow adding records to existing dataset",
512
- info="When selected, you do need to ensure the number of turns in the conversation is the same as the number of turns in the existing dataset.",
513
- value=False,
514
- interactive=True,
515
- scale=0.5,
516
- )
517
-
518
- with gr.Row(variant="panel"):
519
- btn_generate_full_dataset_copy = gr.Button(
520
- value="Generate", variant="primary", scale=2
521
- )
522
- btn_generate_and_push_to_argilla = gr.Button(
523
- value="Generate and Push to Argilla",
524
- variant="primary",
525
- scale=2,
526
- )
527
- btn_push_to_argilla = gr.Button(
528
- value="Push to Argilla", variant="primary", scale=2
529
- )
530
- else:
531
- gr.Markdown(
532
- "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
533
- )
534
- with gr.Tab("Hugging Face Hub"):
535
- with gr.Row(variant="panel"):
536
- org_name = get_org_dropdown()
537
- repo_name = gr.Textbox(
538
- label="Repo name",
539
- placeholder="dataset_name",
540
- value="my-distiset",
541
- )
542
- private = gr.Checkbox(
543
- label="Private dataset",
544
- value=True,
545
- interactive=True,
546
- scale=0.5,
547
- )
548
- with gr.Row(variant="panel"):
549
- btn_generate_full_dataset = gr.Button(
550
- value="Generate", variant="primary", scale=2
551
- )
552
- btn_generate_and_push_to_hub = gr.Button(
553
- value="Generate and Push to Hub", variant="primary", scale=2
554
- )
555
- btn_push_to_hub = gr.Button(
556
- value="Push to Hub", variant="primary", scale=2
557
- )
558
 
559
- with gr.Row():
560
- final_dataset = gr.Dataframe(
561
- value=DEFAULT_DATASETS[0],
562
- label="Generated dataset",
563
- interactive=False,
564
- wrap=True,
565
- )
 
 
 
 
 
 
 
566
 
567
- with gr.Row():
568
- success_message = gr.Markdown(visible=False)
 
 
 
569
 
570
- def show_success_message_argilla():
571
- client = get_argilla_client()
572
- argilla_api_url = client.api_url
573
- return gr.Markdown(
574
- value=f"""
575
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
576
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
577
- <p style="margin-top: 0.5em;">
578
- Your dataset is now available at:
579
- <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
580
- {argilla_api_url}
581
- </a>
582
- <br>Unfamiliar with Argilla? Here are some docs to help you get started:
583
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
584
- <br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
585
- </p>
586
- </div>
587
- """,
588
- visible=True,
589
- )
590
 
591
- def show_success_message_hub(org_name, repo_name):
592
- return gr.Markdown(
593
- value=f"""
594
- <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
595
- <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
596
- <p style="margin-top: 0.5em;">
597
- The generated dataset is in the right format for fine-tuning with TRL, AutoTrain or other frameworks.
598
- Your dataset is now available at:
599
- <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
600
- https://huggingface.co/datasets/{org_name}/{repo_name}
601
- </a>
602
- </p>
603
- </div>
604
- """,
605
- visible=True,
606
- )
607
 
608
- def hide_success_message():
609
- return gr.Markdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
- gr.Markdown("## Or run this pipeline locally with distilabel")
612
- gr.Markdown(
613
- "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
614
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
- with gr.Accordion(
617
- "Run this pipeline using distilabel",
618
- open=False,
619
- ):
620
- pipeline_code = gr.Code(
621
- value=generate_pipeline_code(
622
- system_prompt.value, num_turns.value, num_rows.value
623
- ),
624
- language="python",
625
- label="Distilabel Pipeline Code",
626
  )
627
 
628
- sample_dataset.change(
629
- fn=lambda x: x,
630
- inputs=[sample_dataset],
631
- outputs=[final_dataset],
632
- )
633
  gr.on(
634
  triggers=[
635
  btn_generate_full_dataset.click,
@@ -645,7 +398,7 @@ with get_base_app() as app:
645
  )
646
 
647
  btn_generate_and_push_to_argilla.click(
648
- fn=validate_argilla_dataset_name,
649
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
650
  outputs=[final_dataset],
651
  show_progress=True,
@@ -658,7 +411,7 @@ with get_base_app() as app:
658
  outputs=[final_dataset],
659
  show_progress=True,
660
  ).success(
661
- fn=push_to_argilla,
662
  inputs=[final_dataset, dataset_name],
663
  outputs=[final_dataset],
664
  show_progress=True,
@@ -677,12 +430,12 @@ with get_base_app() as app:
677
  outputs=[final_dataset],
678
  show_progress=True,
679
  ).then(
680
- fn=push_to_hub,
681
  inputs=[final_dataset, private, org_name, repo_name],
682
  outputs=[final_dataset],
683
  show_progress=True,
684
  ).then(
685
- fn=upload_pipeline_code,
686
  inputs=[pipeline_code, org_name, repo_name],
687
  outputs=[],
688
  show_progress=True,
@@ -696,12 +449,12 @@ with get_base_app() as app:
696
  fn=hide_success_message,
697
  outputs=[success_message],
698
  ).then(
699
- fn=push_to_hub,
700
  inputs=[final_dataset, private, org_name, repo_name],
701
  outputs=[final_dataset],
702
  show_progress=True,
703
  ).then(
704
- fn=upload_pipeline_code,
705
  inputs=[pipeline_code, org_name, repo_name],
706
  outputs=[],
707
  show_progress=True,
@@ -715,12 +468,12 @@ with get_base_app() as app:
715
  fn=hide_success_message,
716
  outputs=[success_message],
717
  ).success(
718
- fn=validate_argilla_dataset_name,
719
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
720
  outputs=[final_dataset],
721
  show_progress=True,
722
  ).success(
723
- fn=push_to_argilla,
724
  inputs=[final_dataset, dataset_name],
725
  outputs=[final_dataset],
726
  show_progress=True,
@@ -745,5 +498,3 @@ with get_base_app() as app:
745
  inputs=[system_prompt, num_turns, num_rows],
746
  outputs=[pipeline_code],
747
  )
748
- app.load(get_org_dropdown, outputs=[org_name])
749
- app.load(fn=swap_visibilty, outputs=main_ui)
 
1
  import ast
 
 
2
  from typing import Dict, List, Union
3
 
4
  import argilla as rg
 
6
  import pandas as pd
7
  from datasets import Dataset
8
  from distilabel.distiset import Distiset
9
+ from huggingface_hub import HfApi
 
 
 
10
 
11
+ from src.distilabel_dataset_generator.apps.base import (
12
+ get_argilla_client,
13
+ get_main_ui,
14
+ get_pipeline_code_ui,
15
+ hide_success_message,
16
+ push_pipeline_code_to_hub,
17
+ show_success_message_argilla,
18
+ show_success_message_hub,
19
+ validate_argilla_user_workspace_dataset,
20
+ )
21
+ from src.distilabel_dataset_generator.apps.base import (
22
+ push_dataset_to_hub as push_to_hub_base,
23
+ )
24
+ from src.distilabel_dataset_generator.pipelines.base import (
25
+ DEFAULT_BATCH_SIZE,
26
+ )
27
  from src.distilabel_dataset_generator.pipelines.embeddings import (
28
  get_embeddings,
29
  get_sentence_embedding_dimensions,
30
  )
31
  from src.distilabel_dataset_generator.pipelines.sft import (
 
32
  DEFAULT_DATASET_DESCRIPTIONS,
33
  DEFAULT_DATASETS,
34
  DEFAULT_SYSTEM_PROMPTS,
 
38
  get_prompt_generator,
39
  get_response_generator,
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
42
 
43
+ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
44
+ def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
45
+ return ast.literal_eval(
46
+ messages.replace("'user'}", "'user'},")
47
+ .replace("'system'}", "'system'},")
48
+ .replace("'assistant'}", "'assistant'},")
 
 
 
 
 
49
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ if "messages" in dataframe.columns:
52
+ dataframe["messages"] = dataframe["messages"].apply(
53
+ lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return dataframe
56
 
57
 
58
+ def push_dataset_to_hub(
59
  dataframe: pd.DataFrame,
60
  private: bool = True,
61
  org_name: str = None,
62
  repo_name: str = None,
63
+ oauth_token: Union[gr.OAuthToken, None] = None,
64
  progress=gr.Progress(),
65
+ ):
66
  original_dataframe = dataframe.copy(deep=True)
67
+ dataframe = convert_dataframe_messages(dataframe)
68
+ push_to_hub_base(dataframe, private, org_name, repo_name, oauth_token, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return original_dataframe
70
 
71
 
72
+ def push_dataset_to_argilla(
73
  dataframe: pd.DataFrame,
74
  dataset_name: str,
75
+ oauth_token: Union[gr.OAuthToken, None] = None,
76
  progress=gr.Progress(),
77
  ) -> pd.DataFrame:
78
  original_dataframe = dataframe.copy(deep=True)
79
+ dataframe = convert_dataframe_messages(dataframe)
 
 
 
80
  try:
81
  progress(0.1, desc="Setting up user and workspace")
82
  client = get_argilla_client()
 
196
  return original_dataframe
197
 
198
 
199
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
200
+ progress(0.0, desc="Generating system prompt")
201
+ if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
202
+ index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
203
+ if index < len(DEFAULT_SYSTEM_PROMPTS):
204
+ return DEFAULT_SYSTEM_PROMPTS[index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ progress(0.3, desc="Initializing text generation")
207
+ generate_description = get_prompt_generator()
208
+ progress(0.7, desc="Generating system prompt")
209
+ result = next(
210
+ generate_description.process(
211
+ [
212
+ {
213
+ "system_prompt": PROMPT_CREATION_PROMPT,
214
+ "instruction": dataset_description,
215
+ }
216
+ ]
 
 
 
 
 
 
 
217
  )
218
+ )[0]["generation"]
219
+ progress(1.0, desc="System prompt generated")
220
+ return result
221
 
222
 
223
+ def generate_dataset(
224
+ system_prompt: str,
225
+ num_turns: int = 1,
226
+ num_rows: int = 5,
227
+ is_sample: bool = False,
228
+ progress=gr.Progress(),
229
+ ) -> pd.DataFrame:
230
+ progress(0.0, desc="(1/2) Generating instructions")
231
+ magpie_generator = get_magpie_generator(
232
+ num_turns, num_rows, system_prompt, is_sample
233
+ )
234
+ response_generator = get_response_generator(num_turns, system_prompt, is_sample)
235
+ total_steps: int = num_rows * 2
236
+ batch_size = DEFAULT_BATCH_SIZE
 
 
 
 
 
237
 
238
+ # create instructions
239
+ n_processed = 0
240
+ magpie_results = []
241
+ while n_processed < num_rows:
242
+ progress(
243
+ 0.5 * n_processed / num_rows,
244
+ total=total_steps,
245
+ desc="(1/2) Generating instructions",
246
  )
247
+ remaining_rows = num_rows - n_processed
248
+ batch_size = min(batch_size, remaining_rows)
249
+ inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
250
+ batch = list(magpie_generator.process(inputs=inputs))
251
+ magpie_results.extend(batch[0])
252
+ n_processed += batch_size
253
+ progress(0.5, desc="(1/2) Generating instructions")
254
 
255
+ # generate responses
256
+ n_processed = 0
257
+ response_results = []
258
+ if num_turns == 1:
259
+ while n_processed < num_rows:
260
+ progress(
261
+ 0.5 + 0.5 * n_processed / num_rows,
262
+ total=total_steps,
263
+ desc="(2/2) Generating responses",
264
  )
265
+ batch = magpie_results[n_processed : n_processed + batch_size]
266
+ responses = list(response_generator.process(inputs=batch))
267
+ response_results.extend(responses[0])
268
+ n_processed += batch_size
269
+ for result in response_results:
270
+ result["prompt"] = result["instruction"]
271
+ result["completion"] = result["generation"]
272
+ result["system_prompt"] = system_prompt
273
+ else:
274
+ for result in magpie_results:
275
+ result["conversation"].insert(
276
+ 0, {"role": "system", "content": system_prompt}
277
  )
278
+ result["messages"] = result["conversation"]
279
+ while n_processed < num_rows:
280
+ progress(
281
+ 0.5 + 0.5 * n_processed / num_rows,
282
+ total=total_steps,
283
+ desc="(2/2) Generating responses",
284
+ )
285
+ batch = magpie_results[n_processed : n_processed + batch_size]
286
+ responses = list(response_generator.process(inputs=batch))
287
+ response_results.extend(responses[0])
288
+ n_processed += batch_size
289
+ for result in response_results:
290
+ result["messages"].append(
291
+ {"role": "assistant", "content": result["generation"]}
292
+ )
293
+ progress(
294
+ 1,
295
+ total=total_steps,
296
+ desc="(2/2) Generating responses",
297
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ # create distiset
300
+ distiset_results = []
301
+ for result in response_results:
302
+ record = {}
303
+ for relevant_keys in [
304
+ "messages",
305
+ "prompt",
306
+ "completion",
307
+ "model_name",
308
+ "system_prompt",
309
+ ]:
310
+ if relevant_keys in result:
311
+ record[relevant_keys] = result[relevant_keys]
312
+ distiset_results.append(record)
313
 
314
+ distiset = Distiset(
315
+ {
316
+ "default": Dataset.from_list(distiset_results),
317
+ }
318
+ )
319
 
320
+ # If not pushing to hub generate the dataset directly
321
+ distiset = distiset["default"]
322
+ if num_turns == 1:
323
+ outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
324
+ else:
325
+ outputs = distiset.to_pandas()[["messages"]]
326
+ dataframe = pd.DataFrame(outputs)
327
+ progress(1.0, desc="Dataset generation completed")
328
+ return dataframe
 
 
 
 
 
 
 
 
 
 
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ (
332
+ app,
333
+ main_ui,
334
+ custom_input_ui,
335
+ dataset_description,
336
+ examples,
337
+ btn_generate_system_prompt,
338
+ system_prompt,
339
+ sample_dataset,
340
+ btn_generate_sample_dataset,
341
+ dataset_name,
342
+ add_to_existing_dataset,
343
+ btn_generate_full_dataset_copy,
344
+ btn_generate_and_push_to_argilla,
345
+ btn_push_to_argilla,
346
+ org_name,
347
+ repo_name,
348
+ private,
349
+ btn_generate_full_dataset,
350
+ btn_generate_and_push_to_hub,
351
+ btn_push_to_hub,
352
+ final_dataset,
353
+ success_message,
354
+ ) = get_main_ui(
355
+ default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
356
+ default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
357
+ default_datasets=DEFAULT_DATASETS,
358
+ fn_generate_system_prompt=generate_system_prompt,
359
+ fn_generate_dataset=generate_dataset,
360
+ )
361
 
362
+ with app:
363
+ with main_ui:
364
+ with custom_input_ui:
365
+ num_turns = gr.Number(
366
+ value=1,
367
+ label="Number of turns in the conversation",
368
+ minimum=1,
369
+ maximum=4,
370
+ step=1,
371
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
372
+ )
373
+ num_rows = gr.Number(
374
+ value=10,
375
+ label="Number of rows in the dataset",
376
+ minimum=1,
377
+ maximum=500,
378
+ info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
379
+ )
380
 
381
+ pipeline_code = get_pipeline_code_ui(
382
+ generate_pipeline_code(system_prompt.value, num_turns.value, num_rows.value)
 
 
 
 
 
 
 
 
383
  )
384
 
385
+ # define app triggers
 
 
 
 
386
  gr.on(
387
  triggers=[
388
  btn_generate_full_dataset.click,
 
398
  )
399
 
400
  btn_generate_and_push_to_argilla.click(
401
+ fn=validate_argilla_user_workspace_dataset,
402
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
403
  outputs=[final_dataset],
404
  show_progress=True,
 
411
  outputs=[final_dataset],
412
  show_progress=True,
413
  ).success(
414
+ fn=push_dataset_to_argilla,
415
  inputs=[final_dataset, dataset_name],
416
  outputs=[final_dataset],
417
  show_progress=True,
 
430
  outputs=[final_dataset],
431
  show_progress=True,
432
  ).then(
433
+ fn=push_dataset_to_hub,
434
  inputs=[final_dataset, private, org_name, repo_name],
435
  outputs=[final_dataset],
436
  show_progress=True,
437
  ).then(
438
+ fn=push_pipeline_code_to_hub,
439
  inputs=[pipeline_code, org_name, repo_name],
440
  outputs=[],
441
  show_progress=True,
 
449
  fn=hide_success_message,
450
  outputs=[success_message],
451
  ).then(
452
+ fn=push_dataset_to_hub,
453
  inputs=[final_dataset, private, org_name, repo_name],
454
  outputs=[final_dataset],
455
  show_progress=True,
456
  ).then(
457
+ fn=push_pipeline_code_to_hub,
458
  inputs=[pipeline_code, org_name, repo_name],
459
  outputs=[],
460
  show_progress=True,
 
468
  fn=hide_success_message,
469
  outputs=[success_message],
470
  ).success(
471
+ fn=validate_argilla_user_workspace_dataset,
472
  inputs=[dataset_name, final_dataset, add_to_existing_dataset],
473
  outputs=[final_dataset],
474
  show_progress=True,
475
  ).success(
476
+ fn=push_dataset_to_argilla,
477
  inputs=[final_dataset, dataset_name],
478
  outputs=[final_dataset],
479
  show_progress=True,
 
498
  inputs=[system_prompt, num_turns, num_rows],
499
  outputs=[pipeline_code],
500
  )
 
 
src/distilabel_dataset_generator/pipelines/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.distilabel_dataset_generator.utils import HF_TOKENS
2
+
3
+ DEFAULT_BATCH_SIZE = 5
4
+ TOKEN_INDEX = 0
5
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
+
7
+
8
+ def _get_next_api_key():
9
+ global TOKEN_INDEX
10
+ api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
11
+ TOKEN_INDEX += 1
12
+ return api_key
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,12 +1,11 @@
1
  import pandas as pd
2
- from datasets import Dataset
3
- from distilabel.distiset import Distiset
4
  from distilabel.llms import InferenceEndpointsLLM
5
- from distilabel.pipeline import Pipeline
6
- from distilabel.steps import KeepColumns
7
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
8
 
9
- from src.distilabel_dataset_generator.utils import HF_TOKENS
 
 
 
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -120,7 +119,6 @@ The prompt you write should follow the same style and structure as the following
120
  User dataset description:
121
  """
122
 
123
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
124
  DEFAULT_DATASET_DESCRIPTIONS = (
125
  "rude customer assistant for a phone company",
126
  "assistant that solves math puzzles using python",
@@ -157,8 +155,6 @@ _STOP_SEQUENCES = [
157
  "assistant",
158
  " \n\n",
159
  ]
160
- DEFAULT_BATCH_SIZE = 5
161
- TOKEN_INDEX = 0
162
 
163
 
164
  def _get_output_mappings(num_turns):
@@ -213,13 +209,6 @@ if __name__ == "__main__":
213
  return code
214
 
215
 
216
- def _get_next_api_key():
217
- global TOKEN_INDEX
218
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
219
- TOKEN_INDEX += 1
220
- return api_key
221
-
222
-
223
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
224
  input_mappings = _get_output_mappings(num_turns)
225
  output_mappings = input_mappings.copy()
@@ -300,12 +289,9 @@ def get_response_generator(num_turns, system_prompt, is_sample):
300
 
301
 
302
  def get_prompt_generator():
303
- global TOKEN_INDEX
304
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
305
- TOKEN_INDEX += 1
306
  prompt_generator = TextGeneration(
307
  llm=InferenceEndpointsLLM(
308
- api_key=api_key,
309
  model_id=MODEL,
310
  tokenizer_id=MODEL,
311
  generation_kwargs={
@@ -318,95 +304,3 @@ def get_prompt_generator():
318
  )
319
  prompt_generator.load()
320
  return prompt_generator
321
-
322
-
323
- def get_pipeline(num_turns, num_rows, system_prompt, is_sample):
324
- input_mappings = _get_output_mappings(num_turns)
325
- output_mappings = input_mappings
326
-
327
- with Pipeline(name="sft") as pipeline:
328
- magpie = get_magpie_generator(num_turns, num_rows, system_prompt, is_sample)
329
- generate_response = get_response_generator(system_prompt, is_sample)
330
-
331
- keep_columns = KeepColumns(
332
- columns=list(output_mappings.values()) + ["model_name"],
333
- )
334
-
335
- magpie.connect(generate_response)
336
- generate_response.connect(keep_columns)
337
- return pipeline
338
-
339
-
340
- if __name__ == "__main__":
341
- prompt_generation_step = get_prompt_generator()
342
- system_prompt = next(
343
- prompt_generation_step.process(
344
- [
345
- {
346
- "system_prompt": PROMPT_CREATION_PROMPT,
347
- "instruction": DEFAULT_DATASET_DESCRIPTIONS[0],
348
- }
349
- ]
350
- )
351
- )[0]["generation"]
352
- num_rows = 2
353
- num_turns = 1
354
- magpie_generator = get_magpie_generator(num_turns, num_rows, system_prompt, False)
355
- response_generator = get_response_generator(num_turns, system_prompt, False)
356
- total_steps = num_rows * 2
357
- batch_size = 5 # Adjust this value as needed
358
-
359
- # create instructions
360
- magpie_results = []
361
- for i in range(0, num_rows, batch_size):
362
- batch = list(magpie_generator.process())[:batch_size]
363
- magpie_results.extend([item[0] for item in batch])
364
-
365
- # generate responses
366
- response_results = []
367
- if num_turns == 1:
368
- for i in range(0, len(magpie_results), batch_size):
369
- batch = magpie_results[i : i + batch_size]
370
- batch = [entry[0] for entry in batch]
371
- responses = list(response_generator.process(inputs=batch))
372
- response_results.extend(responses)
373
- for result in response_results:
374
- result[0]["prompt"] = result[0]["instruction"]
375
- result[0]["completion"] = result[0]["generation"]
376
- result[0]["system_prompt"] = system_prompt
377
- else:
378
- for result in magpie_results:
379
- result[0]["conversation"].insert(
380
- 0, {"role": "system", "content": system_prompt}
381
- )
382
- result[0]["messages"] = result[0]["conversation"]
383
- for i in range(0, len(magpie_results), batch_size):
384
- batch = magpie_results[i : i + batch_size]
385
- batch = [entry[0] for entry in batch]
386
- responses = list(response_generator.process(inputs=batch))
387
- response_results.extend(responses)
388
-
389
- for result in response_results:
390
- result[0]["messages"].append(
391
- {"role": "assistant", "content": result[0]["generation"]}
392
- )
393
-
394
- distiset_results = []
395
- for result in response_results[0]:
396
- record = {}
397
- for relevant_keys in [
398
- "messages",
399
- "prompt",
400
- "completion",
401
- "model_name",
402
- "system_prompt",
403
- ]:
404
- if relevant_keys in result:
405
- record[relevant_keys] = result[relevant_keys]
406
- distiset_results.append(record)
407
-
408
- distiset = Distiset(
409
- {
410
- "default": Dataset.from_list(distiset_results),
411
- }
412
- )
 
1
  import pandas as pd
 
 
2
  from distilabel.llms import InferenceEndpointsLLM
 
 
3
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
4
 
5
+ from src.distilabel_dataset_generator.pipelines.base import (
6
+ MODEL,
7
+ _get_next_api_key,
8
+ )
9
 
10
  INFORMATION_SEEKING_PROMPT = (
11
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
119
  User dataset description:
120
  """
121
 
 
122
  DEFAULT_DATASET_DESCRIPTIONS = (
123
  "rude customer assistant for a phone company",
124
  "assistant that solves math puzzles using python",
 
155
  "assistant",
156
  " \n\n",
157
  ]
 
 
158
 
159
 
160
  def _get_output_mappings(num_turns):
 
209
  return code
210
 
211
 
 
 
 
 
 
 
 
212
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
213
  input_mappings = _get_output_mappings(num_turns)
214
  output_mappings = input_mappings.copy()
 
289
 
290
 
291
  def get_prompt_generator():
 
 
 
292
  prompt_generator = TextGeneration(
293
  llm=InferenceEndpointsLLM(
294
+ api_key=_get_next_api_key(),
295
  model_id=MODEL,
296
  tokenizer_id=MODEL,
297
  generation_kwargs={
 
304
  )
305
  prompt_generator.load()
306
  return prompt_generator