duzx16
commited on
Commit
·
591fa87
1
Parent(s):
85ba2d2
Add system prompt
Browse files- modeling_chatglm.py +14 -17
- tokenization_chatglm.py +7 -2
modeling_chatglm.py
CHANGED
@@ -1001,19 +1001,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1001 |
response = response.replace("[[训练时间]]", "2023年")
|
1002 |
return response
|
1003 |
|
1004 |
-
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
1005 |
-
inputs = tokenizer.build_chat_input(query, history=history)
|
1006 |
-
inputs = inputs.to(self.device)
|
1007 |
-
return inputs
|
1008 |
-
|
1009 |
-
def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
1010 |
-
inputs = tokenizer.build_chat_input(query)
|
1011 |
inputs = inputs.to(self.device)
|
1012 |
return inputs
|
1013 |
|
1014 |
@torch.inference_mode()
|
1015 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
1016 |
-
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
|
|
1017 |
if history is None:
|
1018 |
history = []
|
1019 |
if logits_processor is None:
|
@@ -1021,7 +1017,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1021 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1022 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1023 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1024 |
-
inputs = self.build_inputs(tokenizer, query, history=history)
|
1025 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
|
1026 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1027 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
@@ -1031,21 +1027,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1031 |
return response, history
|
1032 |
|
1033 |
@torch.inference_mode()
|
1034 |
-
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
1035 |
-
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1036 |
-
return_past_key_values=False, **kwargs):
|
1037 |
if history is None:
|
1038 |
history = []
|
1039 |
if logits_processor is None:
|
1040 |
logits_processor = LogitsProcessorList()
|
1041 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1042 |
-
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")
|
|
|
1043 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1044 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1045 |
-
if past_key_values is None
|
1046 |
-
inputs = self.build_inputs(tokenizer, query, history=history)
|
1047 |
else:
|
1048 |
-
inputs = self.
|
1049 |
if past_key_values is not None:
|
1050 |
past_length = past_key_values[0][0].shape[0]
|
1051 |
if self.transformer.pre_seq_len is not None:
|
|
|
1001 |
response = response.replace("[[训练时间]]", "2023年")
|
1002 |
return response
|
1003 |
|
1004 |
+
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None):
|
1005 |
+
inputs = tokenizer.build_chat_input(query, history=history, system=system)
|
|
|
|
|
|
|
|
|
|
|
1006 |
inputs = inputs.to(self.device)
|
1007 |
return inputs
|
1008 |
|
1009 |
@torch.inference_mode()
|
1010 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
|
1011 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1012 |
+
**kwargs):
|
1013 |
if history is None:
|
1014 |
history = []
|
1015 |
if logits_processor is None:
|
|
|
1017 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1018 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1019 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1020 |
+
inputs = self.build_inputs(tokenizer, query, history=history, system=system)
|
1021 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
|
1022 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1023 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
|
|
1027 |
return response, history
|
1028 |
|
1029 |
@torch.inference_mode()
|
1030 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
|
1031 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1032 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
1033 |
if history is None:
|
1034 |
history = []
|
1035 |
if logits_processor is None:
|
1036 |
logits_processor = LogitsProcessorList()
|
1037 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
1038 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1039 |
+
tokenizer.get_command("<|observation|>")]
|
1040 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1041 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1042 |
+
if past_key_values is None:
|
1043 |
+
inputs = self.build_inputs(tokenizer, query, history=history, system=system)
|
1044 |
else:
|
1045 |
+
inputs = self.build_inputs(tokenizer, query)
|
1046 |
if past_key_values is not None:
|
1047 |
past_length = past_key_values[0][0].shape[0]
|
1048 |
if self.transformer.pre_seq_len is not None:
|
tokenization_chatglm.py
CHANGED
@@ -67,7 +67,9 @@ class SPTokenizer:
|
|
67 |
|
68 |
def convert_id_to_token(self, index):
|
69 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
70 |
-
if index in self.index_special_tokens
|
|
|
|
|
71 |
return ""
|
72 |
return self.sp_model.IdToPiece(index)
|
73 |
|
@@ -171,10 +173,13 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
171 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
172 |
return prefix_tokens
|
173 |
|
174 |
-
def build_chat_input(self, query, history=None):
|
175 |
if history is None:
|
176 |
history = []
|
177 |
input_ids = []
|
|
|
|
|
|
|
178 |
for i, (old_query, old_response) in enumerate(history):
|
179 |
input_ids.extend(
|
180 |
[self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
|
|
|
67 |
|
68 |
def convert_id_to_token(self, index):
|
69 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
70 |
+
if index in self.index_special_tokens:
|
71 |
+
return self.index_special_tokens[index]
|
72 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
73 |
return ""
|
74 |
return self.sp_model.IdToPiece(index)
|
75 |
|
|
|
173 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
174 |
return prefix_tokens
|
175 |
|
176 |
+
def build_chat_input(self, query, history=None, system=None):
|
177 |
if history is None:
|
178 |
history = []
|
179 |
input_ids = []
|
180 |
+
if system is not None:
|
181 |
+
input_ids.extend(
|
182 |
+
[self.get_command("<|system|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(system))
|
183 |
for i, (old_query, old_response) in enumerate(history):
|
184 |
input_ids.extend(
|
185 |
[self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
|