Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,624 Bytes
6ed1db6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from peft.tuners.tuners_utils import BaseTunerLayer
from typing import List, Any, Optional, Type
class enable_lora:
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
self.activated: bool = activated
if activated:
return
self.lora_modules: List[BaseTunerLayer] = [
each for each in lora_modules if isinstance(each, BaseTunerLayer)
]
self.scales = [
{
active_adapter: lora_module.scaling[active_adapter]
for active_adapter in lora_module.active_adapters
}
for lora_module in self.lora_modules
]
def __enter__(self) -> None:
if self.activated:
return
for lora_module in self.lora_modules:
if not isinstance(lora_module, BaseTunerLayer):
continue
lora_module.scale_layer(0)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
if self.activated:
return
for i, lora_module in enumerate(self.lora_modules):
if not isinstance(lora_module, BaseTunerLayer):
continue
for active_adapter in lora_module.active_adapters:
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
class set_lora_scale:
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
self.lora_modules: List[BaseTunerLayer] = [
each for each in lora_modules if isinstance(each, BaseTunerLayer)
]
self.scales = [
{
active_adapter: lora_module.scaling[active_adapter]
for active_adapter in lora_module.active_adapters
}
for lora_module in self.lora_modules
]
self.scale = scale
def __enter__(self) -> None:
for lora_module in self.lora_modules:
if not isinstance(lora_module, BaseTunerLayer):
continue
lora_module.scale_layer(self.scale)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
for i, lora_module in enumerate(self.lora_modules):
if not isinstance(lora_module, BaseTunerLayer):
continue
for active_adapter in lora_module.active_adapters:
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|