KumaTea commited on
Commit
221f925
·
1 Parent(s): c1aa006

follow KumaTea/KumaGLM

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +139 -0
  3. fix_int8.py +29 -0
  4. requirements.txt +20 -0
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: KumaGLM Lite
3
  emoji: 🐨
4
  colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
  app_file: app.py
 
2
  title: KumaGLM Lite
3
  emoji: 🐨
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fix_int8 import fix_pytorch_int8
2
+ fix_pytorch_int8()
3
+
4
+
5
+ # import subprocess
6
+ # result = subprocess.run(['git', 'clone', 'https://huggingface.co/KumaTea/twitter-int8', 'model'], capture_output=True, text=True)
7
+ # print(result.stdout)
8
+
9
+
10
+ # Credit:
11
+ # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
12
+
13
+
14
+ import torch
15
+ import gradio as gr
16
+ from transformers import AutoTokenizer, GenerationConfig, AutoModel
17
+
18
+
19
+ # device = torch.device('cpu')
20
+ # torch.cuda.current_device = lambda : device
21
+
22
+ model = AutoModel.from_pretrained(
23
+ "KumaTea/twitter-int4",
24
+ trust_remote_code=True,
25
+ revision="e2aecb2"
26
+ ).float() # .to(device)
27
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe")
28
+
29
+ # dump a log to ensure everything works well
30
+ # print(model.peft_config)
31
+ # We have to use full precision, as some tokens are >65535
32
+ model.eval()
33
+ # print(model)
34
+
35
+ torch.set_default_tensor_type(torch.FloatTensor)
36
+
37
+
38
+ def evaluate(context, temperature, top_p, top_k):
39
+ generation_config = GenerationConfig(
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ top_k=top_k,
43
+ #repetition_penalty=1.1,
44
+ num_beams=1,
45
+ do_sample=True,
46
+ )
47
+ with torch.no_grad():
48
+ input_text = f"Context: {context}Answer: "
49
+ ids = tokenizer.encode(input_text)
50
+ input_ids = torch.LongTensor([ids]).to('cpu')
51
+ out = model.generate(
52
+ input_ids=input_ids,
53
+ max_length=160,
54
+ generation_config=generation_config
55
+ )
56
+ out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
57
+ return out_text
58
+
59
+
60
+ def evaluate_stream(msg, history, temperature, top_p):
61
+ generation_config = GenerationConfig(
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ #repetition_penalty=1.1,
65
+ num_beams=1,
66
+ do_sample=True,
67
+ )
68
+
69
+ history.append([msg, None])
70
+
71
+ context = ""
72
+ if len(history) > 4:
73
+ history.pop(0)
74
+
75
+ for j in range(len(history)):
76
+ history[j][0] = history[j][0].replace("<br>", "")
77
+
78
+ # concatenate context
79
+ for h in history[:-1]:
80
+ context += h[0] + "||" + h[1] + "||"
81
+
82
+ context += history[-1][0]
83
+ context = context.replace(r'<br>', '')
84
+
85
+ # TODO: Avoid the tokens are too long.
86
+ CUTOFF = 224
87
+ while len(tokenizer.encode(context)) > CUTOFF:
88
+ # save 15 token size for the answer
89
+ context = context[15:]
90
+
91
+ h = []
92
+ print("History:", history)
93
+ print("Context:", context)
94
+ for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
95
+ history[-1][1] = response
96
+ yield history, ""
97
+
98
+ #return response
99
+
100
+
101
+ title = """<h1 align="center">KumaGLM</h1>
102
+ <h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
103
+ <p align='center'>采用 INT4 量化,速度很慢,仅作备用</p>"""
104
+ footer = """<p align='center'>
105
+ 本项目基于
106
+ <a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
107
+ ,模型采用
108
+ <a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
109
+
110
+ </p>
111
+ <p align='center'>
112
+ <em>每天起床第一句!</em>
113
+ </p>"""
114
+
115
+ with gr.Blocks() as demo:
116
+ gr.HTML(title)
117
+ state = gr.State()
118
+ with gr.Row():
119
+ with gr.Column(scale=2):
120
+ temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.8, label="Temperature",
121
+ info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
122
+ top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
123
+ info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
124
+ #code = gr.Textbox(label="temp_output", info="解码器输出")
125
+ #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
126
+ # info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
127
+
128
+ with gr.Column(scale=3):
129
+ chatbot = gr.Chatbot(label="聊天框", info="")
130
+ msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
131
+ info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。对话一般不能太长,否则就复读机了,建议清除数据。")
132
+ clear = gr.Button("清除聊天")
133
+
134
+ msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
135
+ clear.click(lambda: None, None, chatbot, queue=False)
136
+ gr.HTML(footer)
137
+
138
+ demo.queue()
139
+ demo.launch(debug=False)
fix_int8.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ def fix_pytorch_int8():
6
+ valid_path = [p for p in sys.path if p and os.path.isdir(p)]
7
+
8
+ for path in valid_path:
9
+ for folder in os.listdir(path):
10
+ if 'torch' in folder:
11
+ packages_path = path
12
+ break
13
+
14
+ fix_path = f'{packages_path}/torch/nn/parameter.py'
15
+
16
+ with open(fix_path, 'r') as f:
17
+ text = f.read()
18
+
19
+ if 'if data.dtype == torch.int8' not in text:
20
+ text = text.replace(
21
+ ' return torch.Tensor._make_subclass(cls, data, requires_grad)',
22
+ ' if data.dtype == torch.int8:\n' \
23
+ ' requires_grad = False\n' \
24
+ ' return torch.Tensor._make_subclass(cls, data, requires_grad)'
25
+ )
26
+ with open(fix_path, 'w') as f:
27
+ f.write(text)
28
+
29
+ return print('Fixed torch/nn/parameter.py')
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/requirements.txt
2
+
3
+ # int8
4
+ bitsandbytes>=0.37.1
5
+ accelerate>=0.17.1
6
+
7
+ # chatglm
8
+ protobuf>=3.19.5,<4
9
+ transformers>=4.27.1
10
+ icetk
11
+ cpm_kernels>=1.0.11
12
+
13
+ #
14
+ datasets>=2.10.1
15
+ git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
16
+
17
+ --extra-index-url https://download.pytorch.org/whl/cpu
18
+ torch==2.0.0+cpu
19
+ torchvision==0.15.1+cpu
20
+ torchaudio==2.0.1+cpu