|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from torch.utils.checkpoint import checkpoint |
|
from typing import Dict |
|
import os |
|
|
|
|
|
class LoRALayer: |
|
def __init__( |
|
self, |
|
r: int, |
|
lora_alpha: int, |
|
lora_dropout: float, |
|
merge_weights: bool, |
|
): |
|
self.r = r |
|
self.lora_alpha = lora_alpha |
|
|
|
if lora_dropout > 0.0: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
|
|
self.merged = False |
|
self.merge_weights = merge_weights |
|
|
|
|
|
class Embedding(nn.Embedding, LoRALayer): |
|
|
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
merge_weights: bool = True, |
|
**kwargs, |
|
): |
|
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
|
LoRALayer.__init__( |
|
self, |
|
r=r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=0, |
|
merge_weights=merge_weights, |
|
) |
|
|
|
if r > 0: |
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) |
|
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) |
|
self.scaling = self.lora_alpha / self.r |
|
|
|
self.weight.requires_grad = False |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.Embedding.reset_parameters(self) |
|
if hasattr(self, "lora_A"): |
|
|
|
|
|
nn.init.zeros_(self.lora_A) |
|
nn.init.normal_(self.lora_B) |
|
|
|
def train(self, mode: bool = True): |
|
nn.Embedding.train(self, mode) |
|
if mode: |
|
if self.merge_weights and self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data -= (self.lora_B @ self.lora_A).transpose( |
|
0, 1 |
|
) * self.scaling |
|
self.merged = False |
|
else: |
|
if self.merge_weights and not self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data += (self.lora_B @ self.lora_A).transpose( |
|
0, 1 |
|
) * self.scaling |
|
self.merged = True |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.r > 0 and not self.merged: |
|
result = nn.Embedding.forward(self, x) |
|
if self.r > 0: |
|
after_A = F.embedding( |
|
x, |
|
self.lora_A.transpose(0, 1), |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling |
|
return result |
|
else: |
|
return nn.Embedding.forward(self, x) |
|
|
|
|
|
class QLinear(type): |
|
def __call__(cls, *args, **kwargs): |
|
quant_type = kwargs.get("quant_type", None) |
|
use_ckpting = kwargs.get("use_ckpting", []) |
|
r = kwargs.get("r", 0) |
|
lora_alpha = kwargs.get("lora_alpha", 1) |
|
lora_dropout = kwargs.get("lora_dropout", 0.0) |
|
bias = kwargs.get("bias", False) |
|
merge_weights = kwargs.get("merge_weights", True) |
|
threshold = kwargs.get("threshold", 6.0) |
|
compute_dtype = kwargs.get("compute_dtype", torch.float16) |
|
|
|
if quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4"]: |
|
try: |
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
|
import bitsandbytes as bnb |
|
except ImportError: |
|
raise ImportError("Install bitsandbytes to use 4/8bit compression") |
|
|
|
if quant_type == "bnb_8bit": |
|
layer_class = bnb.nn.Linear8bitLt |
|
elif quant_type in ["bnb_FP4", "bnb_NF4"]: |
|
layer_class = bnb.nn.Linear4bit |
|
else: |
|
layer_class = nn.Linear |
|
|
|
class QLoraLinear_cls(layer_class, LoRALayer): |
|
def __init__(self, *args, **kwargs): |
|
if quant_type == "bnb_8bit": |
|
super(QLoraLinear_cls, self).__init__( |
|
*args, bias=bias, has_fp16_weights=False, threshold=threshold |
|
) |
|
elif quant_type in ["bnb_FP4", "bnb_NF4"]: |
|
super(QLoraLinear_cls, self).__init__( |
|
*args, |
|
bias=bias, |
|
compute_dtype=compute_dtype, |
|
quant_type=quant_type[-3:].lower(), |
|
) |
|
else: |
|
super(QLoraLinear_cls, self).__init__(*args, bias=bias) |
|
self.quant_type = quant_type |
|
LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights) |
|
|
|
if r > 0: |
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r, args[0]))) |
|
self.lora_B = nn.Parameter(self.weight.new_zeros((args[1], r))) |
|
self.scaling = self.lora_alpha / self.r |
|
|
|
self.weight.requires_grad = False |
|
self.reset_parameters() |
|
self.maybe_ckpt = ( |
|
checkpoint if "lora" in use_ckpting else lambda f, x: f(x) |
|
) |
|
|
|
def reset_parameters(self): |
|
|
|
if hasattr(self, "lora_A"): |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B) |
|
|
|
def train(self, mode: bool = True): |
|
if self.quant_type is None: |
|
super().train(mode) |
|
if mode: |
|
if self.merge_weights and self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data -= ( |
|
self.lora_B @ self.lora_A |
|
) * self.scaling |
|
self.merged = False |
|
else: |
|
if self.merge_weights and not self.merged: |
|
|
|
if self.r > 0: |
|
self.weight.data += ( |
|
self.lora_B @ self.lora_A |
|
) * self.scaling |
|
self.merged = True |
|
else: |
|
|
|
pass |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.r > 0 and not self.merged: |
|
result = ( |
|
self.maybe_ckpt(super().forward, x) |
|
+ ( |
|
self.lora_dropout(x) |
|
@ self.lora_A.transpose(0, 1) |
|
@ self.lora_B.transpose(0, 1) |
|
) |
|
* self.scaling |
|
) |
|
else: |
|
result = self.maybe_ckpt(super().forward, x) |
|
return result |
|
|
|
instance = QLoraLinear_cls.__new__( |
|
QLoraLinear_cls |
|
) |
|
instance.__init__( |
|
*args, **kwargs |
|
) |
|
|
|
""" |
|
# Check if QLoraLinear has a custom __init__ method |
|
if hasattr(cls, "__init__"): |
|
kwargs.pop("r", None) |
|
kwargs.pop("quant_type", None) |
|
# Invoke the __init__ method of QLoraLinear |
|
cls.__init__(instance, *args, **kwargs) |
|
""" |
|
return instance |
|
|
|
|
|
class QLoraLinear(metaclass=QLinear): |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
**kwargs, |
|
): |
|
pass |
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
|
for n, p in model.named_parameters(): |
|
if "lora_" not in n: |
|
p.requires_grad = False |
|
if bias == "none": |
|
return |
|
elif bias == "all": |
|
for n, p in model.named_parameters(): |
|
if "bias" in n: |
|
p.requires_grad = True |
|
elif bias == "lora_only": |
|
for m in model.modules(): |
|
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]: |
|
my_state_dict = model.state_dict() |
|
if bias == "none": |
|
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k} |
|
elif bias == "all": |
|
return { |
|
k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k |
|
} |
|
elif bias == "lora_only": |
|
to_return = {} |
|
for k in my_state_dict: |
|
if "lora_" in k: |
|
to_return[k] = my_state_dict[k] |
|
bias_name = k.split("lora_")[0] + "bias" |
|
if bias_name in my_state_dict: |
|
to_return[bias_name] = my_state_dict[bias_name] |
|
return to_return |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def replace_lora_linear( |
|
model, |
|
r=2, |
|
lora_alpha=1, |
|
lora_dropout=0, |
|
layer="", |
|
quant_type=None, |
|
use_ckpting=[], |
|
threshold=6.0, |
|
compute_dtype=torch.float16, |
|
): |
|
""" |
|
Function replacing layers with LoRa layers recursively. |
|
Args: |
|
model: |
|
r: rank of matrix of the Low Rank layer |
|
lora_alpha: cf paper |
|
lora_dropout: cf paper |
|
layer: layer name of the model to be replaced |
|
quant_type: use bnb to quantize nn.Linear sub-layer |
|
""" |
|
for name, module in model.named_children(): |
|
if hasattr(module, "children") and len(list(module.children())) > 0: |
|
replace_lora_linear( |
|
module, |
|
r=r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
layer=layer, |
|
quant_type=quant_type, |
|
use_ckpting=use_ckpting, |
|
threshold=threshold, |
|
compute_dtype=compute_dtype, |
|
) |
|
|
|
if isinstance(module, nn.Linear) and name == layer: |
|
model._modules[name] = QLoraLinear( |
|
module.in_features, |
|
module.out_features, |
|
r=r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
bias=module.bias is not None, |
|
quant_type=quant_type, |
|
use_ckpting=use_ckpting, |
|
threshold=threshold, |
|
compute_dtype=compute_dtype, |
|
) |
|
return model |
|
|
|
|
|
def replace_lora_embedding(model, r=2, lora_alpha=1): |
|
""" |
|
Function replacing Embeddings with LoRa ones recursively. |
|
Args: |
|
model: |
|
r: rank of matrix of the Low Rank layer |
|
lora_alpha: cf paper |
|
""" |
|
for name, module in model.named_children(): |
|
if len(list(module.children())) > 0: |
|
replace_lora_embedding(module, r, lora_alpha) |
|
if isinstance(module, nn.Embedding): |
|
model._modules[name] = Embedding( |
|
module.num_embeddings, |
|
module.embedding_dim, |
|
r=r, |
|
lora_alpha=lora_alpha, |
|
padding_idx=module.padding_idx, |
|
sparse=module.sparse, |
|
) |
|
return model |
|
|