duzx16
commited on
Commit
·
1676f07
1
Parent(s):
591fa87
Implement new interface
Browse files- modeling_chatglm.py +25 -17
- tokenization_chatglm.py +15 -10
modeling_chatglm.py
CHANGED
@@ -996,18 +996,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
996 |
for layer_past in past
|
997 |
)
|
998 |
|
999 |
-
def process_response(self,
|
1000 |
-
|
1001 |
-
response
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
|
|
|
|
|
|
|
|
|
|
1008 |
|
1009 |
@torch.inference_mode()
|
1010 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
1011 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1012 |
**kwargs):
|
1013 |
if history is None:
|
@@ -1017,17 +1022,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1017 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1018 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1019 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1020 |
-
inputs =
|
1021 |
-
|
|
|
|
|
1022 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1023 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
1024 |
response = tokenizer.decode(outputs)
|
1025 |
-
|
1026 |
-
history =
|
1027 |
return response, history
|
1028 |
|
1029 |
@torch.inference_mode()
|
1030 |
-
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
1031 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1032 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
1033 |
if history is None:
|
@@ -1040,9 +1047,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1040 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1041 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1042 |
if past_key_values is None:
|
1043 |
-
inputs =
|
1044 |
else:
|
1045 |
-
inputs =
|
|
|
1046 |
if past_key_values is not None:
|
1047 |
past_length = past_key_values[0][0].shape[0]
|
1048 |
if self.transformer.pre_seq_len is not None:
|
|
|
996 |
for layer_past in past
|
997 |
)
|
998 |
|
999 |
+
def process_response(self, output, history):
|
1000 |
+
content = ""
|
1001 |
+
for response in output.split("<|assistant|>"):
|
1002 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1003 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1004 |
+
if not metadata.strip():
|
1005 |
+
content = content.strip()
|
1006 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1007 |
+
else:
|
1008 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1009 |
+
def tool_call(**kwargs):
|
1010 |
+
return kwargs
|
1011 |
+
content = eval(content)
|
1012 |
+
return content, history
|
1013 |
|
1014 |
@torch.inference_mode()
|
1015 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
|
1016 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1017 |
**kwargs):
|
1018 |
if history is None:
|
|
|
1022 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1023 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1024 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1025 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1026 |
+
inputs = inputs.to(self.device)
|
1027 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1028 |
+
tokenizer.get_command("<|observation|>")]
|
1029 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1030 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
1031 |
response = tokenizer.decode(outputs)
|
1032 |
+
history.append({"role": role, "content": query})
|
1033 |
+
response, history = self.process_response(response, history)
|
1034 |
return response, history
|
1035 |
|
1036 |
@torch.inference_mode()
|
1037 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
|
1038 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1039 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
1040 |
if history is None:
|
|
|
1047 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1048 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1049 |
if past_key_values is None:
|
1050 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1051 |
else:
|
1052 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
1053 |
+
input = inputs.to(self.device)
|
1054 |
if past_key_values is not None:
|
1055 |
past_length = past_key_values[0][0].shape[0]
|
1056 |
if self.transformer.pre_seq_len is not None:
|
tokenization_chatglm.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
from typing import List, Optional, Union, Dict
|
@@ -173,19 +174,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
173 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
174 |
return prefix_tokens
|
175 |
|
176 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
if history is None:
|
178 |
history = []
|
179 |
input_ids = []
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
input_ids.extend(
|
185 |
-
|
186 |
-
input_ids.extend(
|
187 |
-
[self.get_command("<|assistant|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_response))
|
188 |
-
input_ids.extend([self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(query))
|
189 |
input_ids.extend([self.get_command("<|assistant|>")])
|
190 |
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
191 |
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import torch
|
4 |
from typing import List, Optional, Union, Dict
|
|
|
174 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
175 |
return prefix_tokens
|
176 |
|
177 |
+
def build_single_message(self, role, metadata, message):
|
178 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
179 |
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
180 |
+
message_tokens = self.tokenizer.encode(message)
|
181 |
+
tokens = role_tokens + message_tokens
|
182 |
+
return tokens
|
183 |
+
|
184 |
+
def build_chat_input(self, query, history=None, role="user"):
|
185 |
if history is None:
|
186 |
history = []
|
187 |
input_ids = []
|
188 |
+
for item in history:
|
189 |
+
content = item["content"]
|
190 |
+
if item["role"] == "system" and "tools" in item:
|
191 |
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
192 |
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
193 |
+
input_ids.extend(self.build_single_message(role, "", query))
|
|
|
|
|
|
|
194 |
input_ids.extend([self.get_command("<|assistant|>")])
|
195 |
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
196 |
|