Add support for streaming output
Browse files- modeling_chatglm.py +120 -42
modeling_chatglm.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
6 |
-
import
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
13 |
from torch.nn.utils import skip_init
|
14 |
-
from typing import Optional, Tuple, Union, List
|
15 |
|
16 |
from transformers.utils import (
|
17 |
add_code_sample_docstrings,
|
@@ -26,7 +26,7 @@ from transformers.modeling_outputs import (
|
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
27 |
from transformers.utils import logging
|
28 |
from transformers.generation.logits_process import LogitsProcessor
|
29 |
-
from transformers.generation.utils import LogitsProcessorList
|
30 |
|
31 |
from .configuration_chatglm import ChatGLMConfig
|
32 |
|
@@ -1108,7 +1108,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1108 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1109 |
input_ids = input_ids.to(self.device)
|
1110 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1111 |
-
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0])
|
1112 |
response = tokenizer.decode(outputs)
|
1113 |
response = response.strip()
|
1114 |
response = response.replace("[[训练时间]]", "2023年")
|
@@ -1116,55 +1116,133 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1116 |
return response, history
|
1117 |
|
1118 |
@torch.no_grad()
|
1119 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
self,
|
|
|
|
|
|
|
|
|
|
|
1121 |
**kwargs,
|
1122 |
):
|
1123 |
-
|
1124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1125 |
|
1126 |
-
if
|
1127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1128 |
|
1129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1130 |
|
1131 |
-
|
|
|
|
|
|
|
1132 |
|
|
|
|
|
1133 |
while True:
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
mask_position = output_seq.index(mask_token)
|
1143 |
-
bos_position = output_seq.index(bos)
|
1144 |
-
if eos in output_seq:
|
1145 |
-
eos_position = output_seq.index(eos)
|
1146 |
-
else:
|
1147 |
-
eos_position = len(output_seq)
|
1148 |
-
|
1149 |
-
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
|
1150 |
-
mask_position + 1:bos_position]
|
1151 |
-
max_length = max(max_length, len(return_seq))
|
1152 |
-
return_seqs.append(return_seq)
|
1153 |
-
|
1154 |
-
for i in range(output_ids.shape[0]):
|
1155 |
-
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
|
1156 |
-
if mask_token not in return_seqs[i]:
|
1157 |
-
stop = True
|
1158 |
-
|
1159 |
-
if stop:
|
1160 |
-
break
|
1161 |
|
1162 |
-
|
1163 |
-
return_seq += [bos]
|
1164 |
|
1165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1166 |
|
1167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1168 |
|
1169 |
def quantize(self, bits: int):
|
1170 |
from .quantization import quantize
|
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
6 |
+
import warnings
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
|
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
13 |
from torch.nn.utils import skip_init
|
14 |
+
from typing import Optional, Tuple, Union, List, Callable
|
15 |
|
16 |
from transformers.utils import (
|
17 |
add_code_sample_docstrings,
|
|
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
27 |
from transformers.utils import logging
|
28 |
from transformers.generation.logits_process import LogitsProcessor
|
29 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
30 |
|
31 |
from .configuration_chatglm import ChatGLMConfig
|
32 |
|
|
|
1108 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1109 |
input_ids = input_ids.to(self.device)
|
1110 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1111 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1112 |
response = tokenizer.decode(outputs)
|
1113 |
response = response.strip()
|
1114 |
response = response.replace("[[训练时间]]", "2023年")
|
|
|
1116 |
return response, history
|
1117 |
|
1118 |
@torch.no_grad()
|
1119 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
1120 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
1121 |
+
if history is None:
|
1122 |
+
history = []
|
1123 |
+
if logits_processor is None:
|
1124 |
+
logits_processor = LogitsProcessorList()
|
1125 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1126 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1127 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1128 |
+
if not history:
|
1129 |
+
prompt = query
|
1130 |
+
else:
|
1131 |
+
prompt = ""
|
1132 |
+
for i, (old_query, response) in enumerate(history):
|
1133 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
1134 |
+
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
1135 |
+
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1136 |
+
input_ids = input_ids.to(self.device)
|
1137 |
+
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1138 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1139 |
+
response = tokenizer.decode(outputs)
|
1140 |
+
response = response.strip()
|
1141 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1142 |
+
new_history = history + [(query, response)]
|
1143 |
+
yield response, new_history
|
1144 |
+
|
1145 |
+
@torch.no_grad()
|
1146 |
+
def stream_generate(
|
1147 |
self,
|
1148 |
+
input_ids,
|
1149 |
+
generation_config: Optional[GenerationConfig] = None,
|
1150 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1151 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1152 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1153 |
**kwargs,
|
1154 |
):
|
1155 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1156 |
+
|
1157 |
+
if generation_config is None:
|
1158 |
+
generation_config = self.generation_config
|
1159 |
+
generation_config = copy.deepcopy(generation_config)
|
1160 |
+
model_kwargs = generation_config.update(**kwargs)
|
1161 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1162 |
+
|
1163 |
+
if isinstance(eos_token_id, int):
|
1164 |
+
eos_token_id = [eos_token_id]
|
1165 |
+
|
1166 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1167 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1168 |
+
warnings.warn(
|
1169 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1170 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1171 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1172 |
+
UserWarning,
|
1173 |
+
)
|
1174 |
+
elif generation_config.max_new_tokens is not None:
|
1175 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1176 |
+
if not has_default_max_length:
|
1177 |
+
logger.warn(
|
1178 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1179 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1180 |
+
"Please refer to the documentation for more information. "
|
1181 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1182 |
+
UserWarning,
|
1183 |
+
)
|
1184 |
|
1185 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1186 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1187 |
+
logger.warning(
|
1188 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1189 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1190 |
+
" increasing `max_new_tokens`."
|
1191 |
+
)
|
1192 |
+
|
1193 |
+
# 2. Set generation parameters if not already defined
|
1194 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1195 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1196 |
|
1197 |
+
logits_processor = self._get_logits_processor(
|
1198 |
+
generation_config=generation_config,
|
1199 |
+
input_ids_seq_length=input_ids_seq_length,
|
1200 |
+
encoder_input_ids=input_ids,
|
1201 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1202 |
+
logits_processor=logits_processor,
|
1203 |
+
)
|
1204 |
|
1205 |
+
stopping_criteria = self._get_stopping_criteria(
|
1206 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1207 |
+
)
|
1208 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1209 |
|
1210 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1211 |
+
scores = None
|
1212 |
while True:
|
1213 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1214 |
+
# forward pass to get next token
|
1215 |
+
outputs = self(
|
1216 |
+
**model_inputs,
|
1217 |
+
return_dict=True,
|
1218 |
+
output_attentions=False,
|
1219 |
+
output_hidden_states=False,
|
1220 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1221 |
|
1222 |
+
next_token_logits = outputs.logits[:, -1, :]
|
|
|
1223 |
|
1224 |
+
# pre-process distribution
|
1225 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1226 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1227 |
+
|
1228 |
+
# sample
|
1229 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1230 |
+
if generation_config.do_sample:
|
1231 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1232 |
+
else:
|
1233 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1234 |
|
1235 |
+
# update generated ids, model inputs, and length for next step
|
1236 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1237 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1238 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1239 |
+
)
|
1240 |
+
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
1241 |
+
|
1242 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1243 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1244 |
+
break
|
1245 |
+
yield input_ids
|
1246 |
|
1247 |
def quantize(self, bits: int):
|
1248 |
from .quantization import quantize
|