|
from dataclasses import dataclass, field |
|
from typing import Optional, Literal |
|
import torch |
|
import pyrallis |
|
from transformers import PretrainedConfig |
|
from typing import Optional |
|
from dataclasses import asdict |
|
|
|
|
|
@dataclass |
|
class TrainingConfig: |
|
|
|
model_name: str = "unsloth/Meta-Llama-3.1-8B" |
|
layer: int = 12 |
|
hook_point: str = "resid_mid" |
|
act_size: Optional[int] = None |
|
|
|
|
|
sae_type: str = "batchtopk" |
|
dict_size: int = 2**15 |
|
aux_penalty: float = 1/32 |
|
input_unit_norm: bool = True |
|
|
|
|
|
top_k: int = 50 |
|
top_k_warmup_steps_fraction: float = 0.1 |
|
start_top_k: int = 4096 |
|
top_k_aux: int = 512 |
|
|
|
n_batches_to_dead: int = 10 |
|
|
|
|
|
lr: float = 3e-4 |
|
bandwidth: float = 0.001 |
|
l1_coeff: float = 0.0018 |
|
num_tokens: int = int(1e9) |
|
seq_len: int = 1024 |
|
model_batch_size: int = 16 |
|
num_batches_in_buffer: int = 5 |
|
max_grad_norm: float = 1.0 |
|
batch_size: int = 8192 |
|
|
|
|
|
warmup_fraction: float = 0.1 |
|
scheduler_type: str = 'linear' |
|
|
|
|
|
device: str = "cuda" |
|
dtype: torch.dtype = field(default=torch.float32) |
|
sae_dtype: torch.dtype = field(default=torch.float32) |
|
|
|
|
|
dataset_path: str = "cerebras/SlimPajama-627B" |
|
|
|
|
|
wandb_project: str = "turbo-llama-lens" |
|
|
|
performance_log_steps: int = 100 |
|
save_checkpoint_steps: int = 10_000 |
|
def __post_init__(self): |
|
if self.device == "cuda" and not torch.cuda.is_available(): |
|
print("CUDA not available, falling back to CPU") |
|
self.device = "cpu" |
|
|
|
|
|
if isinstance(self.dtype, str): |
|
self.dtype = getattr(torch, self.dtype) |
|
|
|
|
|
class SAEConfig(PretrainedConfig): |
|
model_type = "sae" |
|
|
|
def __init__( |
|
self, |
|
|
|
act_size: int = None, |
|
dict_size: int = 2**15, |
|
sae_type: str = "batchtopk", |
|
input_unit_norm: bool = True, |
|
|
|
|
|
top_k: int = 50, |
|
top_k_aux: int = 512, |
|
n_batches_to_dead: int = 10, |
|
|
|
|
|
aux_penalty: float = 1/32, |
|
l1_coeff: float = 0.0018, |
|
bandwidth: float = 0.001, |
|
|
|
|
|
dtype: str = "float32", |
|
sae_dtype: str = "float32", |
|
|
|
|
|
parent_model_name: Optional[str] = None, |
|
parent_layer: Optional[int] = None, |
|
parent_hook_point: Optional[str] = None, |
|
|
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.act_size = act_size |
|
self.dict_size = dict_size |
|
self.sae_type = sae_type |
|
self.input_unit_norm = input_unit_norm |
|
|
|
self.top_k = top_k |
|
self.top_k_aux = top_k_aux |
|
self.n_batches_to_dead = n_batches_to_dead |
|
|
|
self.aux_penalty = aux_penalty |
|
self.l1_coeff = l1_coeff |
|
self.bandwidth = bandwidth |
|
|
|
self.dtype = dtype |
|
self.sae_dtype = sae_dtype |
|
|
|
self.parent_model_name = parent_model_name |
|
self.parent_layer = parent_layer |
|
self.parent_hook_point = parent_hook_point |
|
|
|
def get_torch_dtype(self, dtype_str: str) -> torch.dtype: |
|
dtype_map = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16, |
|
} |
|
return dtype_map.get(dtype_str, torch.float32) |
|
|
|
@classmethod |
|
def from_training_config(cls, cfg: TrainingConfig): |
|
"""Convert TrainingConfig to SAEConfig""" |
|
return cls( |
|
act_size=cfg.act_size, |
|
dict_size=cfg.dict_size, |
|
sae_type=cfg.sae_type, |
|
input_unit_norm=cfg.input_unit_norm, |
|
top_k=cfg.top_k, |
|
top_k_aux=cfg.top_k_aux, |
|
n_batches_to_dead=cfg.n_batches_to_dead, |
|
aux_penalty=cfg.aux_penalty, |
|
l1_coeff=cfg.l1_coeff, |
|
bandwidth=cfg.bandwidth, |
|
dtype=str(cfg.dtype).split('.')[-1], |
|
sae_dtype=str(cfg.sae_dtype).split('.')[-1], |
|
parent_model_name=cfg.model_name, |
|
parent_layer=cfg.layer, |
|
parent_hook_point=cfg.hook_point, |
|
) |
|
|
|
def to_training_config(self) -> TrainingConfig: |
|
"""Convert SAEConfig back to TrainingConfig""" |
|
config_dict = asdict(self) |
|
config_dict['dtype'] = self.get_torch_dtype(self.dtype) |
|
config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype) |
|
config_dict['model_name'] = self.parent_model_name |
|
config_dict['layer'] = self.parent_layer |
|
config_dict['hook_point'] = self.parent_hook_point |
|
return TrainingConfig(**config_dict) |
|
|
|
|
|
@pyrallis.wrap() |
|
def get_config() -> TrainingConfig: |
|
return TrainingConfig() |
|
|
|
|
|
|
|
def get_default_cfg() -> TrainingConfig: |
|
return get_config() |
|
|
|
|
|
def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig: |
|
""" |
|
Any additional configuration setup that needs to happen after model initialization |
|
""" |
|
return cfg |