cstr commited on
Commit
55f037d
Β·
verified Β·
1 Parent(s): 964e0c7

Update app.py

Browse files

+other target languages

Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -550,7 +550,7 @@ logger = logging.getLogger(__name__)
550
 
551
  # Main function to handle the translation workflow
552
  # Main function to handle the translation workflow
553
- def main(dataset_url, model_type, output_dataset_name, range_specification, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
554
  try:
555
  # Login to Hugging Face
556
  if token is None or profile is None or token.token is None or profile.username is None:
@@ -574,6 +574,7 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
574
  # Load the tokenizer
575
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
576
  tokenizer.src_lang = "en"
 
577
  logger.info("Tokenizer loaded successfully.")
578
 
579
  # Define the task based on user input
@@ -581,10 +582,11 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
581
  "url": dataset_url,
582
  "local_path": "train.parquet",
583
  "input_file": f"{model_type}_en.jsonl",
584
- "output_file": f"{model_type}_de.jsonl",
585
- "raw_file": f"{model_type}_de_raw.jsonl",
586
  "range_spec": range_specification,
587
- "model_type": model_type
 
588
  }
589
 
590
  # Call the translate_dataset function with the provided parameters
@@ -601,6 +603,7 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
601
  model_type=task["model_type"],
602
  translator=translator,
603
  tokenizer=tokenizer,
 
604
  )
605
  logger.info("Dataset translation completed!")
606
  return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
@@ -608,15 +611,17 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
608
  return "Login failed. Please try again."
609
  except Exception as e:
610
  logger.error(f"An error occurred in the main function: {e}")
611
- # Ensure logs are flushed and captured
612
  return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
613
 
 
614
  # Gradio interface setup
615
  gradio_title = "🧐 WMT21 Dataset Translation"
616
- gradio_desc = """This tool translates datasets using the WMT21 translation model.
617
  ## πŸ’­ What Does This Tool Do:
618
- - Translates datasets based on the selected model type.
619
- - Uploads the translated dataset to Hugging Face."""
 
 
620
  datasets_desc = """## πŸ“Š Dataset Types:
621
  - **mix**:
622
  - `prompt`: List of dictionaries with 'content' and 'role' fields (multi-turn conversation).
@@ -650,12 +655,14 @@ with gr.Blocks(theme=theme) as demo:
650
  model_type = gr.Dropdown(choices=["mix", "ufb_cached", "ufb"], label="Dataset Type")
651
  output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1, placeholder = "cstr/translated_datasets")
652
  range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
653
-
 
654
  with gr.Column():
655
  output = gr.Markdown(label="Output")
656
 
657
  submit_btn = gr.Button("Translate Dataset", variant="primary")
658
- submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification], outputs=output)
 
659
 
660
  gr.Markdown(datasets_desc)
661
 
 
550
 
551
  # Main function to handle the translation workflow
552
  # Main function to handle the translation workflow
553
+ def main(dataset_url, model_type, output_dataset_name, range_specification, target_language, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
554
  try:
555
  # Login to Hugging Face
556
  if token is None or profile is None or token.token is None or profile.username is None:
 
574
  # Load the tokenizer
575
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
576
  tokenizer.src_lang = "en"
577
+ tokenizer.tgt_lang = target_language # Set target language
578
  logger.info("Tokenizer loaded successfully.")
579
 
580
  # Define the task based on user input
 
582
  "url": dataset_url,
583
  "local_path": "train.parquet",
584
  "input_file": f"{model_type}_en.jsonl",
585
+ "output_file": f"{model_type}_{target_language}.jsonl", # Include target language in the filename
586
+ "raw_file": f"{model_type}_{target_language}_raw.jsonl",
587
  "range_spec": range_specification,
588
+ "model_type": model_type,
589
+ "target_language": target_language # Include target language in the task
590
  }
591
 
592
  # Call the translate_dataset function with the provided parameters
 
603
  model_type=task["model_type"],
604
  translator=translator,
605
  tokenizer=tokenizer,
606
+ target_language=task["target_language"] # Pass the target language
607
  )
608
  logger.info("Dataset translation completed!")
609
  return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
 
611
  return "Login failed. Please try again."
612
  except Exception as e:
613
  logger.error(f"An error occurred in the main function: {e}")
 
614
  return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
615
 
616
+
617
  # Gradio interface setup
618
  gradio_title = "🧐 WMT21 Dataset Translation"
619
+ gradio_desc = """This tool translates english datasets using the WMT21 translation model.
620
  ## πŸ’­ What Does This Tool Do:
621
+ - Translates datasets with structures based on the selected model type.
622
+ - The translation model (facebook/wmt21-dense-24-wide-en-x) supports as target languages: Hausa (ha), Icelandic (is), Japanese (ja), Czech (cs), Russian (ru), Chinese (zh), German (de)
623
+ - Uploads the translated dataset to Hugging Face.
624
+ - At the moment, this works only on CPU, and therefore is very very slow (>1 minute per item depending on string lengths)."""
625
  datasets_desc = """## πŸ“Š Dataset Types:
626
  - **mix**:
627
  - `prompt`: List of dictionaries with 'content' and 'role' fields (multi-turn conversation).
 
655
  model_type = gr.Dropdown(choices=["mix", "ufb_cached", "ufb"], label="Dataset Type")
656
  output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1, placeholder = "cstr/translated_datasets")
657
  range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
658
+ target_language = gr.Dropdown(choices=["ha", "is", "ja", "cs", "ru", "zh", "de"], label="Target Language") # New dropdown for target language
659
+
660
  with gr.Column():
661
  output = gr.Markdown(label="Output")
662
 
663
  submit_btn = gr.Button("Translate Dataset", variant="primary")
664
+ submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, target_language], outputs=output)
665
+
666
 
667
  gr.Markdown(datasets_desc)
668