# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. from .configuration_baichuan import BaichuanConfig from .generation_utils import build_chat_input, TextIterStreamer import math from threading import Thread from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.activations import ACT2FN from transformers.generation.utils import GenerationConfig from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging, ContextManagers import os from contextlib import contextmanager from accelerate import init_empty_weights logger = logging.get_logger(__name__) try: from xformers import ops as xops except ImportError: xops = None logger.warning( "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." ) def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return _get_interleave_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( _get_interleave_power_of_2(closest_power_of_2) + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) def _fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float("-inf")).type_as(t) def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) _future_mask = _future_mask.unsqueeze(0) + alibi new_future_mask = _future_mask.to(tensor) return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] def _gen_alibi_mask(tensor, n_head, max_pos): slopes = torch.Tensor(_get_interleave(n_head)) position_point = torch.arange(max_pos) - max_pos + 1 position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) diag = torch.diag(position_point[0]) position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point alibi = alibi.view(n_head, 1, max_pos) alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) alibi_mask = alibi_mask.unsqueeze(0) + alibi return alibi_mask class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, epsilon=1e-6): super().__init__() self.weight = torch.nn.Parameter(torch.empty(hidden_size)) self.epsilon = epsilon def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) # convert into half-precision if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class MLP(torch.nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class BaichuanAttention(torch.nn.Module): def __init__(self, config: BaichuanConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.model_max_length if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" ) self.W_pack = torch.nn.Linear( self.hidden_size, 3 * self.hidden_size, bias=False ) self.o_proj = torch.nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=False ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() proj = self.W_pack(hidden_states) proj = ( proj.unflatten(-1, (3, self.hidden_size)) .unsqueeze(0) .transpose(0, -2) .squeeze(-2) ) query_states = ( proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) key_states = ( proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) value_states = ( proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if xops is not None and self.training: attn_weights = None query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = xops.memory_efficient_attention( query_states, key_states, value_states, attn_bias=attention_mask ) else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: if q_len == 1: # inference with cache if len(attention_mask.size()) == 4: attention_mask = attention_mask[:, :, -1:, :] else: attention_mask = attention_mask[:, -1:, :] attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class BaichuanLayer(torch.nn.Module): def __init__(self, config: BaichuanConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BaichuanAttention(config=config) self.mlp = MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, epsilon=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) return outputs class BaichuanPreTrainedModel(PreTrainedModel): config_class = BaichuanConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BaichuanLayer"] _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, BaichuanModel): module.gradient_checkpointing = value class BaichuanModel(BaichuanPreTrainedModel): def __init__(self, config: BaichuanConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.n_head = config.num_attention_heads self.embed_tokens = torch.nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = torch.nn.ModuleList( [BaichuanLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) self.gradient_checkpointing = config.gradient_checkpointing self.post_init() self.max_cache_pos = config.model_max_length self.first_run = True self.alibi_mask = None def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_alibi_mask(self, tensor, seq_length_with_past): if self.training: slopes = torch.Tensor(_get_interleave(self.n_head)) position_point = ( torch.arange(seq_length_with_past) - seq_length_with_past + 1 ) position_point = ( position_point.unsqueeze(0) .unsqueeze(0) .expand(self.n_head, seq_length_with_past, -1) ) diag = torch.diag(position_point[0]) position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( -1, -2 ) alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point mask = _buffered_future_mask( tensor, seq_length_with_past, alibi, self.n_head ) else: if self.first_run: self.first_run = False self.register_buffer( "future_mask", _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( tensor ), persistent=False, ) if seq_length_with_past > self.max_cache_pos: self.max_cache_pos = seq_length_with_past self.register_buffer( "future_mask", _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to( tensor ), persistent=False, ) mask = self.future_mask[ : self.n_head, :seq_length_with_past, :seq_length_with_past ] return mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot provide both input_ids and inputs_embeds simultaneously" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You need to provide input_ids or inputs_embeds") return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) seq_length_with_past = seq_length if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if self.training: if ( self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past ): self.alibi_mask = self.get_alibi_mask( inputs_embeds, seq_length_with_past ) alibi_mask = self.alibi_mask else: alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) if attention_mask is not None: if len(attention_mask.shape) == 2: expanded_mask = attention_mask.to(alibi_mask.dtype) expanded_mask = torch.tril( torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) else: expanded_mask = attention_mask bsz = inputs_embeds.size(0) src_len, tgt_len = alibi_mask.size()[-2:] expanded_mask = ( expanded_mask.unsqueeze(1) .expand(bsz, 1, src_len, tgt_len) .to(alibi_mask.dtype) ) inverted_mask = 1.0 - expanded_mask inverted_mask = inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min ) attention_mask = inverted_mask + alibi_mask.unsqueeze(0) else: attention_mask = alibi_mask hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) elif self.first_flag: self.first_flag = False self.weight = nn.Parameter(nn.functional.normalize(self.weight)) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight) _init_weights = True @contextmanager def no_init_weights(_enable=True): global _init_weights old_init_weights = _init_weights if _enable: _init_weights = False try: yield finally: _init_weights = old_init_weights class BaichuanForCausalLM(BaichuanPreTrainedModel): def __init__(self, config, *model_args, **model_kwargs): super().__init__(config, *model_args, **model_kwargs) self.model = BaichuanModel(config) self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False): try: from .quantizer import quantize_offline, init_model_weight_int4 except ImportError: raise ImportError(f"Needs quantize_offline to run quantize.") quantize_offline(self, 4) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=False, proxies=None, local_files_only=local_files_only, token=token, revision=revision, subfolder="", _from_auto=False, _from_pipeline=None, **kwargs, ) else: model_kwargs = kwargs if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: try: from .quantizer import init_model_weight_int4 from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map from accelerate.utils import CustomDtype from accelerate.utils import get_balanced_memory except ImportError: raise ImportError(f"Needs import model weight init func to run quantize.") # Instantiate model. init_contexts = [no_init_weights(_enable=True)] init_contexts.append(init_empty_weights()) with ContextManagers(init_contexts): model = cls(config) model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin') state_dict = torch.load(model_file, map_location="cpu") model.is_quantized = True device_map = kwargs.pop("device_map", None) torch_dtype = kwargs.pop("torch_dtype", None) if device_map is not None: kwargs = {"no_split_module_classes": model._no_split_modules} target_dtype = CustomDtype.INT4 max_memory = get_balanced_memory( model, dtype=target_dtype, low_zero=(device_map == "balanced_low_0"), max_memory=None, **kwargs, ) kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) model = init_model_weight_int4(config, model, state_dict) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() # If it is a model with generation capabilities, attempt to load the generation config if model.can_generate(): try: model.generation_config = GenerationConfig.from_pretrained( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, resume_download=False, proxies=None, local_files_only=local_files_only, token=token, revision=revision, subfolder="", _from_auto=False, _from_pipeline=None, **kwargs, ) except (OSError, TypeError): logger.info( "Generation config file not found, using a generation config created from the model config." ) pass if device_map is not None: dispatch_model(model, device_map=device_map) return model return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) softmax_normalizer = shift_logits.max(-1).values ** 2 z_loss = self.config.z_loss_weight * softmax_normalizer.mean() # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) + z_loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def quantize(self, bits: int): try: from .quantizer import quantize_online except ImportError: raise ImportError(f"Needs QLinear to run quantize.") return quantize_online(self, bits) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): return tuple( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past_key_values ) def _build_chat_input( self, tokenizer, messages: List[dict], max_new_tokens: int = 0 ): max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens max_input_tokens = self.config.model_max_length - max_new_tokens max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) total_input, round_input = [], [] for i, message in enumerate(messages[::-1]): content_tokens = tokenizer.encode(message["content"]) if message["role"] == "user": round_input = ( [self.generation_config.user_token_id] + content_tokens + round_input ) if ( total_input and len(total_input) + len(round_input) > max_input_tokens ): break else: total_input = round_input + total_input if len(total_input) >= max_input_tokens: break else: round_input = [] elif message["role"] == "assistant": round_input = ( [self.generation_config.assistant_token_id] + content_tokens + [self.generation_config.eos_token_id] + round_input ) else: raise ValueError(f"message role not supported yet: {message['role']}") total_input = total_input[-max_input_tokens:] # truncate left total_input.append(self.generation_config.assistant_token_id) total_input = torch.LongTensor([total_input]).to(self.device) return total_input def chat(self, tokenizer, messages: List[dict], stream=False, generation_config: Optional[GenerationConfig]=None): generation_config = generation_config or self.generation_config input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens) if stream: streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) Thread(target=self.generate, kwargs=dict( inputs=input_ids, streamer=streamer, generation_config=generation_config, )).start() return streamer else: outputs = self.generate(input_ids, generation_config=generation_config) response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) return response