Text Generation
Transformers
PyTorch
RefinedWeb
custom_code
text-generation-inference
psinger commited on
Commit
761cf67
·
1 Parent(s): 7700b1c

Update modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +41 -35
modelling_RW.py CHANGED
@@ -175,32 +175,42 @@ class Attention(nn.Module):
175
 
176
  self.query_key_value = Linear(
177
  self.hidden_size,
178
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
  bias=config.bias,
180
  )
181
- self.multi_query = config.multi_query
182
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
  self.attention_dropout = nn.Dropout(config.attention_dropout)
184
- self.num_kv = config.n_head if not self.multi_query else 1
185
 
186
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
  """
188
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
  storage as `fused_qkv`
190
  Args:
191
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
192
  Returns:
193
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 
194
  value: [batch_size, seq_length, num_heads, head_dim]
195
  """
196
- if not self.multi_query:
197
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
198
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
199
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
200
- else:
201
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
202
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
203
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
 
 
 
 
 
 
 
 
 
 
204
 
205
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
206
  """
@@ -244,11 +254,11 @@ class Attention(nn.Module):
244
 
245
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
246
  key_layer = key_layer.transpose(1, 2).reshape(
247
- batch_size * self.num_kv,
248
  q_length,
249
  self.head_dim,
250
  )
251
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
252
 
253
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
254
 
@@ -269,8 +279,8 @@ class Attention(nn.Module):
269
 
270
  if alibi is None:
271
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
272
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
273
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
274
 
275
  attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, torch.finfo(torch.float16).min).to(query_layer_.dtype)
276
  attn_output = F.scaled_dot_product_attention(
@@ -300,7 +310,8 @@ class Attention(nn.Module):
300
  attention_scores = attention_scores.to(torch.float32)
301
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
302
  attention_probs = F.softmax(
303
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
 
304
  dim=-1,
305
  dtype=hidden_states.dtype,
306
  )
@@ -349,14 +360,12 @@ class DecoderLayer(nn.Module):
349
  super().__init__()
350
  hidden_size = config.hidden_size
351
 
352
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
 
353
  self.num_heads = config.n_head
354
  self.self_attention = Attention(config)
355
 
356
- if not config.parallel_attn:
357
- # unused if parallel attn
358
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
359
-
360
  self.mlp = MLP(config)
361
 
362
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
@@ -375,12 +384,14 @@ class DecoderLayer(nn.Module):
375
  output_attentions: bool = False,
376
  ):
377
 
378
- layernorm_output = self.input_layernorm(hidden_states)
 
 
379
  residual = hidden_states
380
 
381
  # Self attention.
382
  attn_outputs = self.self_attention(
383
- layernorm_output,
384
  layer_past=layer_past,
385
  attention_mask=attention_mask,
386
  alibi=alibi,
@@ -391,19 +402,14 @@ class DecoderLayer(nn.Module):
391
 
392
  attention_output = attn_outputs[0]
393
 
394
- if not self.config.parallel_attn:
395
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
396
- layernorm_output = self.post_attention_layernorm(residual)
397
-
398
  outputs = attn_outputs[1:]
399
 
400
  # MLP.
401
- mlp_output = self.mlp(layernorm_output)
402
-
403
- if self.config.parallel_attn:
404
- mlp_output += attention_output
405
 
406
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
 
407
 
408
  if use_cache:
409
  outputs = (output,) + outputs
@@ -1093,4 +1099,4 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1093
  end_logits=end_logits,
1094
  hidden_states=outputs.hidden_states,
1095
  attentions=outputs.attentions,
1096
- )
 
175
 
176
  self.query_key_value = Linear(
177
  self.hidden_size,
178
+ (config.n_head_kv * 2 + config.n_head) * self.head_dim,
179
  bias=config.bias,
180
  )
 
181
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
182
  self.attention_dropout = nn.Dropout(config.attention_dropout)
183
+ self.num_kv = config.n_head_kv
184
 
185
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
186
  """
187
+ Split the last dimension into (num_heads, head_dim), results share same memory
188
  storage as `fused_qkv`
189
  Args:
190
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
191
  Returns:
192
+ query: [batch_size, seq_length, num_heads, head_dim]
193
+ key: [batch_size, seq_length, num_heads, head_dim]
194
  value: [batch_size, seq_length, num_heads, head_dim]
195
  """
196
+ batch, seq_len, _ = fused_qkv.shape
197
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
198
+ q = qkv[:, :, :, :-2]
199
+ k = qkv[:, :, :, [-2]]
200
+ v = qkv[:, :, :, [-1]]
201
+ k = torch.broadcast_to(k, q.shape)
202
+ v = torch.broadcast_to(v, q.shape)
203
+
204
+ q, k, v = [
205
+ rearrange(
206
+ x,
207
+ "batch seq_len group num_heads head_dim ->\
208
+ batch seq_len (group num_heads) head_dim",
209
+ head_dim=self.head_dim,
210
+ )
211
+ for x in [q, k, v]
212
+ ]
213
+ return q, k, v
214
 
215
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
216
  """
 
254
 
255
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
256
  key_layer = key_layer.transpose(1, 2).reshape(
257
+ batch_size * self.num_heads,
258
  q_length,
259
  self.head_dim,
260
  )
261
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
262
 
263
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
264
 
 
279
 
280
  if alibi is None:
281
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
282
+ key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
283
+ value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
284
 
285
  attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, torch.finfo(torch.float16).min).to(query_layer_.dtype)
286
  attn_output = F.scaled_dot_product_attention(
 
310
  attention_scores = attention_scores.to(torch.float32)
311
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
312
  attention_probs = F.softmax(
313
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
314
+ + attention_mask_float,
315
  dim=-1,
316
  dtype=hidden_states.dtype,
317
  )
 
360
  super().__init__()
361
  hidden_size = config.hidden_size
362
 
363
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
364
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
365
+
366
  self.num_heads = config.n_head
367
  self.self_attention = Attention(config)
368
 
 
 
 
 
369
  self.mlp = MLP(config)
370
 
371
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
 
384
  output_attentions: bool = False,
385
  ):
386
 
387
+ ln_attn = self.ln_attn(hidden_states)
388
+ ln_mlp = self.ln_mlp(hidden_states)
389
+
390
  residual = hidden_states
391
 
392
  # Self attention.
393
  attn_outputs = self.self_attention(
394
+ ln_attn,
395
  layer_past=layer_past,
396
  attention_mask=attention_mask,
397
  alibi=alibi,
 
402
 
403
  attention_output = attn_outputs[0]
404
 
 
 
 
 
405
  outputs = attn_outputs[1:]
406
 
407
  # MLP.
408
+ mlp_output = self.mlp(ln_mlp)
 
 
 
409
 
410
+ output = dropout_add(
411
+ mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
412
+ )
413
 
414
  if use_cache:
415
  outputs = (output,) + outputs
 
1099
  end_logits=end_logits,
1100
  hidden_states=outputs.hidden_states,
1101
  attentions=outputs.attentions,
1102
+ )