davidberenstein1957 HF staff commited on
Commit
85b97c4
·
1 Parent(s): d129960

fix model validation when using ollama

Browse files
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -527,77 +527,76 @@ with gr.Blocks() as app:
527
  label="Distilabel Pipeline Code",
528
  )
529
 
530
- load_btn.click(
531
- fn=generate_system_prompt,
532
- inputs=[dataset_description],
533
- outputs=[system_prompt],
534
- show_progress=True,
535
- ).then(
536
- fn=generate_sample_dataset,
537
- inputs=[system_prompt, num_turns],
538
- outputs=[dataframe],
539
- show_progress=True,
540
- )
541
-
542
- btn_apply_to_sample_dataset.click(
543
- fn=generate_sample_dataset,
544
- inputs=[system_prompt, num_turns],
545
- outputs=[dataframe],
546
- show_progress=True,
547
- )
548
 
549
- btn_push_to_hub.click(
550
- fn=validate_argilla_user_workspace_dataset,
551
- inputs=[repo_name],
552
- outputs=[success_message],
553
- show_progress=True,
554
- ).then(
555
- fn=validate_push_to_hub,
556
- inputs=[org_name, repo_name],
557
- outputs=[success_message],
558
- show_progress=True,
559
- ).success(
560
- fn=hide_success_message,
561
- outputs=[success_message],
562
- show_progress=True,
563
- ).success(
564
- fn=hide_pipeline_code_visibility,
565
- inputs=[],
566
- outputs=[pipeline_code_ui],
567
- show_progress=True,
568
- ).success(
569
- fn=push_dataset,
570
- inputs=[
571
- org_name,
572
- repo_name,
573
- system_prompt,
574
- num_turns,
575
- num_rows,
576
- private,
577
- temperature,
578
- pipeline_code,
579
- ],
580
- outputs=[success_message],
581
- show_progress=True,
582
- ).success(
583
- fn=show_success_message,
584
- inputs=[org_name, repo_name],
585
- outputs=[success_message],
586
- ).success(
587
- fn=generate_pipeline_code,
588
- inputs=[system_prompt, num_turns, num_rows, temperature],
589
- outputs=[pipeline_code],
590
- ).success(
591
- fn=show_pipeline_code_visibility,
592
- inputs=[],
593
- outputs=[pipeline_code_ui],
594
- )
595
- gr.on(
596
- triggers=[clear_btn_part.click, clear_btn_full.click],
597
- fn=lambda _: ("", "", 1, _get_dataframe()),
598
- inputs=[dataframe],
599
- outputs=[dataset_description, system_prompt, num_turns, dataframe],
600
- )
601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  app.load(fn=swap_visibility, outputs=main_ui)
603
- app.load(fn=get_org_dropdown, outputs=[org_name])
 
527
  label="Distilabel Pipeline Code",
528
  )
529
 
530
+ load_btn.click(
531
+ fn=generate_system_prompt,
532
+ inputs=[dataset_description],
533
+ outputs=[system_prompt],
534
+ show_progress=True,
535
+ ).then(
536
+ fn=generate_sample_dataset,
537
+ inputs=[system_prompt, num_turns],
538
+ outputs=[dataframe],
539
+ show_progress=True,
540
+ )
 
 
 
 
 
 
 
541
 
542
+ btn_apply_to_sample_dataset.click(
543
+ fn=generate_sample_dataset,
544
+ inputs=[system_prompt, num_turns],
545
+ outputs=[dataframe],
546
+ show_progress=True,
547
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
+ btn_push_to_hub.click(
550
+ fn=validate_argilla_user_workspace_dataset,
551
+ inputs=[repo_name],
552
+ outputs=[success_message],
553
+ show_progress=True,
554
+ ).then(
555
+ fn=validate_push_to_hub,
556
+ inputs=[org_name, repo_name],
557
+ outputs=[success_message],
558
+ show_progress=True,
559
+ ).success(
560
+ fn=hide_success_message,
561
+ outputs=[success_message],
562
+ show_progress=True,
563
+ ).success(
564
+ fn=hide_pipeline_code_visibility,
565
+ inputs=[],
566
+ outputs=[pipeline_code_ui],
567
+ show_progress=True,
568
+ ).success(
569
+ fn=push_dataset,
570
+ inputs=[
571
+ org_name,
572
+ repo_name,
573
+ system_prompt,
574
+ num_turns,
575
+ num_rows,
576
+ private,
577
+ temperature,
578
+ pipeline_code,
579
+ ],
580
+ outputs=[success_message],
581
+ show_progress=True,
582
+ ).success(
583
+ fn=show_success_message,
584
+ inputs=[org_name, repo_name],
585
+ outputs=[success_message],
586
+ ).success(
587
+ fn=generate_pipeline_code,
588
+ inputs=[system_prompt, num_turns, num_rows, temperature],
589
+ outputs=[pipeline_code],
590
+ ).success(
591
+ fn=show_pipeline_code_visibility,
592
+ inputs=[],
593
+ outputs=[pipeline_code_ui],
594
+ )
595
+ gr.on(
596
+ triggers=[clear_btn_part.click, clear_btn_full.click],
597
+ fn=lambda _: ("", "", 1, _get_dataframe()),
598
+ inputs=[dataframe],
599
+ outputs=[dataset_description, system_prompt, num_turns, dataframe],
600
+ )
601
+ app.load(fn=get_org_dropdown, outputs=[org_name])
602
  app.load(fn=swap_visibility, outputs=main_ui)
 
src/synthetic_dataset_generator/constants.py CHANGED
@@ -46,10 +46,14 @@ if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
46
  raise ValueError(
47
  f"MAGPIE_PRE_QUERY_TEMPLATE must be either {llama_options} or {qwen_options}."
48
  )
49
- elif MODEL.lower() in llama_options:
 
 
50
  SFT_AVAILABLE = True
51
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
52
- elif MODEL.lower() in qwen_options:
 
 
53
  SFT_AVAILABLE = True
54
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
55
  else:
 
46
  raise ValueError(
47
  f"MAGPIE_PRE_QUERY_TEMPLATE must be either {llama_options} or {qwen_options}."
48
  )
49
+ elif MODEL.lower() in llama_options or any(
50
+ option in MODEL.lower() for option in llama_options
51
+ ):
52
  SFT_AVAILABLE = True
53
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
54
+ elif MODEL.lower() in qwen_options or any(
55
+ option in MODEL.lower() for option in qwen_options
56
+ ):
57
  SFT_AVAILABLE = True
58
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
59
  else: