K024 commited on
Commit
d1a642c
·
1 Parent(s): 498424c
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project ignores
2
+ models/
3
+ scripts/
4
+ data/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ .idea/
chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93c988ddb30e2eb97aafe05fd8086f56faec47e8488bc2bb6dbd19ee50ce36ae
3
+ size 459821
chatglm-6b-int8-onnx-merged/model_weights_0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:721f5497129c8f2bbffe685892a99bdc87e00fd29b70d54d5f75df8810811cf1
3
+ size 1069807488
chatglm-6b-int8-onnx-merged/model_weights_1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:320f96165f0ba496292eb4dd35979d5fb5c0bbfc0fbaf83b0e8150a9959d4c8d
3
+ size 948125696
chatglm-6b-int8-onnx-merged/model_weights_2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92bc601207b27b08803e223b6a414eb533d3f4eeab26ed9c3b75ca4b0b977f41
3
+ size 1006960640
chatglm-6b-int8-onnx-merged/model_weights_3.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26218891b8d13a8c3b3b5cc15b47c6ba1b5b140a614cd9a5ffb95a69e5180025
3
+ size 1006960640
chatglm-6b-int8-onnx-merged/model_weights_4.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22f6b5087d50d39c566079a8677c1e1ef41e3b16763f4d022e00d385d4dc88af
3
+ size 1006960640
chatglm-6b-int8-onnx-merged/model_weights_5.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5c6502fdf30878a5e75be2da7e2e134e5bfe3a132b1e98880880687cce1e703
3
+ size 1006960640
chatglm-6b-int8-onnx-merged/model_weights_6.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82b140850685302b6939fca378a4174246304c4afb7b58b26aaecad370d2a15a
3
+ size 671842304
chatglm-6b-int8-onnx-merged/sentencepiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e974d9a69c242ce014c88c2b26089270f6198f3c0b700a887666cd3e816f17e
3
+ size 2706249
model.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from tokenizer import ChatGLMTokenizer
4
+ # import torch
5
+ from onnxruntime import InferenceSession, SessionOptions
6
+
7
+
8
+ # Currently `MatMulInteger` and `DynamicQuantizeLinear` are only supported on CPU,
9
+ # although they are documented as supported on CUDA.
10
+ providers = ["CPUExecutionProvider"]
11
+
12
+ # if torch.cuda.is_available():
13
+ # providers = ["CUDAExecutionProvider"] + providers
14
+
15
+
16
+ # Default paths
17
+ tokenizer_path = "chatglm-6b-int8-onnx-merged/sentencepiece.model"
18
+ onnx_model_path = "chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx"
19
+
20
+
21
+ # input & output names
22
+ past_names = [f"past_{name}_{i}" for i in range(28) for name in ["key", "value"]]
23
+ present_names = [f"present_{name}_{i}" for i in range(28) for name in ["key", "value"]]
24
+ output_names = ["logits"] + present_names
25
+
26
+
27
+ # default kv_cache for first inference
28
+ default_past_key_values = {
29
+ k: np.zeros((1, 0, 32, 128), dtype=np.float32) for k in past_names
30
+ }
31
+
32
+
33
+ def chat_template(history: list[tuple[str, str]], current: str):
34
+ prompt = ""
35
+ chat_round = 0
36
+ for question, answer in history:
37
+ prompt += f"[Round {chat_round}]\n问:{question}\n答:{answer}\n"
38
+ chat_round += 1
39
+ prompt += f"[Round {chat_round}]\n问:{current}\n答:"
40
+ return prompt
41
+
42
+
43
+ def process_response(response: str):
44
+ response = response.strip()
45
+ response = response.replace("[[训练时间]]", "2023年")
46
+ punkts = [
47
+ [",", ","],
48
+ ["!", "!"],
49
+ [":", ":"],
50
+ [";", ";"],
51
+ ["\?", "?"],
52
+ ]
53
+ for item in punkts:
54
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
55
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
56
+ return response
57
+
58
+
59
+ class ChatGLMModel():
60
+
61
+ def __init__(self, onnx_model_path=onnx_model_path, tokenizer_path=tokenizer_path, profile=False) -> None:
62
+ self.tokenizer = ChatGLMTokenizer(tokenizer_path)
63
+ options = SessionOptions()
64
+ options.enable_profiling = profile
65
+ self.session = InferenceSession(onnx_model_path, options, providers=providers)
66
+ self.eop_token_id = self.tokenizer["<eop>"]
67
+
68
+
69
+ def prepare_input(self, prompt: str):
70
+ input_ids, prefix_mask = self.tokenizer.encode(prompt)
71
+
72
+ input_ids = np.array([input_ids], dtype=np.longlong)
73
+ prefix_mask = np.array([prefix_mask], dtype=np.longlong)
74
+
75
+ return input_ids, prefix_mask, default_past_key_values
76
+
77
+
78
+ def sample_next_token(self, logits: np.ndarray, top_k=50, top_p=0.7, temperature=1):
79
+ # softmax with temperature
80
+ exp_logits = np.exp(logits / temperature)
81
+ probs = exp_logits / np.sum(exp_logits)
82
+
83
+ # top k
84
+ top_k_idx = np.argsort(-probs)[:top_k]
85
+ top_k_probs = probs[top_k_idx]
86
+
87
+ # top p
88
+ cumsum_probs = np.cumsum(top_k_probs)
89
+ top_k_probs[(cumsum_probs - top_k_probs) > top_p] = 0.0
90
+ top_k_probs = top_k_probs / np.sum(top_k_probs)
91
+
92
+ # sample
93
+ next_token = np.random.choice(top_k_idx, size=1, p=top_k_probs)
94
+ return next_token[0].item()
95
+
96
+
97
+ def generate_iterate(self, prompt: str, max_generated_tokens=100, top_k=50, top_p=0.7, temperature=1):
98
+ input_ids, prefix_mask, past_key_values = self.prepare_input(prompt)
99
+ output_tokens = []
100
+
101
+ while True:
102
+ inputs = {
103
+ "input_ids": input_ids,
104
+ "prefix_mask": prefix_mask,
105
+ "use_past": np.array(len(output_tokens) > 0),
106
+ }
107
+ inputs.update(past_key_values)
108
+
109
+ logits, *past_key_values = self.session.run(output_names, inputs)
110
+ past_key_values = { k: v for k, v in zip(past_names, past_key_values) }
111
+
112
+ next_token = self.sample_next_token(logits[0, -1], top_k=top_k, top_p=top_p, temperature=temperature)
113
+
114
+ output_tokens += [next_token]
115
+
116
+ if next_token == self.eop_token_id or len(output_tokens) > max_generated_tokens:
117
+ break
118
+
119
+ input_ids = np.array([[next_token]], dtype=np.longlong)
120
+ prefix_mask = np.concatenate([prefix_mask, np.array([[0]], dtype=np.longlong)], axis=1)
121
+
122
+ yield process_response(self.tokenizer.decode(output_tokens))
123
+
124
+ return process_response(self.tokenizer.decode(output_tokens))
125
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ onnxruntime
3
+ sentencepiece
4
+ streamlit
5
+ streamlit-chat
tokenizer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from sentencepiece import SentencePieceProcessor
3
+
4
+
5
+ def replace_spaces_with_blank(match: re.Match[str]):
6
+ return f"<|blank_{len(match.group())}|>"
7
+
8
+
9
+ def replace_blank_with_spaces(match: re.Match[str]):
10
+ return " " * int(match.group(1))
11
+
12
+
13
+ class ChatGLMTokenizer:
14
+ def __init__(self, vocab_file):
15
+ assert vocab_file is not None
16
+ self.vocab_file = vocab_file
17
+ self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
18
+ self.text_tokenizer = SentencePieceProcessor(str(vocab_file))
19
+
20
+ def __len__(self):
21
+ return len(self.text_tokenizer)
22
+
23
+ def __getitem__(self, key: str):
24
+ return self.text_tokenizer[key]
25
+
26
+
27
+ def preprocess(self, text: str, linebreak=True, whitespaces=True):
28
+ if linebreak:
29
+ text = text.replace("\n", "<n>")
30
+ if whitespaces:
31
+ text = text.replace("\t", "<|tab|>")
32
+ text = re.sub(r" {2,80}", replace_spaces_with_blank, text)
33
+ return text
34
+
35
+
36
+ def encode(
37
+ self, text: str, text_pair: str = None,
38
+ linebreak=True, whitespaces=True,
39
+ add_dummy_prefix=True, special_tokens=True,
40
+ ) -> tuple[list[int], list[int]]:
41
+ """
42
+ text: Text to encode. Bidirectional part with a [gMASK] and an <sop> for causal LM.
43
+ text_pair: causal LM part.
44
+ linebreak: Whether to encode newline (\n) in text.
45
+ whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
46
+ special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
47
+ add_dummy_prefix: Whether to add dummy blank space in the beginning.
48
+ """
49
+ text = self.preprocess(text, linebreak, whitespaces)
50
+ if not add_dummy_prefix:
51
+ text = "<n>" + text
52
+
53
+ tokens = self.text_tokenizer.encode(text)
54
+ prefix_mask = [1] * len(tokens)
55
+ if special_tokens:
56
+ tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer["<sop>"]]
57
+ prefix_mask += [1, 0]
58
+
59
+ if text_pair is not None:
60
+ pair_tokens = self.text_tokenizer.encode(text_pair)
61
+ tokens += pair_tokens
62
+ prefix_mask += [0] * len(pair_tokens)
63
+ if special_tokens:
64
+ tokens += [self.text_tokenizer["<eop>"]]
65
+ prefix_mask += [0]
66
+
67
+ return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask
68
+
69
+
70
+ def decode(self, text_ids: list[int]) -> str:
71
+ text = self.text_tokenizer.decode(text_ids)
72
+ text = text.replace("<n>", "\n")
73
+ text = text.replace("<|tab|>", "\t")
74
+ text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text)
75
+ return text
web-ui.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ from model import ChatGLMModel, chat_template
4
+
5
+
6
+ # page state
7
+
8
+ @st.cache_resource
9
+ def create_model():
10
+ return ChatGLMModel()
11
+
12
+ with st.spinner("加载模型中..."):
13
+ model = create_model()
14
+
15
+
16
+ if "history" not in st.session_state:
17
+ st.session_state["history"] = []
18
+
19
+
20
+ # parameters
21
+
22
+ with st.sidebar:
23
+ st.markdown("## 采样参数")
24
+
25
+ max_tokens = st.number_input("max_tokens", min_value=1, max_value=500, value=200)
26
+ temperature = st.number_input("temperature", min_value=0.1, max_value=4.0, value=1.0)
27
+ top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.7)
28
+ top_k = st.number_input("top_k", min_value=1, max_value=500, value=50)
29
+
30
+ if st.button("清空上下文"):
31
+ st.session_state.message = ""
32
+ st.session_state.history = []
33
+
34
+ st.markdown("""
35
+ [ChatGLM](https://huggingface.co/THUDM/chatglm-6b) + [ONNXRuntime](https://onnxruntime.ai/)
36
+ """)
37
+
38
+
39
+ # main body
40
+
41
+ st.markdown("## ChatGLM + ONNXRuntime")
42
+
43
+ history: list[tuple[str, str]] = st.session_state.history
44
+
45
+ if len(history) == 0:
46
+ st.caption("请在下方输入消息开始会话")
47
+
48
+
49
+ for idx, (question, answer) in enumerate(history):
50
+ message(question, is_user=True, key=f"history_question_{idx}")
51
+ message(answer, key=f"history_answer_{idx}")
52
+
53
+
54
+ next_answer = st.container()
55
+
56
+ question = st.text_area(label="消息", key="message")
57
+
58
+ if st.button("发送") and len(question.strip()):
59
+ with next_answer:
60
+ message(question, is_user=True, key="message_question")
61
+ with st.spinner("正在回复中"):
62
+ with st.empty():
63
+ prompt = chat_template(history, question)
64
+ for answer in model.generate_iterate(
65
+ prompt,
66
+ max_generated_tokens=max_tokens,
67
+ top_k=top_k,
68
+ top_p=top_p,
69
+ temperature=temperature,
70
+ ):
71
+ st.write(answer)
72
+ message(answer, key="message_answer")
73
+
74
+ st.session_state.history = history + [(question, answer)]