sdiazlor HF staff commited on
Commit
07a8bbc
·
verified ·
1 Parent(s): 3c2fc33

textcat-review (#12)

Browse files

- fix: apply feedback (3f2128047133bf339ddb5d4c4a6d6d41edb368da)
- fix: remove extra args (d27c1e6872d8be8b25d12d2b7e4baa5c075ed8c5)
- fix: add seed for more randomized samples (46f00bc57d59efb6274c287aa3b3ab0046d4d64e)
- fix: typo (a3f4be77171e6db5c804079dc82e0e30354bcec9)
- fix: correction label or labels (d59361703bc3414d4d5845cbe57ae760b52be7fc)
- fix: duplicated labels in labels and number of rows update listener in raw pipeline (b92482822c81a2a4330d54fd35640c6984ae8bda)

src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -38,8 +38,8 @@ def get_main_ui(
38
  if task == TEXTCAT_TASK:
39
  result = fn_generate_dataset(
40
  system_prompt=system_prompt,
41
- difficulty="mixed",
42
- clarity="mixed",
43
  labels=[],
44
  num_labels=1,
45
  num_rows=1,
@@ -271,7 +271,11 @@ def get_iterate_on_sample_dataset_ui(
271
  with gr.Row():
272
  sample_dataset = gr.Dataframe(
273
  value=default_datasets[0],
274
- label="Sample dataset. Prompts and completions truncated to 256 tokens.",
 
 
 
 
275
  interactive=False,
276
  wrap=True,
277
  )
 
38
  if task == TEXTCAT_TASK:
39
  result = fn_generate_dataset(
40
  system_prompt=system_prompt,
41
+ difficulty="high school",
42
+ clarity="clear",
43
  labels=[],
44
  num_labels=1,
45
  num_rows=1,
 
271
  with gr.Row():
272
  sample_dataset = gr.Dataframe(
273
  value=default_datasets[0],
274
+ label=(
275
+ "Sample dataset. Text truncated to 256 tokens."
276
+ if task == TEXTCAT_TASK
277
+ else "Sample dataset. Prompts and completions truncated to 256 tokens."
278
+ ),
279
  interactive=False,
280
  wrap=True,
281
  )
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -215,7 +215,6 @@ def generate_dataset(
215
  system_prompt=system_prompt,
216
  labels=labels,
217
  num_labels=num_labels,
218
- is_sample=is_sample,
219
  )
220
  total_steps: int = num_rows * 2
221
  batch_size = DEFAULT_BATCH_SIZE
@@ -280,11 +279,13 @@ def generate_dataset(
280
  else:
281
  dataframe["labels"] = dataframe["labels"].apply(
282
  lambda x: (
283
- [
284
- label.lower().strip()
285
- for label in x
286
- if label.lower().strip() in labels
287
- ]
 
 
288
  if isinstance(x, list)
289
  else None
290
  )
@@ -309,6 +310,9 @@ def validate_input_labels(labels):
309
  )
310
  return labels
311
 
 
 
 
312
 
313
  (
314
  app,
@@ -354,7 +358,7 @@ with app:
354
  ],
355
  value="mixed",
356
  label="Difficulty",
357
- info="The difficulty of the text to be generated.",
358
  )
359
  clarity = gr.Dropdown(
360
  choices=[
@@ -368,7 +372,7 @@ with app:
368
  ],
369
  value="mixed",
370
  label="Clarity",
371
- info="The clarity of the text to be generated.",
372
  )
373
  with gr.Column():
374
  labels = gr.Dropdown(
@@ -385,18 +389,18 @@ with app:
385
  size="sm",
386
  )
387
  num_labels = gr.Number(
388
- label="Number of labels",
389
  value=1,
390
  minimum=1,
391
  maximum=10,
392
- info="The number of labels to classify the text.",
393
  )
394
  num_rows = gr.Number(
395
  label="Number of rows",
396
  value=10,
397
  minimum=1,
398
  maximum=500,
399
- info="More rows will take longer to generate.",
400
  )
401
 
402
  pipeline_code = get_pipeline_code_ui(
@@ -415,6 +419,10 @@ with app:
415
  fn=update_suggested_labels,
416
  inputs=[system_prompt],
417
  outputs=labels,
 
 
 
 
418
  )
419
 
420
  gr.on(
@@ -540,9 +548,18 @@ with app:
540
  fn=generate_pipeline_code,
541
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
542
  outputs=[pipeline_code],
 
 
 
 
543
  )
544
  num_labels.change(
545
  fn=generate_pipeline_code,
546
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
547
  outputs=[pipeline_code],
548
  )
 
 
 
 
 
 
215
  system_prompt=system_prompt,
216
  labels=labels,
217
  num_labels=num_labels,
 
218
  )
219
  total_steps: int = num_rows * 2
220
  batch_size = DEFAULT_BATCH_SIZE
 
279
  else:
280
  dataframe["labels"] = dataframe["labels"].apply(
281
  lambda x: (
282
+ list(
283
+ set(
284
+ label.lower().strip()
285
+ for label in x
286
+ if label.lower().strip() in labels
287
+ )
288
+ )
289
  if isinstance(x, list)
290
  else None
291
  )
 
310
  )
311
  return labels
312
 
313
+ def update_max_num_labels(labels):
314
+ return gr.update(maximum=len(labels) if labels else 1)
315
+
316
 
317
  (
318
  app,
 
358
  ],
359
  value="mixed",
360
  label="Difficulty",
361
+ info="Select the comprehension level for the text. Ensure it matches the task context.",
362
  )
363
  clarity = gr.Dropdown(
364
  choices=[
 
372
  ],
373
  value="mixed",
374
  label="Clarity",
375
+ info="Set how easily the correct label or labels can be identified.",
376
  )
377
  with gr.Column():
378
  labels = gr.Dropdown(
 
389
  size="sm",
390
  )
391
  num_labels = gr.Number(
392
+ label="Number of labels per text",
393
  value=1,
394
  minimum=1,
395
  maximum=10,
396
+ info="Select 1 for single-label and >1 for multi-label.",
397
  )
398
  num_rows = gr.Number(
399
  label="Number of rows",
400
  value=10,
401
  minimum=1,
402
  maximum=500,
403
+ info="Select the number of rows in the dataset. More rows will take more time.",
404
  )
405
 
406
  pipeline_code = get_pipeline_code_ui(
 
419
  fn=update_suggested_labels,
420
  inputs=[system_prompt],
421
  outputs=labels,
422
+ ).then(
423
+ fn=update_max_num_labels,
424
+ inputs=[labels],
425
+ outputs=[num_labels],
426
  )
427
 
428
  gr.on(
 
548
  fn=generate_pipeline_code,
549
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
550
  outputs=[pipeline_code],
551
+ ).then(
552
+ fn=update_max_num_labels,
553
+ inputs=[labels],
554
+ outputs=[num_labels],
555
  )
556
  num_labels.change(
557
  fn=generate_pipeline_code,
558
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
559
  outputs=[pipeline_code],
560
  )
561
+ num_rows.change(
562
+ fn=generate_pipeline_code,
563
+ inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
564
+ outputs=[pipeline_code],
565
+ )
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List
2
 
3
  import pandas as pd
 
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
@@ -88,6 +89,7 @@ def generate_pipeline_code(
88
  base_code = f"""
89
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
  import os
 
91
  from distilabel.llms import InferenceEndpointsLLM
92
  from distilabel.pipeline import Pipeline
93
  from distilabel.steps import LoadDataFromDicts, KeepColumns
@@ -111,6 +113,8 @@ with Pipeline(name="textcat") as pipeline:
111
  generation_kwargs={{
112
  "temperature": 0.8,
113
  "max_new_tokens": 2048,
 
 
114
  }},
115
  ),
116
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
@@ -175,8 +179,10 @@ def get_textcat_generator(difficulty, clarity, is_sample):
175
  tokenizer_id=MODEL,
176
  api_key=_get_next_api_key(),
177
  generation_kwargs={
178
- "temperature": 0.8,
179
- "max_new_tokens": 256 if is_sample else 1024,
 
 
180
  },
181
  ),
182
  difficulty=None if difficulty == "mixed" else difficulty,
@@ -186,15 +192,15 @@ def get_textcat_generator(difficulty, clarity, is_sample):
186
  return textcat_generator
187
 
188
 
189
- def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
190
  labeller_generator = TextClassification(
191
  llm=InferenceEndpointsLLM(
192
  model_id=MODEL,
193
  tokenizer_id=MODEL,
194
  api_key=_get_next_api_key(),
195
  generation_kwargs={
196
- "temperature": 0.8,
197
- "max_new_tokens": 256 if is_sample else 1024,
198
  },
199
  ),
200
  context=system_prompt,
 
1
  from typing import List
2
 
3
  import pandas as pd
4
+ import random
5
  from distilabel.llms import InferenceEndpointsLLM
6
  from distilabel.steps.tasks import (
7
  GenerateTextClassificationData,
 
89
  base_code = f"""
90
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
91
  import os
92
+ import random
93
  from distilabel.llms import InferenceEndpointsLLM
94
  from distilabel.pipeline import Pipeline
95
  from distilabel.steps import LoadDataFromDicts, KeepColumns
 
113
  generation_kwargs={{
114
  "temperature": 0.8,
115
  "max_new_tokens": 2048,
116
+ "do_sample": True,
117
+ "seed": random.randint(0, 2**32 - 1),
118
  }},
119
  ),
120
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
 
179
  tokenizer_id=MODEL,
180
  api_key=_get_next_api_key(),
181
  generation_kwargs={
182
+ "temperature": 0.9,
183
+ "max_new_tokens": 256 if is_sample else 2048,
184
+ "do_sample": True,
185
+ "seed": random.randint(0, 2**32 - 1),
186
  },
187
  ),
188
  difficulty=None if difficulty == "mixed" else difficulty,
 
192
  return textcat_generator
193
 
194
 
195
+ def get_labeller_generator(system_prompt, labels, num_labels):
196
  labeller_generator = TextClassification(
197
  llm=InferenceEndpointsLLM(
198
  model_id=MODEL,
199
  tokenizer_id=MODEL,
200
  api_key=_get_next_api_key(),
201
  generation_kwargs={
202
+ "temperature": 0.7,
203
+ "max_new_tokens": 2048,
204
  },
205
  ),
206
  context=system_prompt,