choco9966 commited on
Commit
57814fa
·
1 Parent(s): 30cd044

[Add] Upload Qwen-7b model

Browse files
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/project/public/checkpoints/Qwen-7B-HF-Toknizer",
3
+ "apply_residual_connection_post_layernorm": false,
4
+ "architectures": [
5
+ "QWenLMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_qwen.QWenConfig",
10
+ "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
11
+ },
12
+ "bf16": true,
13
+ "bias_dropout_fusion": true,
14
+ "bos_token_id": 151643,
15
+ "embd_pdrop": 0.0,
16
+ "eos_token_id": 151643,
17
+ "ffn_hidden_size": 22016,
18
+ "fp16": false,
19
+ "fp32": false,
20
+ "initializer_range": 0.02,
21
+ "kv_channels": 128,
22
+ "layer_norm_epsilon": 1e-06,
23
+ "model_type": "qwen",
24
+ "n_embd": 4096,
25
+ "n_head": 32,
26
+ "n_inner": null,
27
+ "n_layer": 32,
28
+ "n_positions": 6144,
29
+ "no_bias": true,
30
+ "onnx_safe": null,
31
+ "padded_vocab_size": 151936,
32
+ "params_dtype": "torch.bfloat16",
33
+ "pos_emb": "rotary",
34
+ "resid_pdrop": 0.1,
35
+ "rotary_emb_base": 10000,
36
+ "rotary_pct": 1.0,
37
+ "scale_attn_weights": true,
38
+ "seq_length": 2048,
39
+ "tie_word_embeddings": false,
40
+ "tokenizer_type": "QWenTokenizer",
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.32.0",
43
+ "use_cache": false,
44
+ "use_dynamic_ntk": true,
45
+ "use_flash_attn": true,
46
+ "use_logn_attn": true,
47
+ "vocab_size": 151936
48
+ }
configuration_qwen.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+ attribute_map = {
13
+ "hidden_size": "n_embd",
14
+ "num_attention_heads": "n_head",
15
+ "max_position_embeddings": "n_positions",
16
+ "num_hidden_layers": "n_layer",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=151851,
22
+ n_embd=4096,
23
+ n_layer=32,
24
+ n_head=32,
25
+ n_inner=None,
26
+ embd_pdrop=0.0,
27
+ attn_pdrop=0.0,
28
+ layer_norm_epsilon=1e-5,
29
+ initializer_range=0.02,
30
+ scale_attn_weights=True,
31
+ use_cache=True,
32
+ eos_token_id=151643,
33
+ apply_residual_connection_post_layernorm=False,
34
+ bf16=False,
35
+ fp16=False,
36
+ fp32=False,
37
+ kv_channels=128,
38
+ rotary_pct=1.0,
39
+ rotary_emb_base=10000,
40
+ use_dynamic_ntk=False,
41
+ use_logn_attn=False,
42
+ use_flash_attn=True,
43
+ ffn_hidden_size=22016,
44
+ no_bias=True,
45
+ tie_word_embeddings=False,
46
+ **kwargs,
47
+ ):
48
+ self.eos_token_id = eos_token_id
49
+ super().__init__(
50
+ eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
51
+ )
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_embd = n_embd
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.n_inner = n_inner
58
+ self.embd_pdrop = embd_pdrop
59
+ self.attn_pdrop = attn_pdrop
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+ self.scale_attn_weights = scale_attn_weights
63
+ self.use_cache = use_cache
64
+ self.apply_residual_connection_post_layernorm = (
65
+ apply_residual_connection_post_layernorm
66
+ )
67
+ self.bf16 = bf16
68
+ self.fp16 = fp16
69
+ self.fp32 = fp32
70
+ self.kv_channels = kv_channels
71
+ self.rotary_pct = rotary_pct
72
+ self.rotary_emb_base = rotary_emb_base
73
+ self.use_dynamic_ntk = use_dynamic_ntk
74
+ self.use_logn_attn = use_logn_attn
75
+ self.use_flash_attn = use_flash_attn
76
+ self.ffn_hidden_size = ffn_hidden_size
77
+ self.no_bias = no_bias
78
+ self.tie_word_embeddings = tie_word_embeddings
generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "raw",
3
+ "do_sample": true,
4
+ "eos_token_id": 151643,
5
+ "max_new_tokens": 512,
6
+ "pad_token_id": 151643,
7
+ "stop_words_ids": [
8
+ [
9
+ 151643
10
+ ]
11
+ ],
12
+ "top_k": 0,
13
+ "top_p": 0.8,
14
+ "transformers_version": "4.32.0"
15
+ }
modeling_qwen.py ADDED
@@ -0,0 +1,1210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
+ from transformers.generation.logits_process import LogitsProcessorList
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.generation.streamers import BaseStreamer
21
+ from transformers.generation.utils import GenerateOutput
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ try:
30
+ from einops import rearrange
31
+ except ImportError:
32
+ rearrange = None
33
+ from torch import nn
34
+
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+
39
+ from .configuration_qwen import QWenConfig
40
+ from .qwen_generation_utils import (
41
+ HistoryType,
42
+ make_context,
43
+ decode_tokens,
44
+ get_stop_words_ids,
45
+ StopWordsLogitsProcessor,
46
+ )
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "qwen"
52
+ _CONFIG_FOR_DOC = "QWenConfig"
53
+
54
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
+
56
+ _ERROR_BAD_CHAT_FORMAT = """\
57
+ We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
58
+ If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
59
+ 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
60
+ 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
+ """
62
+
63
+ _SENTINEL = object()
64
+ _ERROR_STREAM_IN_CHAT = """\
65
+ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
66
+ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
67
+ """
68
+
69
+ apply_rotary_emb_func = None
70
+ rms_norm = None
71
+ flash_attn_unpadded_func = None
72
+
73
+
74
+ def _import_flash_attn():
75
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
76
+ try:
77
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
78
+ apply_rotary_emb_func = __apply_rotary_emb_func
79
+ except ImportError:
80
+ logger.warn(
81
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
82
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
83
+ )
84
+
85
+ try:
86
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
87
+ rms_norm = __rms_norm
88
+ except ImportError:
89
+ logger.warn(
90
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
91
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
92
+ )
93
+
94
+ try:
95
+ import flash_attn
96
+ if not hasattr(flash_attn, '__version__'):
97
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
98
+ else:
99
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
100
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
101
+ else:
102
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
103
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
104
+ except ImportError:
105
+ logger.warn(
106
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
107
+ "https://github.com/Dao-AILab/flash-attention"
108
+ )
109
+
110
+
111
+ class FlashSelfAttention(torch.nn.Module):
112
+ def __init__(
113
+ self,
114
+ causal=False,
115
+ softmax_scale=None,
116
+ attention_dropout=0.0,
117
+ ):
118
+ super().__init__()
119
+ assert flash_attn_unpadded_func is not None, (
120
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
121
+ )
122
+ assert (
123
+ rearrange is not None
124
+ ), "Please install einops first, e.g., with pip install einops"
125
+ self.causal = causal
126
+ self.softmax_scale = softmax_scale
127
+ self.dropout_p = attention_dropout
128
+
129
+ def forward(self, q, k, v):
130
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
131
+ assert all((i.is_cuda for i in (q, k, v)))
132
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
133
+ seqlen_k = k.shape[1]
134
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
135
+ cu_seqlens_q = torch.arange(
136
+ 0,
137
+ (batch_size + 1) * seqlen_q,
138
+ step=seqlen_q,
139
+ dtype=torch.int32,
140
+ device=q.device,
141
+ )
142
+
143
+ if self.training:
144
+ assert seqlen_k == seqlen_q
145
+
146
+ is_causal = self.causal
147
+ cu_seqlens_k = cu_seqlens_q
148
+ else:
149
+ is_causal = seqlen_q == seqlen_k
150
+ cu_seqlens_k = torch.arange(
151
+ 0,
152
+ (batch_size + 1) * seqlen_k,
153
+ step=seqlen_k,
154
+ dtype=torch.int32,
155
+ device=q.device,
156
+ )
157
+ self.dropout_p = 0
158
+ output = flash_attn_unpadded_func(
159
+ q,
160
+ k,
161
+ v,
162
+ cu_seqlens_q,
163
+ cu_seqlens_k,
164
+ seqlen_q,
165
+ seqlen_k,
166
+ self.dropout_p,
167
+ softmax_scale=self.softmax_scale,
168
+ causal=is_causal,
169
+ )
170
+
171
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
172
+ return output
173
+
174
+
175
+ class QWenAttention(nn.Module):
176
+ def __init__(self, config, layer_number=None):
177
+ super().__init__()
178
+
179
+ max_positions = config.max_position_embeddings
180
+ self.register_buffer(
181
+ "bias",
182
+ torch.tril(
183
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
184
+ ).view(1, 1, max_positions, max_positions),
185
+ persistent=False,
186
+ )
187
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
188
+ self.layer_number = max(1, layer_number)
189
+ self.params_dtype = config.params_dtype
190
+ self.seq_length = config.seq_length
191
+
192
+ self.hidden_size = config.hidden_size
193
+ self.split_size = config.hidden_size
194
+ self.num_heads = config.num_attention_heads
195
+ self.head_dim = self.hidden_size // self.num_heads
196
+
197
+ self.use_flash_attn = config.use_flash_attn
198
+ self.scale_attn_weights = True
199
+
200
+ self.layer_idx = None
201
+
202
+ self.projection_size = config.kv_channels * config.num_attention_heads
203
+
204
+ assert self.projection_size % config.num_attention_heads == 0
205
+ self.hidden_size_per_attention_head = (
206
+ self.projection_size // config.num_attention_heads
207
+ )
208
+
209
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
210
+
211
+ self.c_proj = nn.Linear(
212
+ config.hidden_size, self.projection_size, bias=not config.no_bias
213
+ )
214
+
215
+ self.is_fp32 = not (config.bf16 or config.fp16)
216
+ if (
217
+ self.use_flash_attn
218
+ and flash_attn_unpadded_func is not None
219
+ and not self.is_fp32
220
+ ):
221
+ self.core_attention_flash = FlashSelfAttention(
222
+ causal=True, attention_dropout=config.attn_pdrop
223
+ )
224
+
225
+ self.bf16 = config.bf16
226
+
227
+ if config.rotary_pct == 1.0:
228
+ self.rotary_ndims = None
229
+ else:
230
+ assert config.rotary_pct < 1
231
+ self.rotary_ndims = int(
232
+ self.hidden_size_per_attention_head * config.rotary_pct
233
+ )
234
+ dim = (
235
+ self.rotary_ndims
236
+ if self.rotary_ndims is not None
237
+ else self.hidden_size_per_attention_head
238
+ )
239
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
240
+
241
+ self.use_dynamic_ntk = config.use_dynamic_ntk
242
+ self.use_logn_attn = config.use_logn_attn
243
+
244
+ logn_list = [
245
+ math.log(i, self.seq_length) if i > self.seq_length else 1
246
+ for i in range(1, 32768)
247
+ ]
248
+ self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
249
+ self._ntk_cached = 1.0
250
+
251
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
252
+
253
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
254
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
255
+
256
+ if self.scale_attn_weights:
257
+ attn_weights = attn_weights / torch.full(
258
+ [],
259
+ value.size(-1) ** 0.5,
260
+ dtype=attn_weights.dtype,
261
+ device=attn_weights.device,
262
+ )
263
+
264
+ query_length, key_length = query.size(-2), key.size(-2)
265
+ causal_mask = self.bias[
266
+ :, :, key_length - query_length : key_length, :key_length
267
+ ]
268
+ mask_value = torch.finfo(attn_weights.dtype).min
269
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
270
+ attn_weights.device
271
+ )
272
+ attn_weights = torch.where(
273
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
274
+ )
275
+
276
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
277
+
278
+ attn_weights = attn_weights.type(value.dtype)
279
+ attn_weights = self.attn_dropout(attn_weights)
280
+
281
+ if head_mask is not None:
282
+ attn_weights = attn_weights * head_mask
283
+
284
+ attn_output = torch.matmul(attn_weights, value)
285
+ attn_output = attn_output.transpose(1, 2)
286
+
287
+ return attn_output, attn_weights
288
+
289
+ def _upcast_and_reordered_attn(
290
+ self, query, key, value, attention_mask=None, head_mask=None
291
+ ):
292
+ bsz, num_heads, q_seq_len, dk = query.size()
293
+ _, _, k_seq_len, _ = key.size()
294
+
295
+ attn_weights = torch.empty(
296
+ bsz * num_heads,
297
+ q_seq_len,
298
+ k_seq_len,
299
+ dtype=torch.float32,
300
+ device=query.device,
301
+ )
302
+
303
+ scale_factor = 1.0
304
+ if self.scale_attn_weights:
305
+ scale_factor /= float(value.size(-1)) ** 0.5
306
+
307
+ with autocast(enabled=False):
308
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
309
+ -1, dk, k_seq_len
310
+ )
311
+ attn_weights = torch.baddbmm(
312
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
313
+ )
314
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
315
+
316
+ query_length, key_length = query.size(-2), key.size(-2)
317
+ causal_mask = self.bias[
318
+ :, :, key_length - query_length : key_length, :key_length
319
+ ]
320
+ mask_value = torch.finfo(attn_weights.dtype).min
321
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
322
+ attn_weights.device
323
+ )
324
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
325
+
326
+ if attention_mask is not None:
327
+ attn_weights = attn_weights + attention_mask
328
+
329
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
330
+
331
+ if attn_weights.dtype != torch.float32:
332
+ raise RuntimeError(
333
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
334
+ )
335
+ attn_weights = attn_weights.type(value.dtype)
336
+ attn_weights = self.attn_dropout(attn_weights)
337
+
338
+ if head_mask is not None:
339
+ attn_weights = attn_weights * head_mask
340
+
341
+ attn_output = torch.matmul(attn_weights, value)
342
+
343
+ return attn_output, attn_weights
344
+
345
+ def _split_heads(self, tensor, num_heads, attn_head_size):
346
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
347
+ tensor = tensor.view(new_shape)
348
+ return tensor
349
+
350
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
351
+ tensor = tensor.contiguous()
352
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
353
+ return tensor.view(new_shape)
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
358
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
359
+ attention_mask: Optional[torch.FloatTensor] = None,
360
+ head_mask: Optional[torch.FloatTensor] = None,
361
+ encoder_hidden_states: Optional[torch.Tensor] = None,
362
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
363
+ output_attentions: Optional[bool] = False,
364
+ use_cache: Optional[bool] = False,
365
+ ):
366
+
367
+ mixed_x_layer = self.c_attn(hidden_states)
368
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
369
+
370
+ query = self._split_heads(query, self.num_heads, self.head_dim)
371
+ key = self._split_heads(key, self.num_heads, self.head_dim)
372
+ value = self._split_heads(value, self.num_heads, self.head_dim)
373
+
374
+ kv_seq_len = hidden_states.size()[1]
375
+ if layer_past:
376
+ # layer past[0] shape: bs * seq_len * head_num * dim
377
+ kv_seq_len += layer_past[0].shape[1]
378
+ if (
379
+ self.use_dynamic_ntk
380
+ and kv_seq_len == hidden_states.size()[1]
381
+ and not self.training
382
+ ):
383
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
384
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
385
+ ntk_alpha = max(ntk_alpha, 1)
386
+ self._ntk_cached = ntk_alpha
387
+ else:
388
+ ntk_alpha = self._ntk_cached
389
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
390
+ hidden_states.device
391
+ )
392
+
393
+ if rotary_pos_emb is not None:
394
+ if isinstance(rotary_pos_emb, tuple):
395
+ rotary_pos_emb = rotary_pos_emb
396
+ else:
397
+ rotary_pos_emb = (rotary_pos_emb,) * 2
398
+
399
+ if rotary_pos_emb is not None:
400
+ q_pos_emb, k_pos_emb = rotary_pos_emb
401
+ # Slice the pos emb for current inference
402
+ cur_len = query.shape[1]
403
+ q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
404
+ k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
405
+ query = apply_rotary_pos_emb(query, q_pos_emb)
406
+ key = apply_rotary_pos_emb(key, k_pos_emb)
407
+
408
+ if layer_past is not None:
409
+ past_key, past_value = layer_past[0], layer_past[1]
410
+ key = torch.cat((past_key, key), dim=1)
411
+ value = torch.cat((past_value, value), dim=1)
412
+
413
+ if use_cache:
414
+ present = (key, value)
415
+ else:
416
+ present = None
417
+
418
+ if self.use_logn_attn and not self.training:
419
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
420
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
421
+ seq_start = key.size(1) - query.size(1)
422
+ seq_end = key.size(1)
423
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
424
+ query = query * logn_tensor.expand_as(query)
425
+
426
+ if (
427
+ self.use_flash_attn
428
+ and flash_attn_unpadded_func is not None
429
+ and not self.is_fp32
430
+ and query.is_cuda
431
+ ):
432
+ q, k, v = query, key, value
433
+ context_layer = self.core_attention_flash(q, k, v)
434
+
435
+ context_layer = rearrange(
436
+ context_layer, "b s h d -> b s (h d)"
437
+ ).contiguous()
438
+ else:
439
+ query = query.permute(0, 2, 1, 3)
440
+ key = key.permute(0, 2, 1, 3)
441
+ value = value.permute(0, 2, 1, 3)
442
+ attn_output, attn_weight = self._attn(
443
+ query, key, value, attention_mask, head_mask
444
+ )
445
+ context_layer = self._merge_heads(
446
+ attn_output, self.num_heads, self.head_dim
447
+ )
448
+
449
+ attn_output = self.c_proj(context_layer)
450
+ outputs = (attn_output, present)
451
+ if output_attentions:
452
+ if (
453
+ self.use_flash_attn
454
+ and flash_attn_unpadded_func is not None
455
+ and not self.is_fp32
456
+ ):
457
+ raise ValueError("Cannot output attentions while using flash-attn")
458
+ else:
459
+ outputs += (attn_weight,)
460
+
461
+ return outputs
462
+
463
+
464
+ class QWenMLP(nn.Module):
465
+ def __init__(self, config):
466
+ super().__init__()
467
+ self.w1 = nn.Linear(
468
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
469
+ )
470
+ self.w2 = nn.Linear(
471
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
472
+ )
473
+ ff_dim_in = config.ffn_hidden_size // 2
474
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
475
+
476
+ def forward(self, hidden_states):
477
+ a1 = self.w1(hidden_states)
478
+ a2 = self.w2(hidden_states)
479
+ intermediate_parallel = a1 * F.silu(a2)
480
+ output = self.c_proj(intermediate_parallel)
481
+ return output
482
+
483
+
484
+ class QWenBlock(nn.Module):
485
+ def __init__(self, config, layer_idx=None, num_expert=1):
486
+ super().__init__()
487
+ self.num_expert = num_expert
488
+ self.layer_number = layer_idx
489
+ self.apply_residual_connection_post_layernorm = (
490
+ config.apply_residual_connection_post_layernorm
491
+ )
492
+ hidden_size = config.hidden_size
493
+ self.apply_residual_connection_post_layernorm = (
494
+ config.apply_residual_connection_post_layernorm
495
+ )
496
+ self.bf16 = config.bf16
497
+
498
+ self.ln_1 = RMSNorm(
499
+ hidden_size,
500
+ eps=config.layer_norm_epsilon,
501
+ )
502
+ self.attn = QWenAttention(config, layer_number=layer_idx)
503
+ self.ln_2 = RMSNorm(
504
+ hidden_size,
505
+ eps=config.layer_norm_epsilon,
506
+ )
507
+
508
+ self.mlp = QWenMLP(config)
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
513
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
514
+ attention_mask: Optional[torch.FloatTensor] = None,
515
+ head_mask: Optional[torch.FloatTensor] = None,
516
+ encoder_hidden_states: Optional[torch.Tensor] = None,
517
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
518
+ use_cache: Optional[bool] = False,
519
+ output_attentions: Optional[bool] = False,
520
+ ):
521
+ layernorm_output = self.ln_1(hidden_states)
522
+
523
+ attn_outputs = self.attn(
524
+ layernorm_output,
525
+ layer_past=layer_past,
526
+ attention_mask=attention_mask,
527
+ head_mask=head_mask,
528
+ use_cache=use_cache,
529
+ output_attentions=output_attentions,
530
+ )
531
+ attn_output = attn_outputs[0]
532
+
533
+ outputs = attn_outputs[1:]
534
+
535
+ if self.apply_residual_connection_post_layernorm:
536
+ residual = layernorm_output
537
+ else:
538
+ residual = hidden_states
539
+ layernorm_input = attn_output + residual
540
+
541
+ layernorm_output = self.ln_2(layernorm_input)
542
+
543
+ if self.apply_residual_connection_post_layernorm:
544
+ residual = layernorm_output
545
+ else:
546
+ residual = layernorm_input
547
+
548
+ mlp_output = self.mlp(layernorm_output)
549
+ hidden_states = residual + mlp_output
550
+
551
+ if use_cache:
552
+ outputs = (hidden_states,) + outputs
553
+ else:
554
+ outputs = (hidden_states,) + outputs[1:]
555
+
556
+ return outputs
557
+
558
+
559
+ class QWenPreTrainedModel(PreTrainedModel):
560
+ config_class = QWenConfig
561
+ base_model_prefix = "transformer"
562
+ is_parallelizable = False
563
+ supports_gradient_checkpointing = True
564
+ _no_split_modules = ["QWenBlock"]
565
+
566
+ def __init__(self, *inputs, **kwargs):
567
+ super().__init__(*inputs, **kwargs)
568
+
569
+ def _init_weights(self, module):
570
+ """Initialize the weights."""
571
+ if isinstance(module, nn.Linear):
572
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
573
+ if module.bias is not None:
574
+ module.bias.data.zero_()
575
+ elif isinstance(module, nn.Embedding):
576
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
577
+ if module.padding_idx is not None:
578
+ module.weight.data[module.padding_idx].zero_()
579
+ elif isinstance(module, RMSNorm):
580
+ module.weight.data.fill_(1.0)
581
+
582
+ for name, p in module.named_parameters():
583
+ if name == "c_proj.weight":
584
+ p.data.normal_(
585
+ mean=0.0,
586
+ std=(
587
+ self.config.initializer_range
588
+ / math.sqrt(2 * self.config.n_layer)
589
+ ),
590
+ )
591
+
592
+ def _set_gradient_checkpointing(self, module, value=False):
593
+ if isinstance(module, QWenModel):
594
+ module.gradient_checkpointing = value
595
+
596
+
597
+ class QWenModel(QWenPreTrainedModel):
598
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
599
+
600
+ def __init__(self, config):
601
+ super().__init__(config)
602
+ self.vocab_size = config.padded_vocab_size
603
+ self.num_hidden_layers = config.num_hidden_layers
604
+ self.embed_dim = config.hidden_size
605
+
606
+ max_sequence_length = config.max_position_embeddings
607
+ self.position_embedding_type = config.pos_emb
608
+ self.gradient_checkpointing = False
609
+
610
+ if self.position_embedding_type == "learned":
611
+ self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
612
+ self.init_method(self.position_embeddings.weight)
613
+ self._position_embeddings_key = "position_embeddings"
614
+ self.init_method(self.position_embeddings.weight)
615
+ else:
616
+ self.wpe = None
617
+ self._position_embeddings_key = ""
618
+
619
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
620
+
621
+ self.drop = nn.Dropout(config.embd_pdrop)
622
+ self.h = nn.ModuleList(
623
+ [
624
+ QWenBlock(
625
+ config,
626
+ layer_idx=i,
627
+ )
628
+ for i in range(config.num_hidden_layers)
629
+ ]
630
+ )
631
+ self.ln_f = RMSNorm(
632
+ self.embed_dim,
633
+ eps=config.layer_norm_epsilon,
634
+ )
635
+
636
+ self.post_init()
637
+
638
+ def get_input_embeddings(self):
639
+ return self.wte
640
+
641
+ def set_input_embeddings(self, new_embeddings):
642
+ self.wte = new_embeddings
643
+
644
+ def forward(
645
+ self,
646
+ input_ids: Optional[torch.LongTensor] = None,
647
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
648
+ attention_mask: Optional[torch.FloatTensor] = None,
649
+ token_type_ids: Optional[torch.LongTensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ head_mask: Optional[torch.FloatTensor] = None,
652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
653
+ encoder_hidden_states: Optional[torch.Tensor] = None,
654
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
655
+ use_cache: Optional[bool] = None,
656
+ output_attentions: Optional[bool] = None,
657
+ output_hidden_states: Optional[bool] = None,
658
+ return_dict: Optional[bool] = None,
659
+ ):
660
+ output_attentions = (
661
+ output_attentions
662
+ if output_attentions is not None
663
+ else self.config.output_attentions
664
+ )
665
+ output_hidden_states = (
666
+ output_hidden_states
667
+ if output_hidden_states is not None
668
+ else self.config.output_hidden_states
669
+ )
670
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
671
+ return_dict = (
672
+ return_dict if return_dict is not None else self.config.use_return_dict
673
+ )
674
+
675
+ if input_ids is not None and inputs_embeds is not None:
676
+ raise ValueError(
677
+ "You cannot specify both input_ids and inputs_embeds at the same time"
678
+ )
679
+ elif input_ids is not None:
680
+ input_shape = input_ids.size()
681
+ input_ids = input_ids.view(-1, input_shape[-1])
682
+ batch_size = input_ids.shape[0]
683
+ elif inputs_embeds is not None:
684
+ input_shape = inputs_embeds.size()[:-1]
685
+ batch_size = inputs_embeds.shape[0]
686
+ else:
687
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
688
+
689
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
690
+
691
+ if token_type_ids is not None:
692
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
693
+ if position_ids is not None:
694
+ position_ids = position_ids.view(-1, input_shape[-1])
695
+
696
+ if past_key_values is None:
697
+ past_length = 0
698
+ past_key_values = tuple([None] * len(self.h))
699
+ else:
700
+ past_length = past_key_values[0][0].size(-2)
701
+
702
+ if position_ids is None:
703
+ position_ids = torch.arange(
704
+ past_length,
705
+ input_shape[-1] + past_length,
706
+ dtype=torch.long,
707
+ device=device,
708
+ )
709
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
710
+
711
+ if attention_mask is not None:
712
+ if batch_size <= 0:
713
+ raise ValueError("batch_size has to be defined and > 0")
714
+ attention_mask = attention_mask.view(batch_size, -1)
715
+ attention_mask = attention_mask[:, None, None, :]
716
+ attention_mask = attention_mask.to(dtype=self.dtype)
717
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
718
+
719
+ encoder_attention_mask = None
720
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
721
+
722
+ if inputs_embeds is None:
723
+ inputs_embeds = self.wte(input_ids)
724
+ hidden_states = inputs_embeds
725
+ if self.wpe is not None:
726
+ position_embeds = self.wpe(position_ids)
727
+ hidden_states = hidden_states + position_embeds
728
+
729
+ hidden_states = self.drop(hidden_states)
730
+ output_shape = input_shape + (hidden_states.size(-1),)
731
+
732
+ if self.gradient_checkpointing and self.training:
733
+ if use_cache:
734
+ logger.warning_once(
735
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
736
+ )
737
+ use_cache = False
738
+
739
+ presents = () if use_cache else None
740
+ all_self_attentions = () if output_attentions else None
741
+ all_hidden_states = () if output_hidden_states else None
742
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
743
+
744
+ if output_hidden_states:
745
+ all_hidden_states = all_hidden_states + (hidden_states,)
746
+
747
+ if self.gradient_checkpointing and self.training:
748
+
749
+ def create_custom_forward(module):
750
+ def custom_forward(*inputs):
751
+ # None for past_key_value
752
+ return module(*inputs, use_cache, output_attentions)
753
+
754
+ return custom_forward
755
+
756
+ outputs = torch.utils.checkpoint.checkpoint(
757
+ create_custom_forward(block),
758
+ hidden_states,
759
+ None,
760
+ attention_mask,
761
+ head_mask[i],
762
+ encoder_hidden_states,
763
+ encoder_attention_mask,
764
+ )
765
+ else:
766
+ outputs = block(
767
+ hidden_states,
768
+ layer_past=layer_past,
769
+ attention_mask=attention_mask,
770
+ head_mask=head_mask[i],
771
+ encoder_hidden_states=encoder_hidden_states,
772
+ encoder_attention_mask=encoder_attention_mask,
773
+ use_cache=use_cache,
774
+ output_attentions=output_attentions,
775
+ )
776
+
777
+ hidden_states = outputs[0]
778
+ if use_cache is True:
779
+ presents = presents + (outputs[2 if output_attentions else 1],)
780
+
781
+ if output_attentions:
782
+ all_self_attentions = all_self_attentions + (outputs[1],)
783
+
784
+ hidden_states = self.ln_f(hidden_states)
785
+ hidden_states = hidden_states.view(output_shape)
786
+
787
+ if not return_dict:
788
+ return tuple(
789
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
790
+ )
791
+
792
+ return BaseModelOutputWithPast(
793
+ last_hidden_state=hidden_states,
794
+ past_key_values=presents,
795
+ hidden_states=all_hidden_states,
796
+ attentions=all_self_attentions,
797
+ )
798
+
799
+
800
+ class QWenLMHeadModel(QWenPreTrainedModel):
801
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
802
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
803
+
804
+ def __init__(self, config):
805
+ super().__init__(config)
806
+ assert (
807
+ config.bf16 + config.fp16 + config.fp32 <= 1
808
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
809
+
810
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
811
+
812
+ if autoset_precision:
813
+ if SUPPORT_BF16:
814
+ logger.warn(
815
+ "The model is automatically converting to bf16 for faster inference. "
816
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
817
+ )
818
+ config.bf16 = True
819
+ elif SUPPORT_FP16:
820
+ logger.warn(
821
+ "The model is automatically converting to fp16 for faster inference. "
822
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
823
+ )
824
+ config.fp16 = True
825
+ else:
826
+ config.fp32 = True
827
+
828
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
829
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
830
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
831
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
832
+ if config.fp32:
833
+ if SUPPORT_BF16:
834
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
835
+ elif SUPPORT_FP16:
836
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
837
+
838
+ if config.use_flash_attn == "auto":
839
+ if config.bf16 or config.fp16:
840
+ logger.warn("Try importing flash-attention for faster inference...")
841
+ config.use_flash_attn = True
842
+ else:
843
+ config.use_flash_attn = False
844
+ if config.use_flash_attn and config.fp32:
845
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
846
+
847
+ if config.use_flash_attn:
848
+ _import_flash_attn()
849
+
850
+ self.transformer = QWenModel(config)
851
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
852
+
853
+ if config.bf16:
854
+ self.transformer.bfloat16()
855
+ self.lm_head.bfloat16()
856
+ if config.fp16:
857
+ self.transformer.half()
858
+ self.lm_head.half()
859
+ self.post_init()
860
+
861
+ def get_output_embeddings(self):
862
+ return self.lm_head
863
+
864
+ def set_output_embeddings(self, new_embeddings):
865
+ self.lm_head = new_embeddings
866
+
867
+ def prepare_inputs_for_generation(
868
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
869
+ ):
870
+ token_type_ids = kwargs.get("token_type_ids", None)
871
+ if past_key_values:
872
+ input_ids = input_ids[:, -1].unsqueeze(-1)
873
+ if token_type_ids is not None:
874
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
875
+
876
+ attention_mask = kwargs.get("attention_mask", None)
877
+ position_ids = kwargs.get("position_ids", None)
878
+
879
+ if attention_mask is not None and position_ids is None:
880
+ position_ids = attention_mask.long().cumsum(-1) - 1
881
+ position_ids.masked_fill_(attention_mask == 0, 1)
882
+ if past_key_values:
883
+ position_ids = position_ids[:, -1].unsqueeze(-1)
884
+ else:
885
+ position_ids = None
886
+
887
+ if inputs_embeds is not None and past_key_values is None:
888
+ model_inputs = {"inputs_embeds": inputs_embeds}
889
+ else:
890
+ model_inputs = {"input_ids": input_ids}
891
+
892
+ model_inputs.update(
893
+ {
894
+ "past_key_values": past_key_values,
895
+ "use_cache": kwargs.get("use_cache"),
896
+ "position_ids": position_ids,
897
+ "attention_mask": attention_mask,
898
+ "token_type_ids": token_type_ids,
899
+ }
900
+ )
901
+ return model_inputs
902
+
903
+ def forward(
904
+ self,
905
+ input_ids: Optional[torch.LongTensor] = None,
906
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
907
+ attention_mask: Optional[torch.FloatTensor] = None,
908
+ token_type_ids: Optional[torch.LongTensor] = None,
909
+ position_ids: Optional[torch.LongTensor] = None,
910
+ head_mask: Optional[torch.FloatTensor] = None,
911
+ inputs_embeds: Optional[torch.FloatTensor] = None,
912
+ encoder_hidden_states: Optional[torch.Tensor] = None,
913
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
914
+ labels: Optional[torch.LongTensor] = None,
915
+ use_cache: Optional[bool] = None,
916
+ output_attentions: Optional[bool] = None,
917
+ output_hidden_states: Optional[bool] = None,
918
+ return_dict: Optional[bool] = None,
919
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
920
+
921
+ return_dict = (
922
+ return_dict if return_dict is not None else self.config.use_return_dict
923
+ )
924
+
925
+ transformer_outputs = self.transformer(
926
+ input_ids,
927
+ past_key_values=past_key_values,
928
+ attention_mask=attention_mask,
929
+ token_type_ids=token_type_ids,
930
+ position_ids=position_ids,
931
+ head_mask=head_mask,
932
+ inputs_embeds=inputs_embeds,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ encoder_attention_mask=encoder_attention_mask,
935
+ use_cache=use_cache,
936
+ output_attentions=output_attentions,
937
+ output_hidden_states=output_hidden_states,
938
+ return_dict=return_dict,
939
+ )
940
+ hidden_states = transformer_outputs[0]
941
+
942
+ lm_logits = self.lm_head(hidden_states)
943
+
944
+ loss = None
945
+ if labels is not None:
946
+ labels = labels.to(lm_logits.device)
947
+ shift_logits = lm_logits[..., :-1, :].contiguous()
948
+ shift_labels = labels[..., 1:].contiguous()
949
+ loss_fct = CrossEntropyLoss()
950
+ loss = loss_fct(
951
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
952
+ )
953
+
954
+ if not return_dict:
955
+ output = (lm_logits,) + transformer_outputs[1:]
956
+ return ((loss,) + output) if loss is not None else output
957
+
958
+ return CausalLMOutputWithPast(
959
+ loss=loss,
960
+ logits=lm_logits,
961
+ past_key_values=transformer_outputs.past_key_values,
962
+ hidden_states=transformer_outputs.hidden_states,
963
+ attentions=transformer_outputs.attentions,
964
+ )
965
+
966
+ @staticmethod
967
+ def _reorder_cache(
968
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
969
+ ) -> Tuple[Tuple[torch.Tensor]]:
970
+
971
+ return tuple(
972
+ tuple(
973
+ past_state.index_select(0, beam_idx.to(past_state.device))
974
+ for past_state in layer_past
975
+ )
976
+ for layer_past in past_key_values
977
+ )
978
+
979
+ def chat(
980
+ self,
981
+ tokenizer: PreTrainedTokenizer,
982
+ query: str,
983
+ history: Optional[HistoryType],
984
+ system: str = "You are a helpful assistant.",
985
+ append_history: bool = True,
986
+ stream: Optional[bool] = _SENTINEL,
987
+ stop_words_ids: Optional[List[List[int]]] = None,
988
+ **kwargs,
989
+ ) -> Tuple[str, HistoryType]:
990
+ assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
991
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
992
+ if history is None:
993
+ history = []
994
+ if stop_words_ids is None:
995
+ stop_words_ids = []
996
+
997
+ raw_text, context_tokens = make_context(
998
+ tokenizer,
999
+ query,
1000
+ history=history,
1001
+ system=system,
1002
+ max_window_size=6144,
1003
+ chat_format=self.generation_config.chat_format,
1004
+ )
1005
+
1006
+ stop_words_ids.extend(get_stop_words_ids(
1007
+ self.generation_config.chat_format, tokenizer
1008
+ ))
1009
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1010
+ outputs = self.generate(
1011
+ input_ids,
1012
+ stop_words_ids = stop_words_ids,
1013
+ return_dict_in_generate = False,
1014
+ **kwargs,
1015
+ )
1016
+
1017
+ response = decode_tokens(
1018
+ outputs[0],
1019
+ tokenizer,
1020
+ raw_text_len=len(raw_text),
1021
+ context_length=len(context_tokens),
1022
+ chat_format=self.generation_config.chat_format,
1023
+ verbose=False,
1024
+ errors='replace'
1025
+ )
1026
+
1027
+ if append_history:
1028
+ history.append((query, response))
1029
+
1030
+ return response, history
1031
+
1032
+ def chat_stream(
1033
+ self,
1034
+ tokenizer: PreTrainedTokenizer,
1035
+ query: str,
1036
+ history: Optional[HistoryType],
1037
+ system: str = "You are a helpful assistant.",
1038
+ stop_words_ids: Optional[List[List[int]]] = None,
1039
+ logits_processor: Optional[LogitsProcessorList] = None,
1040
+ **kwargs,
1041
+ ) -> Generator[str, Any, None]:
1042
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1043
+ if history is None:
1044
+ history = []
1045
+ if stop_words_ids is None:
1046
+ stop_words_ids = []
1047
+
1048
+ raw_text, context_tokens = make_context(
1049
+ tokenizer,
1050
+ query,
1051
+ history=history,
1052
+ system=system,
1053
+ max_window_size=6144,
1054
+ chat_format=self.generation_config.chat_format,
1055
+ )
1056
+
1057
+ stop_words_ids.extend(get_stop_words_ids(
1058
+ self.generation_config.chat_format, tokenizer
1059
+ ))
1060
+ if stop_words_ids is not None:
1061
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1062
+ stop_words_ids=stop_words_ids,
1063
+ eos_token_id=self.generation_config.eos_token_id,
1064
+ )
1065
+ if logits_processor is None:
1066
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1067
+ else:
1068
+ logits_processor.append(stop_words_logits_processor)
1069
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1070
+
1071
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1072
+ self.__class__.generate_stream = NewGenerationMixin.generate
1073
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
1074
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1075
+ def stream_generator():
1076
+ outputs = []
1077
+ for token in self.generate_stream(
1078
+ input_ids,
1079
+ return_dict_in_generate=False,
1080
+ generation_config=stream_config,
1081
+ logits_processor=logits_processor,
1082
+ seed=-1,
1083
+ **kwargs):
1084
+ outputs.append(token.item())
1085
+ yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
1086
+
1087
+ return stream_generator()
1088
+
1089
+ def generate(
1090
+ self,
1091
+ inputs: Optional[torch.Tensor] = None,
1092
+ generation_config: Optional[GenerationConfig] = None,
1093
+ logits_processor: Optional[LogitsProcessorList] = None,
1094
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1095
+ prefix_allowed_tokens_fn: Optional[
1096
+ Callable[[int, torch.Tensor], List[int]]
1097
+ ] = None,
1098
+ synced_gpus: Optional[bool] = None,
1099
+ assistant_model: Optional["PreTrainedModel"] = None,
1100
+ streamer: Optional["BaseStreamer"] = None,
1101
+ **kwargs,
1102
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1103
+ # Process stop_words_ids.
1104
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
1105
+ if stop_words_ids is None and generation_config is not None:
1106
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1107
+ if stop_words_ids is None:
1108
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
1109
+
1110
+ if stop_words_ids is not None:
1111
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1112
+ stop_words_ids=stop_words_ids,
1113
+ eos_token_id=self.generation_config.eos_token_id,
1114
+ )
1115
+ if logits_processor is None:
1116
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1117
+ else:
1118
+ logits_processor.append(stop_words_logits_processor)
1119
+
1120
+ return super().generate(
1121
+ inputs,
1122
+ generation_config=generation_config,
1123
+ logits_processor=logits_processor,
1124
+ stopping_criteria=stopping_criteria,
1125
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1126
+ synced_gpus=synced_gpus,
1127
+ assistant_model=assistant_model,
1128
+ streamer=streamer,
1129
+ **kwargs,
1130
+ )
1131
+
1132
+
1133
+ class RotaryEmbedding(torch.nn.Module):
1134
+ def __init__(self, dim, base=10000):
1135
+ super().__init__()
1136
+ self.dim = dim
1137
+ self.base = base
1138
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1139
+ if importlib.util.find_spec("einops") is None:
1140
+ raise RuntimeError("einops is required for Rotary Embedding")
1141
+
1142
+ self._rotary_pos_emb_cache = None
1143
+ self._seq_len_cached = 0
1144
+ self._ntk_alpha_cached = 1.0
1145
+
1146
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1147
+ seqlen = max_seq_len + offset
1148
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1149
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1150
+ self.inv_freq = 1.0 / (
1151
+ base
1152
+ ** (
1153
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1154
+ / self.dim
1155
+ )
1156
+ )
1157
+ self._seq_len_cached = max(2 * seqlen, 16)
1158
+ self._ntk_alpha_cached = ntk_alpha
1159
+ seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1160
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1161
+ emb = torch.cat((freqs, freqs), dim=-1)
1162
+ from einops import rearrange
1163
+
1164
+ self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
1165
+
1166
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1167
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1168
+ return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
1169
+
1170
+
1171
+ def _rotate_half(x):
1172
+ from einops import rearrange
1173
+
1174
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1175
+ x1, x2 = x.unbind(dim=-2)
1176
+ return torch.cat((-x2, x1), dim=-1)
1177
+
1178
+
1179
+ def apply_rotary_pos_emb(t, freqs):
1180
+ if apply_rotary_emb_func is not None and t.is_cuda:
1181
+ t_ = t.float()
1182
+ freqs = freqs.squeeze(0).squeeze(1)
1183
+ cos = freqs[:, : freqs.shape[-1] // 2].cos()
1184
+ sin = freqs[:, : freqs.shape[-1] // 2].sin()
1185
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1186
+ return output
1187
+ else:
1188
+ rot_dim = freqs.shape[-1]
1189
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1190
+ t_ = t_.float()
1191
+ t_pass_ = t_pass_.float()
1192
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1193
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1194
+
1195
+
1196
+ class RMSNorm(torch.nn.Module):
1197
+ def __init__(self, dim: int, eps: float = 1e-6):
1198
+ super().__init__()
1199
+ self.eps = eps
1200
+ self.weight = nn.Parameter(torch.ones(dim))
1201
+
1202
+ def _norm(self, x):
1203
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1204
+
1205
+ def forward(self, x):
1206
+ if rms_norm is not None and x.is_cuda:
1207
+ return rms_norm(x, self.weight, self.eps)
1208
+ else:
1209
+ output = self._norm(x.float()).type_as(x)
1210
+ return output * self.weight
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af49b42893f19f71b0dbb0d500d2da4f62b4e2fe0e46ab269750b0026d43357f
3
+ size 9969772092
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d3aefb189675af20dbe2f4678769a6762ddacfcbf16f4cfe25400af23f98c8f
3
+ size 5472963479
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15442649088
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "transformer.h.0.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
8
+ "transformer.h.0.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
9
+ "transformer.h.0.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "transformer.h.0.ln_1.weight": "pytorch_model-00001-of-00002.bin",
11
+ "transformer.h.0.ln_2.weight": "pytorch_model-00001-of-00002.bin",
12
+ "transformer.h.0.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
13
+ "transformer.h.0.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
14
+ "transformer.h.0.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
15
+ "transformer.h.1.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
16
+ "transformer.h.1.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
17
+ "transformer.h.1.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "transformer.h.1.ln_1.weight": "pytorch_model-00001-of-00002.bin",
19
+ "transformer.h.1.ln_2.weight": "pytorch_model-00001-of-00002.bin",
20
+ "transformer.h.1.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "transformer.h.1.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
22
+ "transformer.h.1.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
23
+ "transformer.h.10.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
24
+ "transformer.h.10.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
25
+ "transformer.h.10.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "transformer.h.10.ln_1.weight": "pytorch_model-00001-of-00002.bin",
27
+ "transformer.h.10.ln_2.weight": "pytorch_model-00001-of-00002.bin",
28
+ "transformer.h.10.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
29
+ "transformer.h.10.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
30
+ "transformer.h.10.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
31
+ "transformer.h.11.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
32
+ "transformer.h.11.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
33
+ "transformer.h.11.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "transformer.h.11.ln_1.weight": "pytorch_model-00001-of-00002.bin",
35
+ "transformer.h.11.ln_2.weight": "pytorch_model-00001-of-00002.bin",
36
+ "transformer.h.11.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
37
+ "transformer.h.11.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
38
+ "transformer.h.11.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
39
+ "transformer.h.12.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
40
+ "transformer.h.12.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
41
+ "transformer.h.12.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "transformer.h.12.ln_1.weight": "pytorch_model-00001-of-00002.bin",
43
+ "transformer.h.12.ln_2.weight": "pytorch_model-00001-of-00002.bin",
44
+ "transformer.h.12.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "transformer.h.12.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
46
+ "transformer.h.12.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
47
+ "transformer.h.13.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
48
+ "transformer.h.13.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
49
+ "transformer.h.13.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "transformer.h.13.ln_1.weight": "pytorch_model-00001-of-00002.bin",
51
+ "transformer.h.13.ln_2.weight": "pytorch_model-00001-of-00002.bin",
52
+ "transformer.h.13.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
53
+ "transformer.h.13.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
54
+ "transformer.h.13.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
55
+ "transformer.h.14.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
56
+ "transformer.h.14.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
57
+ "transformer.h.14.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "transformer.h.14.ln_1.weight": "pytorch_model-00001-of-00002.bin",
59
+ "transformer.h.14.ln_2.weight": "pytorch_model-00001-of-00002.bin",
60
+ "transformer.h.14.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "transformer.h.14.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
62
+ "transformer.h.14.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
63
+ "transformer.h.15.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
64
+ "transformer.h.15.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
65
+ "transformer.h.15.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "transformer.h.15.ln_1.weight": "pytorch_model-00001-of-00002.bin",
67
+ "transformer.h.15.ln_2.weight": "pytorch_model-00001-of-00002.bin",
68
+ "transformer.h.15.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
69
+ "transformer.h.15.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
70
+ "transformer.h.15.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
71
+ "transformer.h.16.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
72
+ "transformer.h.16.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
73
+ "transformer.h.16.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "transformer.h.16.ln_1.weight": "pytorch_model-00001-of-00002.bin",
75
+ "transformer.h.16.ln_2.weight": "pytorch_model-00001-of-00002.bin",
76
+ "transformer.h.16.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
77
+ "transformer.h.16.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
78
+ "transformer.h.16.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
79
+ "transformer.h.17.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
80
+ "transformer.h.17.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
81
+ "transformer.h.17.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "transformer.h.17.ln_1.weight": "pytorch_model-00001-of-00002.bin",
83
+ "transformer.h.17.ln_2.weight": "pytorch_model-00001-of-00002.bin",
84
+ "transformer.h.17.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "transformer.h.17.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
86
+ "transformer.h.17.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
87
+ "transformer.h.18.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
88
+ "transformer.h.18.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
89
+ "transformer.h.18.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "transformer.h.18.ln_1.weight": "pytorch_model-00001-of-00002.bin",
91
+ "transformer.h.18.ln_2.weight": "pytorch_model-00001-of-00002.bin",
92
+ "transformer.h.18.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
93
+ "transformer.h.18.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
94
+ "transformer.h.18.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
95
+ "transformer.h.19.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
96
+ "transformer.h.19.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
97
+ "transformer.h.19.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "transformer.h.19.ln_1.weight": "pytorch_model-00001-of-00002.bin",
99
+ "transformer.h.19.ln_2.weight": "pytorch_model-00001-of-00002.bin",
100
+ "transformer.h.19.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "transformer.h.19.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
102
+ "transformer.h.19.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
103
+ "transformer.h.2.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
104
+ "transformer.h.2.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
105
+ "transformer.h.2.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "transformer.h.2.ln_1.weight": "pytorch_model-00001-of-00002.bin",
107
+ "transformer.h.2.ln_2.weight": "pytorch_model-00001-of-00002.bin",
108
+ "transformer.h.2.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
109
+ "transformer.h.2.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
110
+ "transformer.h.2.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
111
+ "transformer.h.20.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
112
+ "transformer.h.20.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
113
+ "transformer.h.20.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "transformer.h.20.ln_1.weight": "pytorch_model-00001-of-00002.bin",
115
+ "transformer.h.20.ln_2.weight": "pytorch_model-00001-of-00002.bin",
116
+ "transformer.h.20.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
117
+ "transformer.h.20.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
118
+ "transformer.h.20.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
119
+ "transformer.h.21.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
120
+ "transformer.h.21.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
121
+ "transformer.h.21.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "transformer.h.21.ln_1.weight": "pytorch_model-00001-of-00002.bin",
123
+ "transformer.h.21.ln_2.weight": "pytorch_model-00001-of-00002.bin",
124
+ "transformer.h.21.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
125
+ "transformer.h.21.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
126
+ "transformer.h.21.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
127
+ "transformer.h.22.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
128
+ "transformer.h.22.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
129
+ "transformer.h.22.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
130
+ "transformer.h.22.ln_1.weight": "pytorch_model-00002-of-00002.bin",
131
+ "transformer.h.22.ln_2.weight": "pytorch_model-00002-of-00002.bin",
132
+ "transformer.h.22.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
133
+ "transformer.h.22.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
134
+ "transformer.h.22.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
135
+ "transformer.h.23.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
136
+ "transformer.h.23.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
137
+ "transformer.h.23.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
138
+ "transformer.h.23.ln_1.weight": "pytorch_model-00002-of-00002.bin",
139
+ "transformer.h.23.ln_2.weight": "pytorch_model-00002-of-00002.bin",
140
+ "transformer.h.23.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
141
+ "transformer.h.23.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
142
+ "transformer.h.23.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
143
+ "transformer.h.24.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
144
+ "transformer.h.24.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
145
+ "transformer.h.24.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
146
+ "transformer.h.24.ln_1.weight": "pytorch_model-00002-of-00002.bin",
147
+ "transformer.h.24.ln_2.weight": "pytorch_model-00002-of-00002.bin",
148
+ "transformer.h.24.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
149
+ "transformer.h.24.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
150
+ "transformer.h.24.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
151
+ "transformer.h.25.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
152
+ "transformer.h.25.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
153
+ "transformer.h.25.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
154
+ "transformer.h.25.ln_1.weight": "pytorch_model-00002-of-00002.bin",
155
+ "transformer.h.25.ln_2.weight": "pytorch_model-00002-of-00002.bin",
156
+ "transformer.h.25.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
157
+ "transformer.h.25.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
158
+ "transformer.h.25.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
159
+ "transformer.h.26.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
160
+ "transformer.h.26.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
161
+ "transformer.h.26.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
162
+ "transformer.h.26.ln_1.weight": "pytorch_model-00002-of-00002.bin",
163
+ "transformer.h.26.ln_2.weight": "pytorch_model-00002-of-00002.bin",
164
+ "transformer.h.26.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
165
+ "transformer.h.26.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
166
+ "transformer.h.26.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
167
+ "transformer.h.27.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
168
+ "transformer.h.27.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
169
+ "transformer.h.27.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
170
+ "transformer.h.27.ln_1.weight": "pytorch_model-00002-of-00002.bin",
171
+ "transformer.h.27.ln_2.weight": "pytorch_model-00002-of-00002.bin",
172
+ "transformer.h.27.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
173
+ "transformer.h.27.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
174
+ "transformer.h.27.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
175
+ "transformer.h.28.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
176
+ "transformer.h.28.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
177
+ "transformer.h.28.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
178
+ "transformer.h.28.ln_1.weight": "pytorch_model-00002-of-00002.bin",
179
+ "transformer.h.28.ln_2.weight": "pytorch_model-00002-of-00002.bin",
180
+ "transformer.h.28.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "transformer.h.28.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
182
+ "transformer.h.28.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
183
+ "transformer.h.29.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
184
+ "transformer.h.29.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
185
+ "transformer.h.29.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "transformer.h.29.ln_1.weight": "pytorch_model-00002-of-00002.bin",
187
+ "transformer.h.29.ln_2.weight": "pytorch_model-00002-of-00002.bin",
188
+ "transformer.h.29.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
189
+ "transformer.h.29.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
190
+ "transformer.h.29.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
191
+ "transformer.h.3.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
192
+ "transformer.h.3.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
193
+ "transformer.h.3.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
194
+ "transformer.h.3.ln_1.weight": "pytorch_model-00001-of-00002.bin",
195
+ "transformer.h.3.ln_2.weight": "pytorch_model-00001-of-00002.bin",
196
+ "transformer.h.3.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
197
+ "transformer.h.3.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
198
+ "transformer.h.3.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
199
+ "transformer.h.30.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
200
+ "transformer.h.30.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
201
+ "transformer.h.30.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "transformer.h.30.ln_1.weight": "pytorch_model-00002-of-00002.bin",
203
+ "transformer.h.30.ln_2.weight": "pytorch_model-00002-of-00002.bin",
204
+ "transformer.h.30.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "transformer.h.30.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
206
+ "transformer.h.30.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
207
+ "transformer.h.31.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
208
+ "transformer.h.31.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
209
+ "transformer.h.31.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "transformer.h.31.ln_1.weight": "pytorch_model-00002-of-00002.bin",
211
+ "transformer.h.31.ln_2.weight": "pytorch_model-00002-of-00002.bin",
212
+ "transformer.h.31.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
213
+ "transformer.h.31.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
214
+ "transformer.h.31.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
215
+ "transformer.h.4.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
216
+ "transformer.h.4.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
217
+ "transformer.h.4.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
218
+ "transformer.h.4.ln_1.weight": "pytorch_model-00001-of-00002.bin",
219
+ "transformer.h.4.ln_2.weight": "pytorch_model-00001-of-00002.bin",
220
+ "transformer.h.4.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
221
+ "transformer.h.4.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
222
+ "transformer.h.4.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
223
+ "transformer.h.5.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
224
+ "transformer.h.5.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
225
+ "transformer.h.5.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
226
+ "transformer.h.5.ln_1.weight": "pytorch_model-00001-of-00002.bin",
227
+ "transformer.h.5.ln_2.weight": "pytorch_model-00001-of-00002.bin",
228
+ "transformer.h.5.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
229
+ "transformer.h.5.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
230
+ "transformer.h.5.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
231
+ "transformer.h.6.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
232
+ "transformer.h.6.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
233
+ "transformer.h.6.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
234
+ "transformer.h.6.ln_1.weight": "pytorch_model-00001-of-00002.bin",
235
+ "transformer.h.6.ln_2.weight": "pytorch_model-00001-of-00002.bin",
236
+ "transformer.h.6.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
237
+ "transformer.h.6.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
238
+ "transformer.h.6.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
239
+ "transformer.h.7.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
240
+ "transformer.h.7.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
241
+ "transformer.h.7.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "transformer.h.7.ln_1.weight": "pytorch_model-00001-of-00002.bin",
243
+ "transformer.h.7.ln_2.weight": "pytorch_model-00001-of-00002.bin",
244
+ "transformer.h.7.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "transformer.h.7.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
246
+ "transformer.h.7.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
247
+ "transformer.h.8.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
248
+ "transformer.h.8.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
249
+ "transformer.h.8.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
250
+ "transformer.h.8.ln_1.weight": "pytorch_model-00001-of-00002.bin",
251
+ "transformer.h.8.ln_2.weight": "pytorch_model-00001-of-00002.bin",
252
+ "transformer.h.8.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
253
+ "transformer.h.8.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
254
+ "transformer.h.8.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
255
+ "transformer.h.9.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
256
+ "transformer.h.9.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
257
+ "transformer.h.9.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
258
+ "transformer.h.9.ln_1.weight": "pytorch_model-00001-of-00002.bin",
259
+ "transformer.h.9.ln_2.weight": "pytorch_model-00001-of-00002.bin",
260
+ "transformer.h.9.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
261
+ "transformer.h.9.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
262
+ "transformer.h.9.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
263
+ "transformer.ln_f.weight": "pytorch_model-00002-of-00002.bin",
264
+ "transformer.wte.weight": "pytorch_model-00001-of-00002.bin"
265
+ }
266
+ }
qwen_generation_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 6144,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role, allowed_special=set()
139
+ ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ errors: str='replace',
202
+ ):
203
+ trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204
+ if verbose:
205
+ print("\nRaw Generate: ", trim_decode_tokens)
206
+
207
+ end_reason = f"Gen length {len(tokens)}"
208
+ for stop_word in stop_words:
209
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210
+ for eod_word in eod_words:
211
+ if eod_word in trim_decode_tokens:
212
+ end_reason = f"Gen {eod_word!r}"
213
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214
+ trim_decode_tokens = trim_decode_tokens.strip()
215
+ if verbose:
216
+ print("\nEnd Reason:", end_reason)
217
+ print("\nGenerate: ", trim_decode_tokens)
218
+
219
+ if return_end_reason:
220
+ return trim_decode_tokens, end_reason
221
+ else:
222
+ return trim_decode_tokens
223
+
224
+
225
+ def _decode_chatml(
226
+ tokens: List[int],
227
+ *,
228
+ stop_words: List[str],
229
+ eod_token_ids: List[int],
230
+ tokenizer: PreTrainedTokenizer,
231
+ raw_text_len: int,
232
+ context_length: int,
233
+ verbose: bool = False,
234
+ return_end_reason: bool = False,
235
+ errors: str='replace'
236
+ ):
237
+ end_reason = f"Gen length {len(tokens)}"
238
+ eod_token_idx = context_length
239
+ for eod_token_idx in range(context_length, len(tokens)):
240
+ if tokens[eod_token_idx] in eod_token_ids:
241
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242
+ break
243
+
244
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245
+ if verbose:
246
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247
+ print("\nRaw Generate:", trim_decode_tokens)
248
+ print("\nEnd Reason:", end_reason)
249
+ for stop_word in stop_words:
250
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251
+ trim_decode_tokens = trim_decode_tokens.strip()
252
+ if verbose:
253
+ print("\nGenerate:", trim_decode_tokens)
254
+
255
+ if return_end_reason:
256
+ return trim_decode_tokens, end_reason
257
+ else:
258
+ return trim_decode_tokens
259
+
260
+
261
+ def decode_tokens(
262
+ tokens: Union[torch.LongTensor, TokensType],
263
+ tokenizer: PreTrainedTokenizer,
264
+ raw_text_len: int,
265
+ context_length: int,
266
+ chat_format: str,
267
+ verbose: bool = False,
268
+ return_end_reason: bool = False,
269
+ errors: str="replace",
270
+ ) -> str:
271
+ if torch.is_tensor(tokens):
272
+ tokens = tokens.cpu().numpy().tolist()
273
+
274
+ if chat_format == "chatml":
275
+ return _decode_chatml(
276
+ tokens,
277
+ stop_words=[],
278
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279
+ tokenizer=tokenizer,
280
+ raw_text_len=raw_text_len,
281
+ context_length=context_length,
282
+ verbose=verbose,
283
+ return_end_reason=return_end_reason,
284
+ errors=errors,
285
+ )
286
+ elif chat_format == "raw":
287
+ return _decode_default(
288
+ tokens,
289
+ stop_words=["<|endoftext|>"],
290
+ eod_words=["<|endoftext|>"],
291
+ tokenizer=tokenizer,
292
+ raw_text_len=raw_text_len,
293
+ verbose=verbose,
294
+ return_end_reason=return_end_reason,
295
+ errors=errors,
296
+ )
297
+ else:
298
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299
+
300
+
301
+ class StopWordsLogitsProcessor(LogitsProcessor):
302
+ """
303
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304
+
305
+ Args:
306
+ stop_words_ids (:obj:`List[List[int]]`):
307
+ List of list of token ids of stop ids. In order to get the tokens of the words
308
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
309
+ add_prefix_space=True).input_ids`.
310
+ eos_token_id (:obj:`int`):
311
+ The id of the `end-of-sequence` token.
312
+ """
313
+
314
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
315
+
316
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
317
+ raise ValueError(
318
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
319
+ )
320
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
321
+ raise ValueError(
322
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
323
+ )
324
+ if any(
325
+ any(
326
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
327
+ for token_id in stop_word_ids
328
+ )
329
+ for stop_word_ids in stop_words_ids
330
+ ):
331
+ raise ValueError(
332
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
333
+ )
334
+
335
+ self.stop_words_ids = list(
336
+ filter(
337
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
338
+ )
339
+ )
340
+ self.eos_token_id = eos_token_id
341
+ for stop_token_seq in self.stop_words_ids:
342
+ assert (
343
+ len(stop_token_seq) > 0
344
+ ), "Stop words token sequences {} cannot have an empty list".format(
345
+ stop_words_ids
346
+ )
347
+
348
+ def __call__(
349
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
350
+ ) -> torch.FloatTensor:
351
+ stopped_samples = self._calc_stopped_samples(input_ids)
352
+ for i, should_stop in enumerate(stopped_samples):
353
+ if should_stop:
354
+ scores[i, self.eos_token_id] = float(2**15)
355
+ return scores
356
+
357
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
358
+ if len(tokens) == 0:
359
+ # if bad word tokens is just one token always ban it
360
+ return True
361
+ elif len(tokens) > len(prev_tokens):
362
+ # if bad word tokens are longer then prev input_ids they can't be equal
363
+ return False
364
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
365
+ # if tokens match
366
+ return True
367
+ else:
368
+ return False
369
+
370
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
371
+ stopped_samples = []
372
+ for prev_input_ids_slice in prev_input_ids:
373
+ match = False
374
+ for stop_token_seq in self.stop_words_ids:
375
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
376
+ # if tokens do not match continue
377
+ match = True
378
+ break
379
+ stopped_samples.append(match)
380
+
381
+ return stopped_samples
382
+
383
+
384
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
385
+ """This function has been mostly taken from huggingface conversational
386
+ ai code at
387
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
388
+ conversational-ai-with-transfer-learning-2d818ac26313"""
389
+
390
+ if top_k > 0:
391
+ # Remove all tokens with a probability less than the
392
+ # last token of the top-k
393
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
394
+ logits[indices_to_remove] = filter_value
395
+
396
+ if top_p > 0.0:
397
+ # Cconvert to 1D
398
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
399
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
400
+
401
+ # Remove tokens with cumulative probability above the threshold
402
+ sorted_indices_to_remove = cumulative_probs > top_p
403
+ # Shift the indices to the right to keep also the first token
404
+ # above the threshold
405
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
406
+ sorted_indices_to_remove[..., 0] = 0
407
+ for i in range(sorted_indices.size(0)):
408
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
409
+ logits[i][indices_to_remove] = filter_value
410
+
411
+ return logits
412
+
413
+
414
+ def switch(val1, val2, boolean):
415
+ boolean = boolean.type_as(val1)
416
+ return (1 - boolean) * val1 + boolean * val2
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 4096000,
7
+ "tokenizer_class": "GPT2Tokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }