darylalim commited on
Commit
ee08279
·
verified ·
1 Parent(s): b970b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -24
app.py CHANGED
@@ -1,45 +1,81 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
-
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
- from optimum.bettertransformer import BetterTransformer
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(
9
- "google/madlad400-3b-mt",
10
- use_fast=True
11
- )
12
 
13
- model_hf = AutoModelForSeq2SeqLM.from_pretrained(
14
- "google/madlad400-3b-mt",
15
- torch_dtype=torch.bfloat16
16
- )
 
 
 
 
 
 
 
17
 
18
- model = BetterTransformer.transform(model_hf, keep_original=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @spaces.GPU
21
- def translate(text):
22
  """
23
- Translates the input text from English to Hawaiian.
24
  """
25
- text = "<2haw> " + text
 
26
 
27
- inputs = tokenizer(
28
- text,
29
- return_tensors="pt"
30
- )
31
 
32
- outputs = model.generate(**inputs, max_new_tokens=1000)
33
  text_translated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
34
-
35
  return text_translated[0]
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  demo = gr.Interface(
38
  fn=translate,
39
- inputs=[gr.Textbox(label="English")],
40
- outputs=[gr.Textbox(label="Hawaiian")],
41
- title="MADLAD-400-3B-MT English-to-Hawaiian Translation",
42
- description="[Code](https://github.com/darylalim/madlad-400-3b-mt-eng-to-haw-translation)")
 
43
 
44
  demo.queue()
45
 
 
1
  import spaces
2
  import gradio as gr
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
7
 
8
+ tokenizer_3b_mt = AutoTokenizer.from_pretrained("google/madlad400-3b-mt", use_fast=True)
9
+ language_codes = [token for token in tokenizer_3b_mt.get_vocab().keys() if token.startswith("<2")]
10
+ remove_codes = ['<2>', '<2en_xx_simple>', '<2translate>', '<2back_translated>', '<2zxx_xx_dtynoise>', '<2transliterate>']
11
+ language_codes = [token for token in language_codes if token not in remove_codes]
12
+
13
+ model_choices = [
14
+ "google/madlad400-3b-mt",
15
+ "google/madlad400-7b-mt",
16
+ "google/madlad400-10b-mt",
17
+ "google/madlad400-7b-mt-bt"
18
+ ]
19
 
20
+ model_resources = {}
21
+
22
+ def load_tokenizer_model(model_name):
23
+ """
24
+ Load tokenizer and model for a chosen model name.
25
+ """
26
+ if model_name not in model_resources:
27
+ # Load tokenizer and model for first time
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)
30
+ model.to_bettertransformer()
31
+ model.to(device)
32
+ model_resources[model_name] = (tokenizer, model)
33
+ return model_resources[model_name]
34
 
35
  @spaces.GPU
36
+ def translate(text, target_language, model_name):
37
  """
38
+ Translate the input text from English to another language.
39
  """
40
+ # Load tokenizer and model if not already loaded
41
+ tokenizer, model = load_tokenizer_model(model_name)
42
 
43
+ text = target_language + text
44
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
 
 
45
 
46
+ outputs = model.generate(input_ids=input_ids, max_new_tokens=128000)
47
  text_translated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
+
49
  return text_translated[0]
50
 
51
+ title = "MADLAD-400 Translation"
52
+ description = """
53
+ Translation from English to over 400 languages based on [research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research
54
+ """
55
+
56
+ input_text = gr.Textbox(
57
+ label="Text",
58
+ placeholder="Enter text here"
59
+ )
60
+ target_language = gr.Dropdown(
61
+ choices=language_codes,
62
+ value="<2haw>",
63
+ label="Target language"
64
+ )
65
+ model_choice = gr.Dropdown(
66
+ choices=model_choices,
67
+ value="google/madlad400-3b-mt",
68
+ label="Model"
69
+ )
70
+ output_text = gr.Textbox(label="Translation")
71
+
72
  demo = gr.Interface(
73
  fn=translate,
74
+ inputs=[input_text, target_language, model_choice],
75
+ outputs=output_text,
76
+ title=title,
77
+ description=description
78
+ )
79
 
80
  demo.queue()
81