winglian commited on
Commit
25afd35
·
unverified ·
1 Parent(s): da265dd

support layer replication for peft and fix rslora integration (#1445)

Browse files
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -6,7 +6,7 @@ Module for pydantic models for configuration
6
  import logging
7
  import os
8
  from enum import Enum
9
- from typing import Any, Dict, List, Literal, Optional, Union
10
 
11
  from pydantic import BaseModel, Field, conlist, field_validator, model_validator
12
  from transformers import SchedulerType
@@ -179,7 +179,8 @@ class LoraConfig(BaseModel):
179
  peft_layers_to_transform: Optional[List[int]] = None
180
  peft: Optional[PeftConfig] = None
181
  peft_use_dora: Optional[bool] = None
182
- peft_use_relora: Optional[bool] = None
 
183
 
184
  lora_on_cpu: Optional[bool] = None
185
  gptq: Optional[bool] = None
 
6
  import logging
7
  import os
8
  from enum import Enum
9
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
10
 
11
  from pydantic import BaseModel, Field, conlist, field_validator, model_validator
12
  from transformers import SchedulerType
 
179
  peft_layers_to_transform: Optional[List[int]] = None
180
  peft: Optional[PeftConfig] = None
181
  peft_use_dora: Optional[bool] = None
182
+ peft_use_rslora: Optional[bool] = None
183
+ peft_layer_replication: Optional[List[Tuple[int, int]]] = None
184
 
185
  lora_on_cpu: Optional[bool] = None
186
  gptq: Optional[bool] = None
src/axolotl/utils/models.py CHANGED
@@ -849,7 +849,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
849
  if cfg.peft_use_dora:
850
  lora_config_kwargs["use_dora"] = cfg.peft_use_dora
851
  if cfg.peft_use_rslora:
852
- lora_config_kwargs["use_rslora"] = cfg.use_rslora
 
 
853
 
854
  lora_config = LoraConfig(
855
  r=cfg.lora_r,
 
849
  if cfg.peft_use_dora:
850
  lora_config_kwargs["use_dora"] = cfg.peft_use_dora
851
  if cfg.peft_use_rslora:
852
+ lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
853
+ if cfg.peft_layer_replication:
854
+ lora_config_kwargs["peft_layer_replication"] = cfg.peft_layer_replication
855
 
856
  lora_config = LoraConfig(
857
  r=cfg.lora_r,