Josephgflowers's picture
Upload LM-Diff.py
eacb34e verified
raw
history blame
18.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
# Custom Modules
class AdaptiveRMSNorm(nn.Module):
"""
Adaptive RMSNorm layer where the scaling parameter adapts based on input.
"""
def __init__(self, normalized_shape, adaptive_dim, eps=1e-6):
super(AdaptiveRMSNorm, self).__init__()
self.normalized_shape = normalized_shape
self.eps = eps
# Standard RMSNorm weight parameter
self.weight = nn.Parameter(torch.ones(normalized_shape))
# Adaptive scaling parameter
self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape)
def forward(self, x, adapt_input):
# Compute adaptive scaling factor gamma
gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
# Compute RMSNorm
norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps)
# Apply adaptive scaling
return self.weight * norm_x * gamma
class TokenMixing(nn.Module):
"""
Token Mixing layer that performs depthwise convolution across the sequence dimension.
"""
def __init__(self, hidden_size):
super(TokenMixing, self).__init__()
self.token_mixing = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=3,
padding=1,
groups=hidden_size # Depthwise convolution
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length]
x = self.token_mixing(x)
x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size]
return x
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation block that adaptively recalibrates channel-wise features.
"""
def __init__(self, hidden_size, reduction=16):
super(SEBlock, self).__init__()
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(hidden_size // reduction, hidden_size, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# x shape: [batch_size, seq_length, hidden_size]
y = x.mean(dim=1) # Global average pooling over sequence length
y = self.fc(y) # Squeeze and Excitation
y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
return x * y # Scale the original input
class DifferentialSelfAttention(nn.Module):
"""
Self-Attention layer with Differential Attention Mechanism.
Includes support for past_key_value and attention_mask handling.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size # e.g., 1024
self.num_heads = config.num_attention_heads # e.g., 4
self.head_dim = self.hidden_size // self.num_heads # e.g., 256
assert self.head_dim * self.num_heads == self.hidden_size, \
"hidden_size must be divisible by num_attention_heads"
self.scaling = self.head_dim ** -0.5
# Linear layers for Q, K, V projections
# Adjust k_proj and v_proj to match the pre-trained model's dimensions
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024]
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256]
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256]
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024]
# Learnable parameters for lambda computation
self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
self.lambda_init = nn.Parameter(torch.tensor(0.5)) # Initial value as per the paper
# Layer normalization
self.sub_layer_norm = nn.LayerNorm(self.hidden_size)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
batch_size, seq_length, _ = hidden_states.size()
# Linear projections
query_states = self.q_proj(hidden_states) * self.scaling # Shape: [batch_size, seq_length, hidden_size]
key_states = self.k_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4]
value_states = self.v_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4]
# Reshape and split into multiple heads
# Query states have shape: [batch_size, num_heads, seq_length, head_dim]
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
# Key and value states have shape: [batch_size, num_heads, seq_length, key_head_dim]
key_head_dim = key_states.size(-1) // self.num_heads # Should be 256 // num_heads
key_states = key_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2)
# Handle past key values for caching
if past_key_value is not None:
# past_key_value[0] and [1] have shape (batch_size, num_heads, seq_len_prev, key_head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=2) # Concat on seq_length dimension
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
present_key_value = (key_states, value_states)
else:
present_key_value = None
# Update sequence length after concatenation
kv_seq_length = key_states.size(2)
# Split Q and K into two groups for differential attention
q1, q2 = torch.chunk(query_states, 2, dim=-1) # Each has shape: [batch_size, num_heads, seq_length, head_dim/2]
k1, k2 = torch.chunk(key_states, 2, dim=-1) # Adjusted for key_states
# Compute attention scores
attn_scores1 = torch.matmul(q1, k1.transpose(-2, -1)) # [batch_size, num_heads, seq_length, kv_seq_length]
attn_scores2 = torch.matmul(q2, k2.transpose(-2, -1))
# Apply attention mask if provided
if attention_mask is not None:
# attention_mask should be of shape [batch_size, 1, seq_length, kv_seq_length]
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :] # Expand to [batch_size, 1, 1, kv_seq_length]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
attention_mask = attention_mask.to(dtype=attn_scores1.dtype) # Ensure dtype matches
attn_scores1 += attention_mask
attn_scores2 += attention_mask
# Compute attention probabilities
attn_probs1 = nn.functional.softmax(attn_scores1, dim=-1, dtype=torch.float32).to(attn_scores1.dtype)
attn_probs2 = nn.functional.softmax(attn_scores2, dim=-1, dtype=torch.float32).to(attn_scores2.dtype)
# Compute lambda as per the DIFF Transformer paper
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# Compute differential attention
attn_probs = attn_probs1 - lambda_full * attn_probs2
# Compute attention output
attn_output = torch.matmul(attn_probs, value_states) # [batch_size, num_heads, seq_length, key_head_dim]
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size)
attn_output = self.o_proj(attn_output)
# Apply layer normalization
attn_output = self.sub_layer_norm(attn_output)
if output_attentions:
# Return attention probabilities if required
attn_probs_return = attn_probs
else:
attn_probs_return = None
return attn_output, present_key_value, attn_probs_return
# Modified Decoder Layer
class ModifiedLlamaDecoderLayer(nn.Module):
"""
Modified Llama Decoder Layer incorporating DifferentialSelfAttention,
AdaptiveRMSNorm, TokenMixing, and SEBlock.
"""
def __init__(self, original_layer, config):
super().__init__()
self.hidden_size = config.hidden_size
self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input
# Replace the self-attention layer with DifferentialSelfAttention
self.self_attn = DifferentialSelfAttention(config)
# Copy the original MLP layer
self.mlp = original_layer.mlp
# Replace RMSNorm layers with AdaptiveRMSNorm
self.input_layernorm = AdaptiveRMSNorm(
self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps
)
self.post_attention_layernorm = AdaptiveRMSNorm(
self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps
)
# Add Token Mixing Layer
self.token_mixing = TokenMixing(self.hidden_size)
# Add SE Block
self.se_block = SEBlock(self.hidden_size, reduction=16)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
# Compute adaptation input for AdaptiveRMSNorm
adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size]
residual = hidden_states
# Input layer normalization with adaptive RMSNorm
hidden_states = self.input_layernorm(hidden_states, adapt_input)
# Self-attention with differential attention mechanism
attn_output, present_key_value, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + attn_output
# Token Mixing
token_mixed = self.token_mixing(hidden_states)
hidden_states = hidden_states + token_mixed
# Post-attention layer normalization with adaptive RMSNorm
hidden_states = self.post_attention_layernorm(hidden_states, adapt_input)
# MLP
residual = hidden_states
hidden_states = self.mlp(hidden_states)
# SE Block
hidden_states = self.se_block(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Modified Model
class ModifiedLlamaModel(LlamaModel):
def __init__(self, config):
super().__init__(config)
# Replace the decoder layers with modified layers
self.layers = nn.ModuleList([
ModifiedLlamaDecoderLayer(layer, config)
for layer in self.layers
])
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs, # Capture any additional keyword arguments
):
# Ensure default values are set
output_attentions = output_attentions if output_attentions is not None else self.config.use_cache
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Process inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# Initialize past_key_values if not provided
if past_key_values is None:
past_key_values = [None] * len(self.layers)
# Embed tokens
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Attention mask processing
if attention_mask is not None:
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
attention_mask = attention_mask.to(dtype=hidden_states.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
# Main loop over layers
next_decoder_cache = [] if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for idx, (decoder_layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# Forward pass through the layer
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs, # Pass any additional keyword arguments
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache.append(layer_outputs[1])
if output_attentions:
all_attentions = all_attentions + (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
outputs = (hidden_states,)
if use_cache:
outputs += (next_decoder_cache,)
if output_hidden_states:
outputs += (all_hidden_states,)
if output_attentions:
outputs += (all_attentions,)
return outputs
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache if use_cache else None,
hidden_states=all_hidden_states if output_hidden_states else None,
attentions=all_attentions if output_attentions else None,
)
# Load the pre-trained model
# Load the configuration from the pre-trained model
config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
# Initialize the modified model
modified_model = LlamaForCausalLM(config)
modified_model.model = ModifiedLlamaModel(config)
# Load the pre-trained weights
pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
modified_model.load_state_dict(pretrained_model.state_dict(), strict=False)
# Save the model and tokenizer
output_dir = "./BSC-LT-salamandra-2b-instruct-saved_model"
modified_model.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False)
tokenizer.save_pretrained(output_dir)
print(f"Model and tokenizer saved to {output_dir}")
# Example Usage
import time
def chat_with_model(prompt_text, stop_token, model, tokenizer):
# Encode the prompt text
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
start_time = time.time()
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device)
# Generate response
output_sequences = model.generate(
input_ids=encoded_prompt,
max_new_tokens=512,
temperature=0.2,
repetition_penalty=1.2,
top_k=30,
top_p=0.9,
do_sample=True,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Ensure use_cache is True for generation
)
# Decode the generated sequence
generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
num_tokens = output_sequences.shape[-1]
response_text = text[len(prompt_text):].strip()
end_time = time.time()
total_time = end_time - start_time
print(f"Total time: {total_time:.3f} seconds")
tokens_per_second = num_tokens / total_time
print(f"Tokens per second: {tokens_per_second:.3f}")
return response_text
# Example usage
input_text = "Hello, how are you?"
stop_token = tokenizer.eos_token_id # Assuming EOS token as the stop token
response = chat_with_model(input_text, stop_token, modified_model, tokenizer)
print("Model response:", response)