test-sae / config.py
elephantmipt's picture
Upload BatchTopKSAE
ca2139a verified
from dataclasses import dataclass, field
from typing import Optional, Literal
import torch
import pyrallis
from transformers import PretrainedConfig
from typing import Optional
from dataclasses import asdict
@dataclass
class TrainingConfig:
# Model settings
model_name: str = "unsloth/Meta-Llama-3.1-8B"
layer: int = 12
hook_point: str = "resid_mid"
act_size: Optional[int] = None # Will be set after model initialization
# SAE settings
sae_type: str = "batchtopk"
dict_size: int = 2**15
aux_penalty: float = 1/32
input_unit_norm: bool = True
# TopK specific settings
top_k: int = 50
top_k_warmup_steps_fraction: float = 0.1
start_top_k: int = 4096
top_k_aux: int = 512
n_batches_to_dead: int = 10
# Training settings
lr: float = 3e-4
bandwidth: float = 0.001
l1_coeff: float = 0.0018
num_tokens: int = int(1e9)
seq_len: int = 1024
model_batch_size: int = 16
num_batches_in_buffer: int = 5
max_grad_norm: float = 1.0
batch_size: int = 8192
# scheduler
warmup_fraction: float = 0.1
scheduler_type: str = 'linear'
# Hardware settings
device: str = "cuda"
dtype: torch.dtype = field(default=torch.float32)
sae_dtype: torch.dtype = field(default=torch.float32)
# Dataset settings
dataset_path: str = "cerebras/SlimPajama-627B"
# Logging settings
wandb_project: str = "turbo-llama-lens"
performance_log_steps: int = 100
save_checkpoint_steps: int = 10_000
def __post_init__(self):
if self.device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
self.device = "cpu"
# Convert string dtype to torch.dtype if needed
if isinstance(self.dtype, str):
self.dtype = getattr(torch, self.dtype)
class SAEConfig(PretrainedConfig):
model_type = "sae"
def __init__(
self,
# SAE architecture
act_size: int = None,
dict_size: int = 2**15,
sae_type: str = "batchtopk",
input_unit_norm: bool = True,
# TopK specific settings
top_k: int = 50,
top_k_aux: int = 512,
n_batches_to_dead: int = 10,
# Training hyperparameters
aux_penalty: float = 1/32,
l1_coeff: float = 0.0018,
bandwidth: float = 0.001,
# Hardware settings
dtype: str = "float32",
sae_dtype: str = "float32",
# Optional parent model info
parent_model_name: Optional[str] = None,
parent_layer: Optional[int] = None,
parent_hook_point: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.act_size = act_size
self.dict_size = dict_size
self.sae_type = sae_type
self.input_unit_norm = input_unit_norm
self.top_k = top_k
self.top_k_aux = top_k_aux
self.n_batches_to_dead = n_batches_to_dead
self.aux_penalty = aux_penalty
self.l1_coeff = l1_coeff
self.bandwidth = bandwidth
self.dtype = dtype
self.sae_dtype = sae_dtype
self.parent_model_name = parent_model_name
self.parent_layer = parent_layer
self.parent_hook_point = parent_hook_point
def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
return dtype_map.get(dtype_str, torch.float32)
@classmethod
def from_training_config(cls, cfg: TrainingConfig):
"""Convert TrainingConfig to SAEConfig"""
return cls(
act_size=cfg.act_size,
dict_size=cfg.dict_size,
sae_type=cfg.sae_type,
input_unit_norm=cfg.input_unit_norm,
top_k=cfg.top_k,
top_k_aux=cfg.top_k_aux,
n_batches_to_dead=cfg.n_batches_to_dead,
aux_penalty=cfg.aux_penalty,
l1_coeff=cfg.l1_coeff,
bandwidth=cfg.bandwidth,
dtype=str(cfg.dtype).split('.')[-1],
sae_dtype=str(cfg.sae_dtype).split('.')[-1],
parent_model_name=cfg.model_name,
parent_layer=cfg.layer,
parent_hook_point=cfg.hook_point,
)
def to_training_config(self) -> TrainingConfig:
"""Convert SAEConfig back to TrainingConfig"""
config_dict = asdict(self)
config_dict['dtype'] = self.get_torch_dtype(self.dtype)
config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype)
config_dict['model_name'] = self.parent_model_name
config_dict['layer'] = self.parent_layer
config_dict['hook_point'] = self.parent_hook_point
return TrainingConfig(**config_dict)
@pyrallis.wrap()
def get_config() -> TrainingConfig:
return TrainingConfig()
# For backward compatibility
def get_default_cfg() -> TrainingConfig:
return get_config()
def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig:
"""
Any additional configuration setup that needs to happen after model initialization
"""
return cfg