Qwen
/

Text Generation
Transformers
Safetensors
Chinese
English
qwen
custom_code
yangapku commited on
Commit
d111a1c
·
1 Parent(s): 58362a1

update config and streaming generation

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_qwen.py +32 -15
config.json CHANGED
@@ -14,12 +14,12 @@
14
  "fp32": false,
15
  "bias_dropout_fusion": true,
16
  "bos_token_id": 151643,
17
- "embd_pdrop": 0.1,
18
  "eos_token_id": 151643,
19
  "ffn_hidden_size": 22016,
20
  "initializer_range": 0.02,
21
  "kv_channels": 128,
22
- "layer_norm_epsilon": 1e-05,
23
  "model_type": "qwen",
24
  "n_embd": 4096,
25
  "n_head": 32,
 
14
  "fp32": false,
15
  "bias_dropout_fusion": true,
16
  "bos_token_id": 151643,
17
+ "embd_pdrop": 0.0,
18
  "eos_token_id": 151643,
19
  "ffn_hidden_size": 22016,
20
  "initializer_range": 0.02,
21
  "kv_channels": 128,
22
+ "layer_norm_epsilon": 1e-06,
23
  "model_type": "qwen",
24
  "n_embd": 4096,
25
  "n_head": 32,
modeling_qwen.py CHANGED
@@ -958,8 +958,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
958
  history: Optional[HistoryType],
959
  system: str = "You are a helpful assistant.",
960
  append_history: bool = True,
 
961
  ) -> Tuple[str, HistoryType]:
962
 
 
963
  if history is None:
964
  history = []
965
 
@@ -976,21 +978,36 @@ class QWenLMHeadModel(QWenPreTrainedModel):
976
  self.generation_config.chat_format, tokenizer
977
  )
978
  input_ids = torch.tensor([context_tokens]).to(self.device)
979
-
980
- outputs = self.generate(
981
- input_ids,
982
- stop_words_ids=stop_words_ids,
983
- return_dict_in_generate=False,
984
- )
985
-
986
- response = decode_tokens(
987
- outputs[0],
988
- tokenizer,
989
- raw_text_len=len(raw_text),
990
- context_length=len(context_tokens),
991
- chat_format=self.generation_config.chat_format,
992
- verbose=False,
993
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
 
995
  if append_history:
996
  history.append((query, response))
 
958
  history: Optional[HistoryType],
959
  system: str = "You are a helpful assistant.",
960
  append_history: bool = True,
961
+ stream: Optional[bool] = False
962
  ) -> Tuple[str, HistoryType]:
963
 
964
+
965
  if history is None:
966
  history = []
967
 
 
978
  self.generation_config.chat_format, tokenizer
979
  )
980
  input_ids = torch.tensor([context_tokens]).to(self.device)
981
+ if stream:
982
+ assert self.generation_config.chat_format == 'chatml'
983
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
984
+ self.__class__.generate = NewGenerationMixin.generate
985
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
986
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
987
+ def stream_generator():
988
+ outputs = []
989
+ for token in self.generate(input_ids, return_dict_in_generate=False, generation_config=stream_config):
990
+ outputs.append(token.item())
991
+ if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
992
+ break
993
+ yield tokenizer.decode(outputs, skip_special_tokens=True)
994
+
995
+ return stream_generator()
996
+ else:
997
+ outputs = self.generate(
998
+ input_ids,
999
+ stop_words_ids = stop_words_ids,
1000
+ return_dict_in_generate = False,
1001
+ )
1002
+
1003
+ response = decode_tokens(
1004
+ outputs[0],
1005
+ tokenizer,
1006
+ raw_text_len=len(raw_text),
1007
+ context_length=len(context_tokens),
1008
+ chat_format=self.generation_config.chat_format,
1009
+ verbose=False,
1010
+ )
1011
 
1012
  if append_history:
1013
  history.append((query, response))