from dataclasses import dataclass, field from typing import Optional, Literal import torch import pyrallis from transformers import PretrainedConfig from typing import Optional from dataclasses import asdict @dataclass class TrainingConfig: # Model settings model_name: str = "unsloth/Meta-Llama-3.1-8B" layer: int = 12 hook_point: str = "resid_mid" act_size: Optional[int] = None # Will be set after model initialization # SAE settings sae_type: str = "batchtopk" dict_size: int = 2**15 aux_penalty: float = 1/32 input_unit_norm: bool = True # TopK specific settings top_k: int = 50 top_k_warmup_steps_fraction: float = 0.1 start_top_k: int = 4096 top_k_aux: int = 512 n_batches_to_dead: int = 10 # Training settings lr: float = 3e-4 bandwidth: float = 0.001 l1_coeff: float = 0.0018 num_tokens: int = int(1e9) seq_len: int = 1024 model_batch_size: int = 16 num_batches_in_buffer: int = 5 max_grad_norm: float = 1.0 batch_size: int = 8192 # scheduler warmup_fraction: float = 0.1 scheduler_type: str = 'linear' # Hardware settings device: str = "cuda" dtype: torch.dtype = field(default=torch.float32) sae_dtype: torch.dtype = field(default=torch.float32) # Dataset settings dataset_path: str = "cerebras/SlimPajama-627B" # Logging settings wandb_project: str = "turbo-llama-lens" performance_log_steps: int = 100 save_checkpoint_steps: int = 10_000 def __post_init__(self): if self.device == "cuda" and not torch.cuda.is_available(): print("CUDA not available, falling back to CPU") self.device = "cpu" # Convert string dtype to torch.dtype if needed if isinstance(self.dtype, str): self.dtype = getattr(torch, self.dtype) class SAEConfig(PretrainedConfig): model_type = "sae" def __init__( self, # SAE architecture act_size: int = None, dict_size: int = 2**15, sae_type: str = "batchtopk", input_unit_norm: bool = True, # TopK specific settings top_k: int = 50, top_k_aux: int = 512, n_batches_to_dead: int = 10, # Training hyperparameters aux_penalty: float = 1/32, l1_coeff: float = 0.0018, bandwidth: float = 0.001, # Hardware settings dtype: str = "float32", sae_dtype: str = "float32", # Optional parent model info parent_model_name: Optional[str] = None, parent_layer: Optional[int] = None, parent_hook_point: Optional[str] = None, **kwargs ): super().__init__(**kwargs) self.act_size = act_size self.dict_size = dict_size self.sae_type = sae_type self.input_unit_norm = input_unit_norm self.top_k = top_k self.top_k_aux = top_k_aux self.n_batches_to_dead = n_batches_to_dead self.aux_penalty = aux_penalty self.l1_coeff = l1_coeff self.bandwidth = bandwidth self.dtype = dtype self.sae_dtype = sae_dtype self.parent_model_name = parent_model_name self.parent_layer = parent_layer self.parent_hook_point = parent_hook_point def get_torch_dtype(self, dtype_str: str) -> torch.dtype: dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } return dtype_map.get(dtype_str, torch.float32) @classmethod def from_training_config(cls, cfg: TrainingConfig): """Convert TrainingConfig to SAEConfig""" return cls( act_size=cfg.act_size, dict_size=cfg.dict_size, sae_type=cfg.sae_type, input_unit_norm=cfg.input_unit_norm, top_k=cfg.top_k, top_k_aux=cfg.top_k_aux, n_batches_to_dead=cfg.n_batches_to_dead, aux_penalty=cfg.aux_penalty, l1_coeff=cfg.l1_coeff, bandwidth=cfg.bandwidth, dtype=str(cfg.dtype).split('.')[-1], sae_dtype=str(cfg.sae_dtype).split('.')[-1], parent_model_name=cfg.model_name, parent_layer=cfg.layer, parent_hook_point=cfg.hook_point, ) def to_training_config(self) -> TrainingConfig: """Convert SAEConfig back to TrainingConfig""" config_dict = asdict(self) config_dict['dtype'] = self.get_torch_dtype(self.dtype) config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype) config_dict['model_name'] = self.parent_model_name config_dict['layer'] = self.parent_layer config_dict['hook_point'] = self.parent_hook_point return TrainingConfig(**config_dict) @pyrallis.wrap() def get_config() -> TrainingConfig: return TrainingConfig() # For backward compatibility def get_default_cfg() -> TrainingConfig: return get_config() def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig: """ Any additional configuration setup that needs to happen after model initialization """ return cfg