davidberenstein1957 HF staff commited on
Commit
76ebdbe
1 Parent(s): 6bb104e

fix src import in textcat file

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -10,7 +10,7 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value
10
  from distilabel.distiset import Distiset
11
  from huggingface_hub import HfApi
12
 
13
- from src.synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
@@ -19,24 +19,24 @@ from src.synthetic_dataset_generator.apps.base import (
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
- from src.synthetic_dataset_generator.pipelines.embeddings import (
 
23
  get_embeddings,
24
  get_sentence_embedding_dimensions,
25
  )
26
- from src.synthetic_dataset_generator.pipelines.textcat import (
27
  DEFAULT_DATASET_DESCRIPTIONS,
28
  generate_pipeline_code,
29
  get_labeller_generator,
30
  get_prompt_generator,
31
  get_textcat_generator,
32
  )
33
- from src.synthetic_dataset_generator.utils import (
34
  get_argilla_client,
35
  get_org_dropdown,
36
  get_preprocess_labels,
37
  swap_visibility,
38
  )
39
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
40
 
41
 
42
  def _get_dataframe():
@@ -189,7 +189,9 @@ def generate_dataset(
189
  lambda x: list(
190
  set(
191
  [
192
- label.lower().strip() if (label is not None and label.lower().strip() in labels) else random.choice(labels)
 
 
193
  for label in x
194
  ]
195
  )
@@ -220,7 +222,10 @@ def push_dataset_to_hub(
220
  pipeline_code: str = "",
221
  progress=gr.Progress(),
222
  ):
223
- gr.Info(message=f"Dataframe columns in push dataset to hub: {dataframe.columns}", duration=20)
 
 
 
224
  progress(0.0, desc="Validating")
225
  repo_id = validate_push_to_hub(org_name, repo_name)
226
  progress(0.3, desc="Preprocessing")
 
10
  from distilabel.distiset import Distiset
11
  from huggingface_hub import HfApi
12
 
13
+ from synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
 
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
+ from synthetic_dataset_generator.pipelines.embeddings import (
24
  get_embeddings,
25
  get_sentence_embedding_dimensions,
26
  )
27
+ from synthetic_dataset_generator.pipelines.textcat import (
28
  DEFAULT_DATASET_DESCRIPTIONS,
29
  generate_pipeline_code,
30
  get_labeller_generator,
31
  get_prompt_generator,
32
  get_textcat_generator,
33
  )
34
+ from synthetic_dataset_generator.utils import (
35
  get_argilla_client,
36
  get_org_dropdown,
37
  get_preprocess_labels,
38
  swap_visibility,
39
  )
 
40
 
41
 
42
  def _get_dataframe():
 
189
  lambda x: list(
190
  set(
191
  [
192
+ label.lower().strip()
193
+ if (label is not None and label.lower().strip() in labels)
194
+ else random.choice(labels)
195
  for label in x
196
  ]
197
  )
 
222
  pipeline_code: str = "",
223
  progress=gr.Progress(),
224
  ):
225
+ gr.Info(
226
+ message=f"Dataframe columns in push dataset to hub: {dataframe.columns}",
227
+ duration=20,
228
+ )
229
  progress(0.0, desc="Validating")
230
  repo_id = validate_push_to_hub(org_name, repo_name)
231
  progress(0.3, desc="Preprocessing")