Lianglan's picture
Update app.py
5342b28
raw
history blame
1.75 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
from ui import title, description, examples
from langs import LANGS
TASK = "translation"
CKPT = "facebook/nllb-200-distilled-600M"
model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
tokenizer = AutoTokenizer.from_pretrained(CKPT)
#device = 0 if torch.cuda.is_available() else -1
def translate(text, src_lang, tgt_lang, max_length=400):
"""
Translate the text from source lang to target lang
"""
translation_pipeline = pipeline(TASK,
model=model,
tokenizer=tokenizer,
src_lang=src_lang,
tgt_lang=tgt_lang,
max_length=max_length)
# translation_pipeline = pipeline(TASK,
# model=model,
# tokenizer=tokenizer,
# src_lang=src_lang,
# tgt_lang=tgt_lang,
# max_length=max_length,
# device=device)
result = translation_pipeline(text)
return result[0]['translation_text']
gr.Interface(
translate,
[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Source Language", choices=LANGS),
gr.components.Dropdown(label="Target Language", choices=LANGS),
gr.components.Slider(8, 512, value=400, step=8, label="Max Length")
],
["text"],
examples=examples,
# article=article,
cache_examples=False,
title=title,
description=description
).launch()