File size: 3,102 Bytes
38ca9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482a4be
38ca9aa
 
d712cdc
38ca9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
    AutoModelForCausalLM,
)

model_dict = {
    "NanoTranslator-XS": "Mxode/NanoTranslator-XS",
    "NanoTranslator-S": "Mxode/NanoTranslator-S",
    "NanoTranslator-M": "Mxode/NanoTranslator-M",
    "NanoTranslator-M2": "Mxode/NanoTranslator-M2",
    "NanoTranslator-L": "Mxode/NanoTranslator-L",
    "NanoTranslator-XL": "Mxode/NanoTranslator-XL",
    "NanoTranslator-XXL": "Mxode/NanoTranslator-XXL",
    "NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2",
}


# initialize model
@st.cache_resource
def load_model(model_path: str):
    model = AutoModelForCausalLM.from_pretrained(model_path)
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
    return model, tokenizer


def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs):
    generation_args = dict(
        max_new_tokens=kwargs.pop("max_new_tokens", 64),
        do_sample=kwargs.pop("do_sample", True),
        temperature=kwargs.pop("temperature", 0.55),
        top_p=kwargs.pop("top_p", 0.8),
        top_k=kwargs.pop("top_k", 40),
        **kwargs
    )

    prompt = "<|im_start|>" + text + "<|endoftext|>"
    model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    generated_ids = model.generate(model_inputs.input_ids, **generation_args)
    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


st.title("NanoTranslator Demo")

st.sidebar.title("Options")
model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()), index=list(model_dict.keys()).index("NanoTranslator-XXL2"))
do_sample = st.sidebar.checkbox("do_sample", value=True)
max_new_tokens = st.sidebar.slider(
    "max_new_tokens", min_value=1, max_value=256, value=64
)
temperature = st.sidebar.slider(
    "temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01
)
top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01)
top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1)

# 根据选择的模型加载
model_path = model_dict[model_choice]
model, tokenizer = load_model(model_path)

input_text = st.text_area(
    "Please input the text to be translated (Currently supports only English to Chinese):",
    "Each step of the cell cycle is monitored by internal.",
)

if st.button("translate"):
    if input_text.strip():
        with st.spinner("Translating..."):
            translation = translate(
                input_text,
                model,
                tokenizer,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
            )
        st.success("Translated successfully!")
        st.write(translation)
    else:
        st.warning("Please input text before translation!")