pozoviy commited on
Commit
4c95a08
·
verified ·
1 Parent(s): 883f13e

Upload BatchTopKSAE

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +28 -0
  3. config.py +203 -0
  4. model.py +562 -0
  5. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_size": 512,
3
+ "architectures": [
4
+ "BatchTopKSAE"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.SAEConfig",
8
+ "AutoModel": "model.BatchTopKSAE"
9
+ },
10
+ "aux_penalty": 0.03125,
11
+ "bandwidth": 0.001,
12
+ "dict_size": 16384,
13
+ "dtype": "float32",
14
+ "input_mean": -3.768720489460975e-06,
15
+ "input_std": 0.3794997036457062,
16
+ "input_unit_norm": true,
17
+ "l1_coeff": 0.001,
18
+ "model_type": "sae",
19
+ "n_batches_to_dead": 10,
20
+ "parent_hook_point": "resid_post",
21
+ "parent_layer": 1,
22
+ "parent_model_name": "EleutherAI/pythia-70m",
23
+ "sae_dtype": "float32",
24
+ "sae_type": "batchtopk",
25
+ "top_k_aux": 512,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.47.1"
28
+ }
config.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+ import torch
4
+ import pyrallis
5
+ from transformers import PretrainedConfig
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class TrainingConfig:
11
+ # Model settings
12
+ model_name: str = "unsloth/Meta-Llama-3.1-8B"
13
+ layer: int = 12
14
+ hook_point: str = "resid_post"
15
+ act_size: Optional[int] = None # Will be set after model initialization
16
+
17
+ # SAE settings
18
+ sae_type: str = "batchtopk"
19
+ dict_size: int = 2**15
20
+ aux_penalty: float = 1/32
21
+ input_unit_norm: bool = True
22
+
23
+ # TopK specific settings
24
+ top_k: int = 50
25
+ top_k_warmup_steps_fraction: float = 0.1
26
+ start_top_k: int = 4096
27
+ top_k_aux: int = 512
28
+
29
+ n_batches_to_dead: int = 10
30
+
31
+ # Training settings
32
+ lr: float = 3e-4
33
+ bandwidth: float = 0.001
34
+ l1_coeff: float = 0.0018
35
+ num_tokens: int = int(1e9)
36
+ seq_len: int = 1024
37
+ model_batch_size: int = 16
38
+ num_batches_in_buffer: int = 5
39
+ max_grad_norm: float = 1.0
40
+ batch_size: int = 8192
41
+
42
+ # scheduler
43
+ warmup_fraction: float = 0.1
44
+ scheduler_type: str = 'linear'
45
+
46
+ # Hardware settings
47
+ device: str = "cuda"
48
+ dtype: torch.dtype = field(default=torch.float32)
49
+ sae_dtype: torch.dtype = field(default=torch.float32)
50
+
51
+ # Dataset settings
52
+ dataset_path: str = "cerebras/SlimPajama-627B"
53
+
54
+ # Logging settings
55
+ wandb_project: str = "turbo-llama-lens"
56
+
57
+ performance_log_steps: int = 100
58
+ save_checkpoint_steps: int = 10_000
59
+ def __post_init__(self):
60
+ if self.device == "cuda" and not torch.cuda.is_available():
61
+ print("CUDA not available, falling back to CPU")
62
+ self.device = "cpu"
63
+
64
+ # Convert string dtype to torch.dtype if needed
65
+ if isinstance(self.dtype, str):
66
+ self.dtype = getattr(torch, self.dtype)
67
+
68
+
69
+ class SAEConfig(PretrainedConfig):
70
+ model_type = "sae"
71
+
72
+ def __init__(
73
+ self,
74
+ # SAE architecture
75
+ act_size: int = None,
76
+ dict_size: int = 2**15,
77
+ sae_type: str = "batchtopk",
78
+ input_unit_norm: bool = True,
79
+
80
+ # TopK specific settings
81
+ top_k: int = 50,
82
+ top_k_aux: int = 512,
83
+ n_batches_to_dead: int = 10,
84
+
85
+ # Training hyperparameters
86
+ aux_penalty: float = 1/32,
87
+ l1_coeff: float = 0.0018,
88
+ bandwidth: float = 0.001,
89
+
90
+ # Hardware settings
91
+ dtype: str = "float32",
92
+ sae_dtype: str = "float32",
93
+
94
+ # Optional parent model info
95
+ parent_model_name: Optional[str] = None,
96
+ parent_layer: Optional[int] = None,
97
+ parent_hook_point: Optional[str] = None,
98
+
99
+ # Input normalization settings
100
+ input_mean: Optional[float] = None,
101
+ input_std: Optional[float] = None,
102
+
103
+ **kwargs
104
+ ):
105
+ super().__init__(**kwargs)
106
+ self.act_size = act_size
107
+ self.dict_size = dict_size
108
+ self.sae_type = sae_type
109
+ self.input_unit_norm = input_unit_norm
110
+
111
+ self.top_k = top_k
112
+ self.top_k_aux = top_k_aux
113
+ self.n_batches_to_dead = n_batches_to_dead
114
+
115
+ self.aux_penalty = aux_penalty
116
+ self.l1_coeff = l1_coeff
117
+ self.bandwidth = bandwidth
118
+
119
+ self.dtype = dtype
120
+ self.sae_dtype = sae_dtype
121
+
122
+ self.parent_model_name = parent_model_name
123
+ self.parent_layer = parent_layer
124
+ self.parent_hook_point = parent_hook_point
125
+
126
+ self.input_mean = input_mean
127
+ self.input_std = input_std
128
+
129
+ def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
130
+ dtype_map = {
131
+ "float32": torch.float32,
132
+ "float16": torch.float16,
133
+ "bfloat16": torch.bfloat16,
134
+ }
135
+ return dtype_map.get(dtype_str, torch.float32)
136
+
137
+ @classmethod
138
+ def from_training_config(cls, cfg: TrainingConfig):
139
+ """Convert TrainingConfig to SAEConfig"""
140
+ return cls(
141
+ act_size=cfg.act_size,
142
+ dict_size=cfg.dict_size,
143
+ sae_type=cfg.sae_type,
144
+ input_unit_norm=cfg.input_unit_norm,
145
+ top_k=cfg.top_k,
146
+ top_k_aux=cfg.top_k_aux,
147
+ n_batches_to_dead=cfg.n_batches_to_dead,
148
+ aux_penalty=cfg.aux_penalty,
149
+ l1_coeff=cfg.l1_coeff,
150
+ bandwidth=cfg.bandwidth,
151
+ dtype=str(cfg.dtype).split('.')[-1],
152
+ sae_dtype=str(cfg.sae_dtype).split('.')[-1],
153
+ parent_model_name=cfg.model_name,
154
+ parent_layer=cfg.layer,
155
+ parent_hook_point=cfg.hook_point,
156
+ input_mean=cfg.input_mean if hasattr(cfg, 'input_mean') else None,
157
+ input_std=cfg.input_std if hasattr(cfg, 'input_std') else None,
158
+ )
159
+
160
+ def to_training_config(self) -> TrainingConfig:
161
+ """Convert SAEConfig back to TrainingConfig"""
162
+ return TrainingConfig(
163
+ dtype=self.get_torch_dtype(self.dtype),
164
+ sae_dtype=self.get_torch_dtype(self.sae_dtype),
165
+ model_name=self.parent_model_name,
166
+ layer=self.parent_layer,
167
+ hook_point=self.parent_hook_point,
168
+ act_size=self.act_size,
169
+ dict_size=self.dict_size,
170
+ sae_type=self.sae_type,
171
+ input_unit_norm=self.input_unit_norm,
172
+ top_k=self.top_k,
173
+ top_k_aux=self.top_k_aux,
174
+ n_batches_to_dead=self.n_batches_to_dead,
175
+ aux_penalty=self.aux_penalty,
176
+ l1_coeff=self.l1_coeff,
177
+ bandwidth=self.bandwidth,
178
+ )
179
+
180
+
181
+ @pyrallis.wrap()
182
+ def get_config() -> TrainingConfig:
183
+ return TrainingConfig()
184
+
185
+
186
+ # For backward compatibility
187
+ def get_default_cfg() -> TrainingConfig:
188
+ return get_config()
189
+
190
+
191
+ def post_init_cfg(cfg: TrainingConfig, activation_store = None) -> TrainingConfig:
192
+ """
193
+ Any additional configuration setup that needs to happen after model initialization
194
+ Args:
195
+ cfg: Training configuration
196
+ activation_store: Optional activation store to get input statistics
197
+ """
198
+ if activation_store is not None:
199
+ cfg.input_mean = activation_store.mean
200
+ cfg.input_std = activation_store.std
201
+ print(f"Setting input statistics from activation store - Mean: {cfg.input_mean:.4f}, Std: {cfg.input_std:.4f}")
202
+
203
+ return cfg
model.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from typing import Optional, Dict, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.autograd as autograd
7
+ from .config import SAEConfig
8
+
9
+
10
+ class BaseSAE(PreTrainedModel):
11
+ """Base class for autoencoder models."""
12
+ config_class = SAEConfig
13
+ base_model_prefix = "sae"
14
+
15
+ def __init__(self, config: SAEConfig):
16
+ super().__init__(config)
17
+ print(config)
18
+ self.config = config
19
+ torch.manual_seed(42)
20
+
21
+ self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
22
+ self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
23
+ self.W_enc = nn.Parameter(
24
+ torch.nn.init.kaiming_uniform_(
25
+ torch.empty(self.config.act_size, self.config.dict_size)
26
+ )
27
+ )
28
+ self.W_dec = nn.Parameter(
29
+ torch.nn.init.kaiming_uniform_(
30
+ torch.empty(self.config.dict_size, self.config.act_size)
31
+ )
32
+ )
33
+ self.W_dec.data[:] = self.W_enc.t().data
34
+ self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
35
+ self.register_buffer('num_batches_not_active', torch.zeros((self.config.dict_size,)))
36
+
37
+ self.to(self.config.get_torch_dtype(self.config.dtype))
38
+
39
+ # Initialize input normalization parameters if provided
40
+ if config.input_mean is not None and config.input_std is not None:
41
+ self.register_buffer('input_mean', torch.tensor(config.input_mean))
42
+ self.register_buffer('input_std', torch.tensor(config.input_std))
43
+ else:
44
+ self.input_mean = None
45
+ self.input_std = None
46
+
47
+ # Initialize CUDA events for timing
48
+ if torch.cuda.is_available():
49
+ self.start_event = torch.cuda.Event(enable_timing=True)
50
+ self.end_event = torch.cuda.Event(enable_timing=True)
51
+ else:
52
+ self.start_event = None
53
+ self.end_event = None
54
+
55
+ def preprocess_input(self, x):
56
+ x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
57
+ if self.config.input_unit_norm:
58
+ if self.input_mean is not None and self.input_std is not None:
59
+ # Use pre-computed statistics
60
+ x = (x - self.input_mean) / (self.input_std + 1e-5)
61
+ return x, self.input_mean, self.input_std
62
+ else:
63
+ # Compute statistics on the fly
64
+ x_mean = x.mean(dim=-1, keepdim=True)
65
+ x = x - x_mean
66
+ x_std = x.std(dim=-1, keepdim=True)
67
+ x = x / (x_std + 1e-5)
68
+ return x, x_mean, x_std
69
+ else:
70
+ return x, None, None
71
+
72
+ def postprocess_output(self, x_reconstruct, x_mean, x_std):
73
+ if self.config.input_unit_norm:
74
+ x_reconstruct = x_reconstruct * x_std + x_mean
75
+ return x_reconstruct
76
+
77
+ @torch.no_grad()
78
+ def make_decoder_weights_and_grad_unit_norm(self):
79
+ W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
80
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
81
+ -1, keepdim=True
82
+ ) * W_dec_normed
83
+ self.W_dec.grad -= W_dec_grad_proj
84
+ self.W_dec.data = W_dec_normed
85
+
86
+ def update_inactive_features(self, acts):
87
+ self.num_batches_not_active += (acts.sum(0) == 0).float()
88
+ self.num_batches_not_active[acts.sum(0) > 0] = 0
89
+
90
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ Encode input tensor to sparse features
93
+ Args:
94
+ x: Input tensor of shape (batch_size, act_size)
95
+ Returns:
96
+ Encoded features of shape (batch_size, dict_size)
97
+ """
98
+ if self.config.input_unit_norm:
99
+ x_mean = x.mean(dim=-1, keepdim=True)
100
+ x = x - x_mean
101
+ x_std = x.std(dim=-1, keepdim=True)
102
+ x = x / (x_std + 1e-5)
103
+
104
+ return F.relu(x @ self.W_enc + self.b_enc)
105
+
106
+ def decode(self, h: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Decode features back to input space
109
+ Args:
110
+ h: Encoded features of shape (batch_size, dict_size)
111
+ Returns:
112
+ Reconstructed input of shape (batch_size, act_size)
113
+ """
114
+ return h @ self.W_dec + self.b_dec
115
+
116
+ def forward(self, x):
117
+ # Start timing if CUDA is available
118
+ if self.start_event is not None:
119
+ self.start_event.record()
120
+
121
+ x, x_mean, x_std = self.preprocess_input(x)
122
+ acts = F.relu(x @ self.W_enc + self.b_enc)
123
+ x_reconstruct = acts @ self.W_dec + self.b_dec
124
+ self.update_inactive_features(acts)
125
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
126
+
127
+ # End timing if CUDA is available
128
+ if self.end_event is not None:
129
+ self.end_event.record()
130
+ torch.cuda.synchronize()
131
+ output["forward_time_ms"] = self.start_event.elapsed_time(self.end_event)
132
+
133
+ return output
134
+
135
+ @torch.no_grad()
136
+ def fold_stats_into_weights(self, mean: torch.Tensor = None, std: torch.Tensor = None) -> "BaseSAE":
137
+ """
138
+ Fold normalization statistics into the encoder weights and biases
139
+ """
140
+ print("Folding statistics into encoder...")
141
+
142
+ if mean is not None and std is not None:
143
+ mean = mean.to(self.W_enc.device)
144
+ std = std.to(self.W_enc.device)
145
+ else:
146
+ mean = self.input_mean
147
+ std = self.input_std
148
+
149
+ # Original forward pass:
150
+ # x_norm = (x - mean) / std
151
+ # acts = relu(x_norm @ W_enc + b_enc)
152
+ # x_hat = acts @ W_dec + b_dec
153
+
154
+ # Folding steps:
155
+ # 1. x_norm = (x - mean) / std
156
+ # 2. acts = relu(x_norm @ W_enc + b_enc)
157
+ # = relu((x/std - mean/std) @ W_enc + b_enc)
158
+ # = relu(x @ (W_enc/std) - mean @ (W_enc/std) + b_enc)
159
+
160
+ # First scale encoder weights
161
+ self.W_enc.data = self.W_enc / std
162
+
163
+ # Then adjust encoder bias
164
+ self.b_enc.data = self.b_enc - mean * (self.W_enc.sum(0))
165
+
166
+ # Scale decoder to preserve reconstruction
167
+ self.W_dec.data = self.W_dec * std
168
+ self.b_dec.data = self.b_dec * std + mean
169
+
170
+ # Turn off input normalization
171
+ self.config.input_unit_norm = False
172
+
173
+ return self
174
+
175
+ @torch.no_grad()
176
+ def fold_W_dec_norm(self):
177
+ """
178
+ Make decoder weights unit norm and adjust encoder accordingly
179
+ """
180
+ # Get current decoder norms
181
+ W_dec_norm = self.W_dec.norm(dim=-1, keepdim=True)
182
+
183
+ # Original: acts @ W_dec + b_dec
184
+ # After: acts @ (W_dec/norm) + b_dec
185
+ # Need: (acts * norm) @ (W_dec/norm) + b_dec
186
+ # So: scale W_enc by norm.T
187
+
188
+ # Scale encoder weights first
189
+ self.W_enc.data = self.W_enc * W_dec_norm.t()
190
+
191
+ # Then normalize decoder weights
192
+ self.W_dec.data = self.W_dec / W_dec_norm
193
+
194
+ # Scale encoder bias to compensate for the scaling of activations
195
+ self.b_enc.data = self.b_enc * W_dec_norm.squeeze(-1)
196
+
197
+ return self
198
+
199
+ def set_mean_std(self, mean: float, std: float):
200
+ """
201
+ Set input normalization statistics after model initialization
202
+
203
+ Args:
204
+ mean: Mean scalar value for input normalization
205
+ std: Standard deviation scalar value for input normalization
206
+ """
207
+ self.register_buffer('input_mean', torch.tensor(mean, device=self.device))
208
+ self.register_buffer('input_std', torch.tensor(std, device=self.device))
209
+ self.config.input_unit_norm = True
210
+ return self
211
+
212
+
213
+
214
+ class BatchTopKSAE(BaseSAE):
215
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
216
+ """
217
+ Encode input tensor to sparse features with batch-wise top-k
218
+ """
219
+ if self.config.input_unit_norm:
220
+ x_mean = x.mean(dim=-1, keepdim=True)
221
+ x = x - x_mean
222
+ x_std = x.std(dim=-1, keepdim=True)
223
+ x = x / (x_std + 1e-5)
224
+
225
+ acts = F.relu(x @ self.W_enc + self.b_enc)
226
+ acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
227
+ return (
228
+ torch.zeros_like(acts.flatten())
229
+ .scatter(-1, acts_topk.indices, acts_topk.values)
230
+ .reshape(acts.shape)
231
+ )
232
+
233
+ def forward(self, x):
234
+ x, x_mean, x_std = self.preprocess_input(x)
235
+ acts = F.relu(x @ self.W_enc + self.b_enc)
236
+ acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
237
+ acts_topk = (
238
+ torch.zeros_like(acts.flatten())
239
+ .scatter(-1, acts_topk.indices, acts_topk.values)
240
+ .reshape(acts.shape)
241
+ )
242
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
243
+ self.update_inactive_features(acts_topk)
244
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
245
+ return output
246
+
247
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
248
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
249
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
250
+ l1_loss = self.config.l1_coeff * l1_norm
251
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
252
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
253
+ loss = l2_loss + aux_loss
254
+ num_dead_features = (
255
+ self.num_batches_not_active > self.config.n_batches_to_dead
256
+ ).sum()
257
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
258
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
259
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
260
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
261
+ output = {
262
+ "sae_out": sae_out,
263
+ "feature_acts": acts_topk,
264
+ "num_dead_features": num_dead_features,
265
+ "loss": loss,
266
+ "l1_loss": l1_loss,
267
+ "l2_loss": l2_loss,
268
+ "l0_norm": l0_norm,
269
+ "l1_norm": l1_norm,
270
+ "aux_loss": aux_loss,
271
+ "explained_variance": explained_variance,
272
+ "top_k": self.config.top_k
273
+ }
274
+ return output
275
+
276
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
277
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
278
+ if dead_features.sum() > 0:
279
+ residual = x.float() - x_reconstruct.float()
280
+ acts_topk_aux = torch.topk(
281
+ acts[:, dead_features],
282
+ min(self.config.top_k_aux, dead_features.sum()),
283
+ dim=-1,
284
+ )
285
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
286
+ -1, acts_topk_aux.indices, acts_topk_aux.values
287
+ )
288
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
289
+ l2_loss_aux = (
290
+ self.config.aux_penalty
291
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
292
+ )
293
+ return l2_loss_aux
294
+ else:
295
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
296
+
297
+
298
+ class TopKSAE(BaseSAE):
299
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
300
+ """
301
+ Encode input tensor to sparse features with per-sample top-k
302
+ """
303
+ if self.config.input_unit_norm:
304
+ x_mean = x.mean(dim=-1, keepdim=True)
305
+ x = x - x_mean
306
+ x_std = x.std(dim=-1, keepdim=True)
307
+ x = x / (x_std + 1e-5)
308
+
309
+ acts = F.relu(x @ self.W_enc + self.b_enc)
310
+ acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
311
+ return torch.zeros_like(acts).scatter(
312
+ -1, acts_topk.indices, acts_topk.values
313
+ )
314
+
315
+ def forward(self, x):
316
+ x, x_mean, x_std = self.preprocess_input(x)
317
+ acts = F.relu(x @ self.W_enc + self.b_enc)
318
+ acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
319
+ acts_topk = torch.zeros_like(acts).scatter(
320
+ -1, acts_topk.indices, acts_topk.values
321
+ )
322
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
323
+ self.update_inactive_features(acts_topk)
324
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
325
+ return output
326
+
327
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
328
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
329
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
330
+ l1_loss = self.config.l1_coeff * l1_norm
331
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
332
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
333
+ loss = l2_loss + l1_loss + aux_loss
334
+ num_dead_features = (
335
+ self.num_batches_not_active > self.config.n_batches_to_dead
336
+ ).sum()
337
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
338
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
339
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
340
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
341
+ output = {
342
+ "sae_out": sae_out,
343
+ "feature_acts": acts_topk,
344
+ "num_dead_features": num_dead_features,
345
+ "loss": loss,
346
+ "l1_loss": l1_loss,
347
+ "l2_loss": l2_loss,
348
+ "l0_norm": l0_norm,
349
+ "l1_norm": l1_norm,
350
+ "explained_variance": explained_variance,
351
+ "aux_loss": aux_loss,
352
+ }
353
+ return output
354
+
355
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
356
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
357
+ if dead_features.sum() > 0:
358
+ residual = x.float() - x_reconstruct.float()
359
+ acts_topk_aux = torch.topk(
360
+ acts[:, dead_features],
361
+ min(self.config.top_k_aux, dead_features.sum()),
362
+ dim=-1,
363
+ )
364
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
365
+ -1, acts_topk_aux.indices, acts_topk_aux.values
366
+ )
367
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
368
+ l2_loss_aux = (
369
+ self.config.aux_penalty
370
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
371
+ )
372
+ return l2_loss_aux
373
+ else:
374
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
375
+
376
+
377
+ class VanillaSAE(BaseSAE):
378
+ def forward(self, x):
379
+ x, x_mean, x_std = self.preprocess_input(x)
380
+ acts = F.relu(x @ self.W_enc + self.b_enc)
381
+ x_reconstruct = acts @ self.W_dec + self.b_dec
382
+ self.update_inactive_features(acts)
383
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
384
+ return output
385
+
386
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
387
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
388
+ l1_norm = acts.float().abs().sum(-1).mean()
389
+ l1_loss = self.config.l1_coeff * l1_norm
390
+ l0_norm = (acts > 0).float().sum(-1).mean()
391
+ loss = l2_loss + l1_loss
392
+ num_dead_features = (
393
+ self.num_batches_not_active > self.config.n_batches_to_dead
394
+ ).sum()
395
+
396
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
397
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
398
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
399
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
400
+ output = {
401
+ "sae_out": sae_out,
402
+ "feature_acts": acts,
403
+ "num_dead_features": num_dead_features,
404
+ "loss": loss,
405
+ "l1_loss": l1_loss,
406
+ "l2_loss": l2_loss,
407
+ "l0_norm": l0_norm,
408
+ "l1_norm": l1_norm,
409
+ "explained_variance": explained_variance,
410
+ }
411
+
412
+ # Add timing if available
413
+ if hasattr(self, "start_event") and self.start_event is not None:
414
+ output["forward_time_ms"] = self.start_event.elapsed_time(self.end_event)
415
+
416
+ return output
417
+
418
+
419
+ import torch
420
+ import torch.nn as nn
421
+
422
+ class RectangleFunction(autograd.Function):
423
+ @staticmethod
424
+ def forward(ctx, x):
425
+ ctx.save_for_backward(x)
426
+ return ((x > -0.5) & (x < 0.5)).float()
427
+
428
+ @staticmethod
429
+ def backward(ctx, grad_output):
430
+ (x,) = ctx.saved_tensors
431
+ grad_input = grad_output.clone()
432
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
433
+ return grad_input
434
+
435
+ class JumpReLUFunction(autograd.Function):
436
+ @staticmethod
437
+ def forward(ctx, x, log_threshold, bandwidth):
438
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
439
+ threshold = torch.exp(log_threshold)
440
+ return x * (x > threshold).float()
441
+
442
+ @staticmethod
443
+ def backward(ctx, grad_output):
444
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
445
+ bandwidth = bandwidth_tensor.item()
446
+ threshold = torch.exp(log_threshold)
447
+ x_grad = (x > threshold).float() * grad_output
448
+ threshold_grad = (
449
+ -(threshold / bandwidth)
450
+ * RectangleFunction.apply((x - threshold) / bandwidth)
451
+ * grad_output
452
+ )
453
+ return x_grad, threshold_grad, None # None for bandwidth
454
+
455
+ class JumpReLU(nn.Module):
456
+ def __init__(self, feature_size, bandwidth, device='cpu'):
457
+ super(JumpReLU, self).__init__()
458
+ self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
459
+ self.bandwidth = bandwidth
460
+
461
+ def forward(self, x):
462
+ return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
463
+
464
+ class StepFunction(autograd.Function):
465
+ @staticmethod
466
+ def forward(ctx, x, log_threshold, bandwidth):
467
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
468
+ threshold = torch.exp(log_threshold)
469
+ return (x > threshold).float()
470
+
471
+ @staticmethod
472
+ def backward(ctx, grad_output):
473
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
474
+ bandwidth = bandwidth_tensor.item()
475
+ threshold = torch.exp(log_threshold)
476
+ x_grad = torch.zeros_like(x)
477
+ threshold_grad = (
478
+ -(1.0 / bandwidth)
479
+ * RectangleFunction.apply((x - threshold) / bandwidth)
480
+ * grad_output
481
+ )
482
+ return x_grad, threshold_grad, None # None for bandwidth
483
+
484
+ class JumpReLUSAE(BaseSAE):
485
+ def __init__(self, config: SAEConfig):
486
+ super().__init__(config)
487
+ self.jumprelu = JumpReLU(
488
+ feature_size=config.dict_size,
489
+ bandwidth=config.bandwidth,
490
+ device=config.device if hasattr(config, 'device') else 'cpu'
491
+ )
492
+
493
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
494
+ """
495
+ Encode input tensor to sparse features using JumpReLU
496
+ """
497
+ if self.config.input_unit_norm:
498
+ x_mean = x.mean(dim=-1, keepdim=True)
499
+ x = x - x_mean
500
+ x_std = x.std(dim=-1, keepdim=True)
501
+ x = x / (x_std + 1e-5)
502
+
503
+ pre_activations = F.relu(x @ self.W_enc + self.b_enc)
504
+ return self.jumprelu(pre_activations)
505
+
506
+ def forward(self, x):
507
+ x, x_mean, x_std = self.preprocess_input(x)
508
+ pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
509
+ feature_magnitudes = self.jumprelu(pre_activations)
510
+ x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
511
+ return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
512
+
513
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
514
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
515
+
516
+ l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
517
+ l0_loss = self.config.l1_coeff * l0
518
+ l1_loss = l0_loss
519
+
520
+ loss = l2_loss + l1_loss
521
+ num_dead_features = (
522
+ self.num_batches_not_active > self.config.n_batches_to_dead
523
+ ).sum()
524
+
525
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
526
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
527
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
528
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
529
+ output = {
530
+ "sae_out": sae_out,
531
+ "feature_acts": acts,
532
+ "num_dead_features": num_dead_features,
533
+ "loss": loss,
534
+ "l1_loss": l1_loss,
535
+ "l2_loss": l2_loss,
536
+ "l0_norm": l0,
537
+ "l1_norm": l0,
538
+ "explained_variance": explained_variance,
539
+ }
540
+ return output
541
+
542
+ @torch.no_grad()
543
+ def fold_W_dec_norm(self):
544
+ """
545
+ Make decoder weights unit norm and adjust encoder and thresholds accordingly
546
+ """
547
+ # Get current decoder norms
548
+ W_dec_norm = self.W_dec.norm(dim=-1, keepdim=True)
549
+
550
+ # Call parent method to handle weights
551
+ super().fold_W_dec_norm()
552
+
553
+ # Scale thresholds to preserve sparsity (remove keepdim to match threshold shape)
554
+ self.jumprelu.log_threshold.data = self.jumprelu.log_threshold + torch.log(W_dec_norm.squeeze(-1))
555
+
556
+ return self
557
+
558
+
559
+ SAEConfig.register_for_auto_class("AutoConfig")
560
+ BatchTopKSAE.register_for_auto_class()
561
+ JumpReLUSAE.register_for_auto_class()
562
+ VanillaSAE.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f618af78706dddf6d13d50b575599f3d86ea34b7e39613d188daaf5dcdfd1b3
3
+ size 67242576