KumaTea commited on
Commit
2a4f313
·
1 Parent(s): 221f925

sync with main version

Browse files
Files changed (2) hide show
  1. app.py +55 -35
  2. requirements.txt +3 -3
app.py CHANGED
@@ -12,13 +12,38 @@ fix_pytorch_int8()
12
 
13
 
14
  import torch
 
15
  import gradio as gr
16
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # device = torch.device('cpu')
20
  # torch.cuda.current_device = lambda : device
21
 
 
 
 
 
 
22
  model = AutoModel.from_pretrained(
23
  "KumaTea/twitter-int4",
24
  trust_remote_code=True,
@@ -35,25 +60,32 @@ model.eval()
35
  torch.set_default_tensor_type(torch.FloatTensor)
36
 
37
 
38
- def evaluate(context, temperature, top_p, top_k):
39
  generation_config = GenerationConfig(
40
  temperature=temperature,
41
  top_p=top_p,
42
- top_k=top_k,
43
  #repetition_penalty=1.1,
44
  num_beams=1,
45
  do_sample=True,
46
  )
47
  with torch.no_grad():
48
- input_text = f"Context: {context}Answer: "
49
- ids = tokenizer.encode(input_text)
50
- input_ids = torch.LongTensor([ids]).to('cpu')
 
 
 
51
  out = model.generate(
52
- input_ids=input_ids,
53
- max_length=160,
54
  generation_config=generation_config
55
  )
56
- out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
 
 
 
 
57
  return out_text
58
 
59
 
@@ -65,10 +97,12 @@ def evaluate_stream(msg, history, temperature, top_p):
65
  num_beams=1,
66
  do_sample=True,
67
  )
 
 
68
 
69
- history.append([msg, None])
70
 
71
- context = ""
72
  if len(history) > 4:
73
  history.pop(0)
74
 
@@ -79,7 +113,7 @@ def evaluate_stream(msg, history, temperature, top_p):
79
  for h in history[:-1]:
80
  context += h[0] + "||" + h[1] + "||"
81
 
82
- context += history[-1][0]
83
  context = context.replace(r'<br>', '')
84
 
85
  # TODO: Avoid the tokens are too long.
@@ -89,37 +123,20 @@ def evaluate_stream(msg, history, temperature, top_p):
89
  context = context[15:]
90
 
91
  h = []
92
- print("History:", history)
93
- print("Context:", context)
94
  for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
95
  history[-1][1] = response
96
  yield history, ""
97
 
98
- #return response
99
-
100
-
101
- title = """<h1 align="center">KumaGLM</h1>
102
- <h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
103
- <p align='center'>采用 INT4 量化,速度很慢,仅作备用</p>"""
104
- footer = """<p align='center'>
105
- 本项目基于
106
- <a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
107
- ,模型采用
108
- <a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
109
-
110
- </p>
111
- <p align='center'>
112
- <em>每天起床第一句!</em>
113
- </p>"""
114
 
115
  with gr.Blocks() as demo:
116
- gr.HTML(title)
117
- state = gr.State()
118
  with gr.Row():
119
  with gr.Column(scale=2):
120
- temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.8, label="Temperature",
121
  info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
122
- top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
123
  info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
124
  #code = gr.Textbox(label="temp_output", info="解码器输出")
125
  #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
@@ -128,12 +145,15 @@ with gr.Blocks() as demo:
128
  with gr.Column(scale=3):
129
  chatbot = gr.Chatbot(label="聊天框", info="")
130
  msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
131
- info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。对话一般不能太长,否则就复读机了,建议清除数据。")
132
  clear = gr.Button("清除聊天")
 
 
133
 
134
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
135
  clear.click(lambda: None, None, chatbot, queue=False)
136
- gr.HTML(footer)
 
137
 
138
  demo.queue()
139
  demo.launch(debug=False)
 
12
 
13
 
14
  import torch
15
+ import logging
16
  import gradio as gr
17
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
18
 
19
 
20
+ gr_title = """<h1 align="center">KumaGLM Lite</h1>
21
+ <h3 align='center'>这是<a href="https://huggingface.co/spaces/KumaTea/KumaGLM" target="_blank">另一个</a> AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
22
+ <p align='center'>采用 INT4 量化,速度很慢,仅作备用</p>
23
+ <p align='center'>GitHub Repo: <a class="github-button" href="https://github.com/KumaTea/ChatGLM" aria-label="Star KumaTea/ChatGLM on GitHub">KumaTea/ChatGLM</a></p>
24
+ <script async defer src="https://buttons.github.io/buttons.js"></script>
25
+ """
26
+ gr_footer = """<p align='center'>
27
+ 本项目基于
28
+ <a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
29
+ ,模型采用
30
+ <a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
31
+
32
+ </p>
33
+ <p align='center'>
34
+ <em>每天起床第一句!</em>
35
+ </p>"""
36
+ default_start = ["你是谁?", "我是 kuma"]
37
+
38
+
39
  # device = torch.device('cpu')
40
  # torch.cuda.current_device = lambda : device
41
 
42
+ logging.basicConfig(
43
+ format='%(asctime)s %(levelname)-8s %(message)s',
44
+ level=logging.INFO,
45
+ datefmt='%m/%d %H:%M:%S')
46
+
47
  model = AutoModel.from_pretrained(
48
  "KumaTea/twitter-int4",
49
  trust_remote_code=True,
 
60
  torch.set_default_tensor_type(torch.FloatTensor)
61
 
62
 
63
+ def evaluate(context, temperature, top_p, top_k=None):
64
  generation_config = GenerationConfig(
65
  temperature=temperature,
66
  top_p=top_p,
67
+ # top_k=top_k,
68
  #repetition_penalty=1.1,
69
  num_beams=1,
70
  do_sample=True,
71
  )
72
  with torch.no_grad():
73
+ # input_text = f"Context: {context}Answer: "
74
+ input_text = '||'.join(default_start) + '||'
75
+ input_text += context + '||'
76
+ logging.info('[API] Incoming request: ' + input_text)
77
+ ids = tokenizer([input_text], return_tensors="pt")
78
+ inputs = ids.to("cpu")
79
  out = model.generate(
80
+ **inputs,
81
+ max_length=224,
82
  generation_config=generation_config
83
  )
84
+ out = out.tolist()[0]
85
+ decoder_output = tokenizer.decode(out)
86
+ # out_text = decoder_output.split("Answer: ")[1]
87
+ out_text = decoder_output
88
+ logging.info('[API] Result: ' + out_text)
89
  return out_text
90
 
91
 
 
97
  num_beams=1,
98
  do_sample=True,
99
  )
100
+ if not msg:
101
+ msg = '……'
102
 
103
+ history.append([msg, ""])
104
 
105
+ context = '||'.join(default_start) + '||'
106
  if len(history) > 4:
107
  history.pop(0)
108
 
 
113
  for h in history[:-1]:
114
  context += h[0] + "||" + h[1] + "||"
115
 
116
+ context += history[-1][0] + "||"
117
  context = context.replace(r'<br>', '')
118
 
119
  # TODO: Avoid the tokens are too long.
 
123
  context = context[15:]
124
 
125
  h = []
126
+ logging.info('[UI] Incoming request: ' + context)
 
127
  for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
128
  history[-1][1] = response
129
  yield history, ""
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  with gr.Blocks() as demo:
133
+ gr.HTML(gr_title)
134
+ # state = gr.State()
135
  with gr.Row():
136
  with gr.Column(scale=2):
137
+ temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.5, label="Temperature",
138
  info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
139
+ top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.8, label="Top-p",
140
  info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
141
  #code = gr.Textbox(label="temp_output", info="解码器输出")
142
  #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
 
145
  with gr.Column(scale=3):
146
  chatbot = gr.Chatbot(label="聊天框", info="")
147
  msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
148
+ info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。")
149
  clear = gr.Button("清除聊天")
150
+ api_handler = gr.Button("API", visible=False)
151
+ textbox_for_api = gr.Textbox(visible=False)
152
 
153
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
154
  clear.click(lambda: None, None, chatbot, queue=False)
155
+ api_handler.click(evaluate, [textbox_for_api, temp, top_p], [textbox_for_api], api_name='chat')
156
+ gr.HTML(gr_footer)
157
 
158
  demo.queue()
159
  demo.launch(debug=False)
requirements.txt CHANGED
@@ -15,6 +15,6 @@ datasets>=2.10.1
15
  git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
16
 
17
  --extra-index-url https://download.pytorch.org/whl/cpu
18
- torch==2.0.0+cpu
19
- torchvision==0.15.1+cpu
20
- torchaudio==2.0.1+cpu
 
15
  git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
16
 
17
  --extra-index-url https://download.pytorch.org/whl/cpu
18
+ torch>=2.0.0+cpu
19
+ torchvision>=0.15.1+cpu
20
+ torchaudio>=2.0.1+cpu