Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
12.6 kB
# --------------------------------------------------------------------------
# copied and adapted https://github.com/microsoft/LoRA/
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT).
# Support bnb quantization of nderlying layers
# --------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint
from typing import Dict
import os
class LoRALayer:
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Embedding(nn.Embedding, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs,
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=0,
merge_weights=merge_weights,
)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default
# for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).transpose(
0, 1
) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A).transpose(
0, 1
) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x,
self.lora_A.transpose(0, 1),
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
class QLinear(type):
def __call__(cls, *args, **kwargs):
quant_type = kwargs.get("quant_type", None)
use_ckpting = kwargs.get("use_ckpting", [])
r = kwargs.get("r", 0)
lora_alpha = kwargs.get("lora_alpha", 1)
lora_dropout = kwargs.get("lora_dropout", 0.0)
bias = kwargs.get("bias", False)
merge_weights = kwargs.get("merge_weights", True)
threshold = kwargs.get("threshold", 6.0)
compute_dtype = kwargs.get("compute_dtype", torch.float16)
if quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4"]:
try:
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
import bitsandbytes as bnb
except ImportError:
raise ImportError("Install bitsandbytes to use 4/8bit compression")
if quant_type == "bnb_8bit":
layer_class = bnb.nn.Linear8bitLt
elif quant_type in ["bnb_FP4", "bnb_NF4"]:
layer_class = bnb.nn.Linear4bit
else:
layer_class = nn.Linear
class QLoraLinear_cls(layer_class, LoRALayer):
def __init__(self, *args, **kwargs):
if quant_type == "bnb_8bit":
super(QLoraLinear_cls, self).__init__(
*args, bias=bias, has_fp16_weights=False, threshold=threshold
)
elif quant_type in ["bnb_FP4", "bnb_NF4"]:
super(QLoraLinear_cls, self).__init__(
*args,
bias=bias,
compute_dtype=compute_dtype,
quant_type=quant_type[-3:].lower(),
)
else:
super(QLoraLinear_cls, self).__init__(*args, bias=bias)
self.quant_type = quant_type
LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, args[0])))
self.lora_B = nn.Parameter(self.weight.new_zeros((args[1], r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
self.maybe_ckpt = (
checkpoint if "lora" in use_ckpting else lambda f, x: f(x)
)
def reset_parameters(self):
# we do not super().reset_parameters() save lot of time and useless when no grad.
if hasattr(self, "lora_A"):
# initialize A the same way as the default
# for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
if self.quant_type is None:
super().train(mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = True
else:
# cannot merge/unmerge quantized weigts with unquantized lora_X
pass
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = (
self.maybe_ckpt(super().forward, x)
+ (
self.lora_dropout(x)
@ self.lora_A.transpose(0, 1)
@ self.lora_B.transpose(0, 1)
)
* self.scaling
)
else:
result = self.maybe_ckpt(super().forward, x)
return result
instance = QLoraLinear_cls.__new__(
QLoraLinear_cls
) # Create a new instance of QLoraLinear_cls
instance.__init__(
*args, **kwargs
) # Invoke the __init__ method of QLoraLinear_cls
"""
# Check if QLoraLinear has a custom __init__ method
if hasattr(cls, "__init__"):
kwargs.pop("r", None)
kwargs.pop("quant_type", None)
# Invoke the __init__ method of QLoraLinear
cls.__init__(instance, *args, **kwargs)
"""
return instance
class QLoraLinear(metaclass=QLinear):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
**kwargs,
):
pass
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == "none":
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
elif bias == "all":
return {
k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k
}
elif bias == "lora_only":
to_return = {}
for k in my_state_dict:
if "lora_" in k:
to_return[k] = my_state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
def replace_lora_linear(
model,
r=2,
lora_alpha=1,
lora_dropout=0,
layer="",
quant_type=None,
use_ckpting=[],
threshold=6.0,
compute_dtype=torch.float16,
):
"""
Function replacing layers with LoRa layers recursively.
Args:
model:
r: rank of matrix of the Low Rank layer
lora_alpha: cf paper
lora_dropout: cf paper
layer: layer name of the model to be replaced
quant_type: use bnb to quantize nn.Linear sub-layer
"""
for name, module in model.named_children():
if hasattr(module, "children") and len(list(module.children())) > 0:
replace_lora_linear(
module,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
layer=layer,
quant_type=quant_type,
use_ckpting=use_ckpting,
threshold=threshold,
compute_dtype=compute_dtype,
)
if isinstance(module, nn.Linear) and name == layer:
model._modules[name] = QLoraLinear(
module.in_features,
module.out_features,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias=module.bias is not None,
quant_type=quant_type,
use_ckpting=use_ckpting,
threshold=threshold,
compute_dtype=compute_dtype,
)
return model
def replace_lora_embedding(model, r=2, lora_alpha=1):
"""
Function replacing Embeddings with LoRa ones recursively.
Args:
model:
r: rank of matrix of the Low Rank layer
lora_alpha: cf paper
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_lora_embedding(module, r, lora_alpha)
if isinstance(module, nn.Embedding):
model._modules[name] = Embedding(
module.num_embeddings,
module.embedding_dim,
r=r,
lora_alpha=lora_alpha,
padding_idx=module.padding_idx,
sparse=module.sparse,
)
return model