Tonic commited on
Commit
5ecade3
·
verified ·
1 Parent(s): d204ef0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -8,24 +8,21 @@ description = """ try this space to build downstream applications with [CohereFo
8
 
9
  checkpoint = "CohereForAI/aya-101"
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
12
-
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- model.to(device)
15
-
16
- if device == "cuda":
17
- model = model.half()
18
 
19
  @spaces.GPU
20
- def translate(text):
21
  """
22
  Translates the input text to English using the Aya model.
23
  Assumes the model can automatically detect the input language.
24
  """
 
 
25
  inputs = tokenizer.encode(text, return_tensors="pt").to(device)
26
 
27
  outputs = model.generate(inputs, max_new_tokens=128)
28
 
 
29
  translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
  return translation
31
 
@@ -33,12 +30,13 @@ def main():
33
  with gr.Blocks() as demo:
34
  gr.Markdown(title)
35
  gr.Markdown(description)
 
36
  with gr.Row():
37
- input_text = gr.Textbox(label="Input Text")
38
- output_text = gr.Textbox(label="🌐Aya", interactive=False)
39
- input_text.change(fn=translate, inputs=input_text, outputs=output_text)
40
 
41
  demo.launch()
42
 
43
  if __name__ == "__main__":
44
- main()
 
8
 
9
  checkpoint = "CohereForAI/aya-101"
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
 
 
 
 
 
 
12
 
13
  @spaces.GPU
14
+ def aya(text):
15
  """
16
  Translates the input text to English using the Aya model.
17
  Assumes the model can automatically detect the input language.
18
  """
19
+ model.to(device)
20
+
21
  inputs = tokenizer.encode(text, return_tensors="pt").to(device)
22
 
23
  outputs = model.generate(inputs, max_new_tokens=128)
24
 
25
+
26
  translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  return translation
28
 
 
30
  with gr.Blocks() as demo:
31
  gr.Markdown(title)
32
  gr.Markdown(description)
33
+ output_text = gr.Textbox(label="🌐Aya", interactive=False)
34
  with gr.Row():
35
+ input_text = gr.Textbox(label="🗣️Input Text")
36
+ submit_button = gr.Button("Translate") # Add a button to submit the input
37
+ submit_button.click(fn=aya, inputs=input_text, outputs=output_text)
38
 
39
  demo.launch()
40
 
41
  if __name__ == "__main__":
42
+ main()