Mxode commited on
Commit
38ca9aa
·
verified ·
1 Parent(s): 5121e26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -90
app.py CHANGED
@@ -1,90 +1,90 @@
1
- import streamlit as st
2
- from transformers import (
3
- PreTrainedTokenizerBase,
4
- PreTrainedTokenizerFast,
5
- AutoModelForCausalLM,
6
- )
7
-
8
- model_dict = {
9
- "NanoTranslator-XS": "Mxode/NanoTranslator-XS",
10
- "NanoTranslator-S": "Mxode/NanoTranslator-S",
11
- "NanoTranslator-M": "Mxode/NanoTranslator-M",
12
- "NanoTranslator-M2": "Mxode/NanoTranslator-M2",
13
- "NanoTranslator-L": "Mxode/NanoTranslator-L",
14
- "NanoTranslator-XL": "Mxode/NanoTranslator-XL",
15
- "NanoTranslator-XXL": "Mxode/NanoTranslator-XXL",
16
- "NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2",
17
- }
18
-
19
-
20
- # initialize model
21
- @st.cache_resource
22
- def load_model(model_path: str):
23
- model = AutoModelForCausalLM.from_pretrained(model_path)
24
- tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
25
- return model, tokenizer
26
-
27
-
28
- def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs):
29
- generation_args = dict(
30
- max_new_tokens=kwargs.pop("max_new_tokens", 64),
31
- do_sample=kwargs.pop("do_sample", True),
32
- temperature=kwargs.pop("temperature", 0.55),
33
- top_p=kwargs.pop("top_p", 0.8),
34
- top_k=kwargs.pop("top_k", 40),
35
- **kwargs
36
- )
37
-
38
- prompt = "<|im_start|>" + text + "<|endoftext|>"
39
- model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
40
-
41
- generated_ids = model.generate(model_inputs.input_ids, **generation_args)
42
- generated_ids = [
43
- output_ids[len(input_ids) :]
44
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
45
- ]
46
-
47
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
- return response
49
-
50
-
51
- st.title("NanoTranslator-Demo")
52
-
53
- st.sidebar.title("Options")
54
- model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()))
55
- do_sample = st.sidebar.checkbox("do_sample", value=True)
56
- max_new_tokens = st.sidebar.slider(
57
- "max_new_tokens", min_value=1, max_value=256, value=64
58
- )
59
- temperature = st.sidebar.slider(
60
- "temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01
61
- )
62
- top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01)
63
- top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1)
64
-
65
- # 根据选择的模型加载
66
- model_path = model_dict[model_choice]
67
- model, tokenizer = load_model(model_path)
68
-
69
- input_text = st.text_area(
70
- "Please input the text to be translated (Currently supports only English to Chinese):",
71
- "Each step of the cell cycle is monitored by internal.",
72
- )
73
-
74
- if st.button("translate"):
75
- if input_text.strip():
76
- with st.spinner("Translating..."):
77
- translation = translate(
78
- input_text,
79
- model,
80
- tokenizer,
81
- max_new_tokens=max_new_tokens,
82
- do_sample=do_sample,
83
- temperature=temperature,
84
- top_p=top_p,
85
- top_k=top_k,
86
- )
87
- st.success("Translated successfully!")
88
- st.write(translation)
89
- else:
90
- st.warning("Please input text before translation!")
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ PreTrainedTokenizerBase,
4
+ PreTrainedTokenizerFast,
5
+ AutoModelForCausalLM,
6
+ )
7
+
8
+ model_dict = {
9
+ "NanoTranslator-XS": "Mxode/NanoTranslator-XS",
10
+ "NanoTranslator-S": "Mxode/NanoTranslator-S",
11
+ "NanoTranslator-M": "Mxode/NanoTranslator-M",
12
+ "NanoTranslator-M2": "Mxode/NanoTranslator-M2",
13
+ "NanoTranslator-L": "Mxode/NanoTranslator-L",
14
+ "NanoTranslator-XL": "Mxode/NanoTranslator-XL",
15
+ "NanoTranslator-XXL": "Mxode/NanoTranslator-XXL",
16
+ "NanoTranslator-XXL2": "Mxode/NanoTranslator-XXL2",
17
+ }
18
+
19
+
20
+ # initialize model
21
+ @st.cache_resource
22
+ def load_model(model_path: str):
23
+ model = AutoModelForCausalLM.from_pretrained(model_path)
24
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
25
+ return model, tokenizer
26
+
27
+
28
+ def translate(text: str, model, tokenizer: PreTrainedTokenizerBase, **kwargs):
29
+ generation_args = dict(
30
+ max_new_tokens=kwargs.pop("max_new_tokens", 64),
31
+ do_sample=kwargs.pop("do_sample", True),
32
+ temperature=kwargs.pop("temperature", 0.55),
33
+ top_p=kwargs.pop("top_p", 0.8),
34
+ top_k=kwargs.pop("top_k", 40),
35
+ **kwargs
36
+ )
37
+
38
+ prompt = "<|im_start|>" + text + "<|endoftext|>"
39
+ model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
40
+
41
+ generated_ids = model.generate(model_inputs.input_ids, **generation_args)
42
+ generated_ids = [
43
+ output_ids[len(input_ids) :]
44
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
45
+ ]
46
+
47
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
+ return response
49
+
50
+
51
+ st.title("NanoTranslator-Demo")
52
+
53
+ st.sidebar.title("Options")
54
+ model_choice = st.sidebar.selectbox("Model", list(model_dict.keys()), index=list(model_options.keys()).index("NanoTranslator-XXL2"))
55
+ do_sample = st.sidebar.checkbox("do_sample", value=True)
56
+ max_new_tokens = st.sidebar.slider(
57
+ "max_new_tokens", min_value=1, max_value=256, value=64
58
+ )
59
+ temperature = st.sidebar.slider(
60
+ "temperature", min_value=0.01, max_value=1.5, value=0.55, step=0.01
61
+ )
62
+ top_p = st.sidebar.slider("top_p", min_value=0.01, max_value=1.0, value=0.8, step=0.01)
63
+ top_k = st.sidebar.slider("top_k", min_value=1, max_value=100, value=40, step=1)
64
+
65
+ # 根据选择的模型加载
66
+ model_path = model_dict[model_choice]
67
+ model, tokenizer = load_model(model_path)
68
+
69
+ input_text = st.text_area(
70
+ "Please input the text to be translated (Currently supports only English to Chinese):",
71
+ "Each step of the cell cycle is monitored by internal.",
72
+ )
73
+
74
+ if st.button("translate"):
75
+ if input_text.strip():
76
+ with st.spinner("Translating..."):
77
+ translation = translate(
78
+ input_text,
79
+ model,
80
+ tokenizer,
81
+ max_new_tokens=max_new_tokens,
82
+ do_sample=do_sample,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ top_k=top_k,
86
+ )
87
+ st.success("Translated successfully!")
88
+ st.write(translation)
89
+ else:
90
+ st.warning("Please input text before translation!")