Upload BatchTopKSAE
Browse files- README.md +199 -0
- config.json +28 -0
- config.py +203 -0
- model.py +562 -0
- model.safetensors +3 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"act_size": 512,
|
3 |
+
"architectures": [
|
4 |
+
"BatchTopKSAE"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "config.SAEConfig",
|
8 |
+
"AutoModel": "model.BatchTopKSAE"
|
9 |
+
},
|
10 |
+
"aux_penalty": 0.03125,
|
11 |
+
"bandwidth": 0.001,
|
12 |
+
"dict_size": 16384,
|
13 |
+
"dtype": "float32",
|
14 |
+
"input_mean": 5.984766175970435e-05,
|
15 |
+
"input_std": 3.1487300395965576,
|
16 |
+
"input_unit_norm": true,
|
17 |
+
"l1_coeff": 0.001,
|
18 |
+
"model_type": "sae",
|
19 |
+
"n_batches_to_dead": 10,
|
20 |
+
"parent_hook_point": "resid_post",
|
21 |
+
"parent_layer": 5,
|
22 |
+
"parent_model_name": "EleutherAI/pythia-70m",
|
23 |
+
"sae_dtype": "float32",
|
24 |
+
"sae_type": "batchtopk",
|
25 |
+
"top_k_aux": 512,
|
26 |
+
"torch_dtype": "float32",
|
27 |
+
"transformers_version": "4.47.1"
|
28 |
+
}
|
config.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
import torch
|
4 |
+
import pyrallis
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class TrainingConfig:
|
11 |
+
# Model settings
|
12 |
+
model_name: str = "unsloth/Meta-Llama-3.1-8B"
|
13 |
+
layer: int = 12
|
14 |
+
hook_point: str = "resid_post"
|
15 |
+
act_size: Optional[int] = None # Will be set after model initialization
|
16 |
+
|
17 |
+
# SAE settings
|
18 |
+
sae_type: str = "batchtopk"
|
19 |
+
dict_size: int = 2**15
|
20 |
+
aux_penalty: float = 1/32
|
21 |
+
input_unit_norm: bool = True
|
22 |
+
|
23 |
+
# TopK specific settings
|
24 |
+
top_k: int = 50
|
25 |
+
top_k_warmup_steps_fraction: float = 0.1
|
26 |
+
start_top_k: int = 4096
|
27 |
+
top_k_aux: int = 512
|
28 |
+
|
29 |
+
n_batches_to_dead: int = 10
|
30 |
+
|
31 |
+
# Training settings
|
32 |
+
lr: float = 3e-4
|
33 |
+
bandwidth: float = 0.001
|
34 |
+
l1_coeff: float = 0.0018
|
35 |
+
num_tokens: int = int(1e9)
|
36 |
+
seq_len: int = 1024
|
37 |
+
model_batch_size: int = 16
|
38 |
+
num_batches_in_buffer: int = 5
|
39 |
+
max_grad_norm: float = 1.0
|
40 |
+
batch_size: int = 8192
|
41 |
+
|
42 |
+
# scheduler
|
43 |
+
warmup_fraction: float = 0.1
|
44 |
+
scheduler_type: str = 'linear'
|
45 |
+
|
46 |
+
# Hardware settings
|
47 |
+
device: str = "cuda"
|
48 |
+
dtype: torch.dtype = field(default=torch.float32)
|
49 |
+
sae_dtype: torch.dtype = field(default=torch.float32)
|
50 |
+
|
51 |
+
# Dataset settings
|
52 |
+
dataset_path: str = "cerebras/SlimPajama-627B"
|
53 |
+
|
54 |
+
# Logging settings
|
55 |
+
wandb_project: str = "turbo-llama-lens"
|
56 |
+
|
57 |
+
performance_log_steps: int = 100
|
58 |
+
save_checkpoint_steps: int = 10_000
|
59 |
+
def __post_init__(self):
|
60 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
61 |
+
print("CUDA not available, falling back to CPU")
|
62 |
+
self.device = "cpu"
|
63 |
+
|
64 |
+
# Convert string dtype to torch.dtype if needed
|
65 |
+
if isinstance(self.dtype, str):
|
66 |
+
self.dtype = getattr(torch, self.dtype)
|
67 |
+
|
68 |
+
|
69 |
+
class SAEConfig(PretrainedConfig):
|
70 |
+
model_type = "sae"
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
# SAE architecture
|
75 |
+
act_size: int = None,
|
76 |
+
dict_size: int = 2**15,
|
77 |
+
sae_type: str = "batchtopk",
|
78 |
+
input_unit_norm: bool = True,
|
79 |
+
|
80 |
+
# TopK specific settings
|
81 |
+
top_k: int = 50,
|
82 |
+
top_k_aux: int = 512,
|
83 |
+
n_batches_to_dead: int = 10,
|
84 |
+
|
85 |
+
# Training hyperparameters
|
86 |
+
aux_penalty: float = 1/32,
|
87 |
+
l1_coeff: float = 0.0018,
|
88 |
+
bandwidth: float = 0.001,
|
89 |
+
|
90 |
+
# Hardware settings
|
91 |
+
dtype: str = "float32",
|
92 |
+
sae_dtype: str = "float32",
|
93 |
+
|
94 |
+
# Optional parent model info
|
95 |
+
parent_model_name: Optional[str] = None,
|
96 |
+
parent_layer: Optional[int] = None,
|
97 |
+
parent_hook_point: Optional[str] = None,
|
98 |
+
|
99 |
+
# Input normalization settings
|
100 |
+
input_mean: Optional[float] = None,
|
101 |
+
input_std: Optional[float] = None,
|
102 |
+
|
103 |
+
**kwargs
|
104 |
+
):
|
105 |
+
super().__init__(**kwargs)
|
106 |
+
self.act_size = act_size
|
107 |
+
self.dict_size = dict_size
|
108 |
+
self.sae_type = sae_type
|
109 |
+
self.input_unit_norm = input_unit_norm
|
110 |
+
|
111 |
+
self.top_k = top_k
|
112 |
+
self.top_k_aux = top_k_aux
|
113 |
+
self.n_batches_to_dead = n_batches_to_dead
|
114 |
+
|
115 |
+
self.aux_penalty = aux_penalty
|
116 |
+
self.l1_coeff = l1_coeff
|
117 |
+
self.bandwidth = bandwidth
|
118 |
+
|
119 |
+
self.dtype = dtype
|
120 |
+
self.sae_dtype = sae_dtype
|
121 |
+
|
122 |
+
self.parent_model_name = parent_model_name
|
123 |
+
self.parent_layer = parent_layer
|
124 |
+
self.parent_hook_point = parent_hook_point
|
125 |
+
|
126 |
+
self.input_mean = input_mean
|
127 |
+
self.input_std = input_std
|
128 |
+
|
129 |
+
def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
|
130 |
+
dtype_map = {
|
131 |
+
"float32": torch.float32,
|
132 |
+
"float16": torch.float16,
|
133 |
+
"bfloat16": torch.bfloat16,
|
134 |
+
}
|
135 |
+
return dtype_map.get(dtype_str, torch.float32)
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def from_training_config(cls, cfg: TrainingConfig):
|
139 |
+
"""Convert TrainingConfig to SAEConfig"""
|
140 |
+
return cls(
|
141 |
+
act_size=cfg.act_size,
|
142 |
+
dict_size=cfg.dict_size,
|
143 |
+
sae_type=cfg.sae_type,
|
144 |
+
input_unit_norm=cfg.input_unit_norm,
|
145 |
+
top_k=cfg.top_k,
|
146 |
+
top_k_aux=cfg.top_k_aux,
|
147 |
+
n_batches_to_dead=cfg.n_batches_to_dead,
|
148 |
+
aux_penalty=cfg.aux_penalty,
|
149 |
+
l1_coeff=cfg.l1_coeff,
|
150 |
+
bandwidth=cfg.bandwidth,
|
151 |
+
dtype=str(cfg.dtype).split('.')[-1],
|
152 |
+
sae_dtype=str(cfg.sae_dtype).split('.')[-1],
|
153 |
+
parent_model_name=cfg.model_name,
|
154 |
+
parent_layer=cfg.layer,
|
155 |
+
parent_hook_point=cfg.hook_point,
|
156 |
+
input_mean=cfg.input_mean if hasattr(cfg, 'input_mean') else None,
|
157 |
+
input_std=cfg.input_std if hasattr(cfg, 'input_std') else None,
|
158 |
+
)
|
159 |
+
|
160 |
+
def to_training_config(self) -> TrainingConfig:
|
161 |
+
"""Convert SAEConfig back to TrainingConfig"""
|
162 |
+
return TrainingConfig(
|
163 |
+
dtype=self.get_torch_dtype(self.dtype),
|
164 |
+
sae_dtype=self.get_torch_dtype(self.sae_dtype),
|
165 |
+
model_name=self.parent_model_name,
|
166 |
+
layer=self.parent_layer,
|
167 |
+
hook_point=self.parent_hook_point,
|
168 |
+
act_size=self.act_size,
|
169 |
+
dict_size=self.dict_size,
|
170 |
+
sae_type=self.sae_type,
|
171 |
+
input_unit_norm=self.input_unit_norm,
|
172 |
+
top_k=self.top_k,
|
173 |
+
top_k_aux=self.top_k_aux,
|
174 |
+
n_batches_to_dead=self.n_batches_to_dead,
|
175 |
+
aux_penalty=self.aux_penalty,
|
176 |
+
l1_coeff=self.l1_coeff,
|
177 |
+
bandwidth=self.bandwidth,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
@pyrallis.wrap()
|
182 |
+
def get_config() -> TrainingConfig:
|
183 |
+
return TrainingConfig()
|
184 |
+
|
185 |
+
|
186 |
+
# For backward compatibility
|
187 |
+
def get_default_cfg() -> TrainingConfig:
|
188 |
+
return get_config()
|
189 |
+
|
190 |
+
|
191 |
+
def post_init_cfg(cfg: TrainingConfig, activation_store = None) -> TrainingConfig:
|
192 |
+
"""
|
193 |
+
Any additional configuration setup that needs to happen after model initialization
|
194 |
+
Args:
|
195 |
+
cfg: Training configuration
|
196 |
+
activation_store: Optional activation store to get input statistics
|
197 |
+
"""
|
198 |
+
if activation_store is not None:
|
199 |
+
cfg.input_mean = activation_store.mean
|
200 |
+
cfg.input_std = activation_store.std
|
201 |
+
print(f"Setting input statistics from activation store - Mean: {cfg.input_mean:.4f}, Std: {cfg.input_std:.4f}")
|
202 |
+
|
203 |
+
return cfg
|
model.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from typing import Optional, Dict, Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.autograd as autograd
|
7 |
+
from .config import SAEConfig
|
8 |
+
|
9 |
+
|
10 |
+
class BaseSAE(PreTrainedModel):
|
11 |
+
"""Base class for autoencoder models."""
|
12 |
+
config_class = SAEConfig
|
13 |
+
base_model_prefix = "sae"
|
14 |
+
|
15 |
+
def __init__(self, config: SAEConfig):
|
16 |
+
super().__init__(config)
|
17 |
+
print(config)
|
18 |
+
self.config = config
|
19 |
+
torch.manual_seed(42)
|
20 |
+
|
21 |
+
self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
|
22 |
+
self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
|
23 |
+
self.W_enc = nn.Parameter(
|
24 |
+
torch.nn.init.kaiming_uniform_(
|
25 |
+
torch.empty(self.config.act_size, self.config.dict_size)
|
26 |
+
)
|
27 |
+
)
|
28 |
+
self.W_dec = nn.Parameter(
|
29 |
+
torch.nn.init.kaiming_uniform_(
|
30 |
+
torch.empty(self.config.dict_size, self.config.act_size)
|
31 |
+
)
|
32 |
+
)
|
33 |
+
self.W_dec.data[:] = self.W_enc.t().data
|
34 |
+
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
35 |
+
self.register_buffer('num_batches_not_active', torch.zeros((self.config.dict_size,)))
|
36 |
+
|
37 |
+
self.to(self.config.get_torch_dtype(self.config.dtype))
|
38 |
+
|
39 |
+
# Initialize input normalization parameters if provided
|
40 |
+
if config.input_mean is not None and config.input_std is not None:
|
41 |
+
self.register_buffer('input_mean', torch.tensor(config.input_mean))
|
42 |
+
self.register_buffer('input_std', torch.tensor(config.input_std))
|
43 |
+
else:
|
44 |
+
self.input_mean = None
|
45 |
+
self.input_std = None
|
46 |
+
|
47 |
+
# Initialize CUDA events for timing
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
self.start_event = torch.cuda.Event(enable_timing=True)
|
50 |
+
self.end_event = torch.cuda.Event(enable_timing=True)
|
51 |
+
else:
|
52 |
+
self.start_event = None
|
53 |
+
self.end_event = None
|
54 |
+
|
55 |
+
def preprocess_input(self, x):
|
56 |
+
x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
|
57 |
+
if self.config.input_unit_norm:
|
58 |
+
if self.input_mean is not None and self.input_std is not None:
|
59 |
+
# Use pre-computed statistics
|
60 |
+
x = (x - self.input_mean) / (self.input_std + 1e-5)
|
61 |
+
return x, self.input_mean, self.input_std
|
62 |
+
else:
|
63 |
+
# Compute statistics on the fly
|
64 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
65 |
+
x = x - x_mean
|
66 |
+
x_std = x.std(dim=-1, keepdim=True)
|
67 |
+
x = x / (x_std + 1e-5)
|
68 |
+
return x, x_mean, x_std
|
69 |
+
else:
|
70 |
+
return x, None, None
|
71 |
+
|
72 |
+
def postprocess_output(self, x_reconstruct, x_mean, x_std):
|
73 |
+
if self.config.input_unit_norm:
|
74 |
+
x_reconstruct = x_reconstruct * x_std + x_mean
|
75 |
+
return x_reconstruct
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def make_decoder_weights_and_grad_unit_norm(self):
|
79 |
+
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
80 |
+
W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
|
81 |
+
-1, keepdim=True
|
82 |
+
) * W_dec_normed
|
83 |
+
self.W_dec.grad -= W_dec_grad_proj
|
84 |
+
self.W_dec.data = W_dec_normed
|
85 |
+
|
86 |
+
def update_inactive_features(self, acts):
|
87 |
+
self.num_batches_not_active += (acts.sum(0) == 0).float()
|
88 |
+
self.num_batches_not_active[acts.sum(0) > 0] = 0
|
89 |
+
|
90 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
91 |
+
"""
|
92 |
+
Encode input tensor to sparse features
|
93 |
+
Args:
|
94 |
+
x: Input tensor of shape (batch_size, act_size)
|
95 |
+
Returns:
|
96 |
+
Encoded features of shape (batch_size, dict_size)
|
97 |
+
"""
|
98 |
+
if self.config.input_unit_norm:
|
99 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
100 |
+
x = x - x_mean
|
101 |
+
x_std = x.std(dim=-1, keepdim=True)
|
102 |
+
x = x / (x_std + 1e-5)
|
103 |
+
|
104 |
+
return F.relu(x @ self.W_enc + self.b_enc)
|
105 |
+
|
106 |
+
def decode(self, h: torch.Tensor) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Decode features back to input space
|
109 |
+
Args:
|
110 |
+
h: Encoded features of shape (batch_size, dict_size)
|
111 |
+
Returns:
|
112 |
+
Reconstructed input of shape (batch_size, act_size)
|
113 |
+
"""
|
114 |
+
return h @ self.W_dec + self.b_dec
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
# Start timing if CUDA is available
|
118 |
+
if self.start_event is not None:
|
119 |
+
self.start_event.record()
|
120 |
+
|
121 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
122 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
123 |
+
x_reconstruct = acts @ self.W_dec + self.b_dec
|
124 |
+
self.update_inactive_features(acts)
|
125 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
|
126 |
+
|
127 |
+
# End timing if CUDA is available
|
128 |
+
if self.end_event is not None:
|
129 |
+
self.end_event.record()
|
130 |
+
torch.cuda.synchronize()
|
131 |
+
output["forward_time_ms"] = self.start_event.elapsed_time(self.end_event)
|
132 |
+
|
133 |
+
return output
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def fold_stats_into_weights(self, mean: torch.Tensor = None, std: torch.Tensor = None) -> "BaseSAE":
|
137 |
+
"""
|
138 |
+
Fold normalization statistics into the encoder weights and biases
|
139 |
+
"""
|
140 |
+
print("Folding statistics into encoder...")
|
141 |
+
|
142 |
+
if mean is not None and std is not None:
|
143 |
+
mean = mean.to(self.W_enc.device)
|
144 |
+
std = std.to(self.W_enc.device)
|
145 |
+
else:
|
146 |
+
mean = self.input_mean
|
147 |
+
std = self.input_std
|
148 |
+
|
149 |
+
# Original forward pass:
|
150 |
+
# x_norm = (x - mean) / std
|
151 |
+
# acts = relu(x_norm @ W_enc + b_enc)
|
152 |
+
# x_hat = acts @ W_dec + b_dec
|
153 |
+
|
154 |
+
# Folding steps:
|
155 |
+
# 1. x_norm = (x - mean) / std
|
156 |
+
# 2. acts = relu(x_norm @ W_enc + b_enc)
|
157 |
+
# = relu((x/std - mean/std) @ W_enc + b_enc)
|
158 |
+
# = relu(x @ (W_enc/std) - mean @ (W_enc/std) + b_enc)
|
159 |
+
|
160 |
+
# First scale encoder weights
|
161 |
+
self.W_enc.data = self.W_enc / std
|
162 |
+
|
163 |
+
# Then adjust encoder bias
|
164 |
+
self.b_enc.data = self.b_enc - mean * (self.W_enc.sum(0))
|
165 |
+
|
166 |
+
# Scale decoder to preserve reconstruction
|
167 |
+
self.W_dec.data = self.W_dec * std
|
168 |
+
self.b_dec.data = self.b_dec * std + mean
|
169 |
+
|
170 |
+
# Turn off input normalization
|
171 |
+
self.config.input_unit_norm = False
|
172 |
+
|
173 |
+
return self
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def fold_W_dec_norm(self):
|
177 |
+
"""
|
178 |
+
Make decoder weights unit norm and adjust encoder accordingly
|
179 |
+
"""
|
180 |
+
# Get current decoder norms
|
181 |
+
W_dec_norm = self.W_dec.norm(dim=-1, keepdim=True)
|
182 |
+
|
183 |
+
# Original: acts @ W_dec + b_dec
|
184 |
+
# After: acts @ (W_dec/norm) + b_dec
|
185 |
+
# Need: (acts * norm) @ (W_dec/norm) + b_dec
|
186 |
+
# So: scale W_enc by norm.T
|
187 |
+
|
188 |
+
# Scale encoder weights first
|
189 |
+
self.W_enc.data = self.W_enc * W_dec_norm.t()
|
190 |
+
|
191 |
+
# Then normalize decoder weights
|
192 |
+
self.W_dec.data = self.W_dec / W_dec_norm
|
193 |
+
|
194 |
+
# Scale encoder bias to compensate for the scaling of activations
|
195 |
+
self.b_enc.data = self.b_enc * W_dec_norm.squeeze(-1)
|
196 |
+
|
197 |
+
return self
|
198 |
+
|
199 |
+
def set_mean_std(self, mean: float, std: float):
|
200 |
+
"""
|
201 |
+
Set input normalization statistics after model initialization
|
202 |
+
|
203 |
+
Args:
|
204 |
+
mean: Mean scalar value for input normalization
|
205 |
+
std: Standard deviation scalar value for input normalization
|
206 |
+
"""
|
207 |
+
self.register_buffer('input_mean', torch.tensor(mean, device=self.device))
|
208 |
+
self.register_buffer('input_std', torch.tensor(std, device=self.device))
|
209 |
+
self.config.input_unit_norm = True
|
210 |
+
return self
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
class BatchTopKSAE(BaseSAE):
|
215 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
216 |
+
"""
|
217 |
+
Encode input tensor to sparse features with batch-wise top-k
|
218 |
+
"""
|
219 |
+
if self.config.input_unit_norm:
|
220 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
221 |
+
x = x - x_mean
|
222 |
+
x_std = x.std(dim=-1, keepdim=True)
|
223 |
+
x = x / (x_std + 1e-5)
|
224 |
+
|
225 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
226 |
+
acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
|
227 |
+
return (
|
228 |
+
torch.zeros_like(acts.flatten())
|
229 |
+
.scatter(-1, acts_topk.indices, acts_topk.values)
|
230 |
+
.reshape(acts.shape)
|
231 |
+
)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
235 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
236 |
+
acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
|
237 |
+
acts_topk = (
|
238 |
+
torch.zeros_like(acts.flatten())
|
239 |
+
.scatter(-1, acts_topk.indices, acts_topk.values)
|
240 |
+
.reshape(acts.shape)
|
241 |
+
)
|
242 |
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
243 |
+
self.update_inactive_features(acts_topk)
|
244 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
245 |
+
return output
|
246 |
+
|
247 |
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
248 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
249 |
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
250 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
251 |
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
252 |
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
253 |
+
loss = l2_loss + aux_loss
|
254 |
+
num_dead_features = (
|
255 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
256 |
+
).sum()
|
257 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
258 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
259 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
260 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
261 |
+
output = {
|
262 |
+
"sae_out": sae_out,
|
263 |
+
"feature_acts": acts_topk,
|
264 |
+
"num_dead_features": num_dead_features,
|
265 |
+
"loss": loss,
|
266 |
+
"l1_loss": l1_loss,
|
267 |
+
"l2_loss": l2_loss,
|
268 |
+
"l0_norm": l0_norm,
|
269 |
+
"l1_norm": l1_norm,
|
270 |
+
"aux_loss": aux_loss,
|
271 |
+
"explained_variance": explained_variance,
|
272 |
+
"top_k": self.config.top_k
|
273 |
+
}
|
274 |
+
return output
|
275 |
+
|
276 |
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
277 |
+
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
|
278 |
+
if dead_features.sum() > 0:
|
279 |
+
residual = x.float() - x_reconstruct.float()
|
280 |
+
acts_topk_aux = torch.topk(
|
281 |
+
acts[:, dead_features],
|
282 |
+
min(self.config.top_k_aux, dead_features.sum()),
|
283 |
+
dim=-1,
|
284 |
+
)
|
285 |
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
286 |
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
287 |
+
)
|
288 |
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
289 |
+
l2_loss_aux = (
|
290 |
+
self.config.aux_penalty
|
291 |
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
292 |
+
)
|
293 |
+
return l2_loss_aux
|
294 |
+
else:
|
295 |
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
296 |
+
|
297 |
+
|
298 |
+
class TopKSAE(BaseSAE):
|
299 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
300 |
+
"""
|
301 |
+
Encode input tensor to sparse features with per-sample top-k
|
302 |
+
"""
|
303 |
+
if self.config.input_unit_norm:
|
304 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
305 |
+
x = x - x_mean
|
306 |
+
x_std = x.std(dim=-1, keepdim=True)
|
307 |
+
x = x / (x_std + 1e-5)
|
308 |
+
|
309 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
310 |
+
acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
|
311 |
+
return torch.zeros_like(acts).scatter(
|
312 |
+
-1, acts_topk.indices, acts_topk.values
|
313 |
+
)
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
317 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
318 |
+
acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
|
319 |
+
acts_topk = torch.zeros_like(acts).scatter(
|
320 |
+
-1, acts_topk.indices, acts_topk.values
|
321 |
+
)
|
322 |
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
323 |
+
self.update_inactive_features(acts_topk)
|
324 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
325 |
+
return output
|
326 |
+
|
327 |
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
328 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
329 |
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
330 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
331 |
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
332 |
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
333 |
+
loss = l2_loss + l1_loss + aux_loss
|
334 |
+
num_dead_features = (
|
335 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
336 |
+
).sum()
|
337 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
338 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
339 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
340 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
341 |
+
output = {
|
342 |
+
"sae_out": sae_out,
|
343 |
+
"feature_acts": acts_topk,
|
344 |
+
"num_dead_features": num_dead_features,
|
345 |
+
"loss": loss,
|
346 |
+
"l1_loss": l1_loss,
|
347 |
+
"l2_loss": l2_loss,
|
348 |
+
"l0_norm": l0_norm,
|
349 |
+
"l1_norm": l1_norm,
|
350 |
+
"explained_variance": explained_variance,
|
351 |
+
"aux_loss": aux_loss,
|
352 |
+
}
|
353 |
+
return output
|
354 |
+
|
355 |
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
356 |
+
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
|
357 |
+
if dead_features.sum() > 0:
|
358 |
+
residual = x.float() - x_reconstruct.float()
|
359 |
+
acts_topk_aux = torch.topk(
|
360 |
+
acts[:, dead_features],
|
361 |
+
min(self.config.top_k_aux, dead_features.sum()),
|
362 |
+
dim=-1,
|
363 |
+
)
|
364 |
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
365 |
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
366 |
+
)
|
367 |
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
368 |
+
l2_loss_aux = (
|
369 |
+
self.config.aux_penalty
|
370 |
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
371 |
+
)
|
372 |
+
return l2_loss_aux
|
373 |
+
else:
|
374 |
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
375 |
+
|
376 |
+
|
377 |
+
class VanillaSAE(BaseSAE):
|
378 |
+
def forward(self, x):
|
379 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
380 |
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
381 |
+
x_reconstruct = acts @ self.W_dec + self.b_dec
|
382 |
+
self.update_inactive_features(acts)
|
383 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
|
384 |
+
return output
|
385 |
+
|
386 |
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
387 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
388 |
+
l1_norm = acts.float().abs().sum(-1).mean()
|
389 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
390 |
+
l0_norm = (acts > 0).float().sum(-1).mean()
|
391 |
+
loss = l2_loss + l1_loss
|
392 |
+
num_dead_features = (
|
393 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
394 |
+
).sum()
|
395 |
+
|
396 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
397 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
398 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
399 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
400 |
+
output = {
|
401 |
+
"sae_out": sae_out,
|
402 |
+
"feature_acts": acts,
|
403 |
+
"num_dead_features": num_dead_features,
|
404 |
+
"loss": loss,
|
405 |
+
"l1_loss": l1_loss,
|
406 |
+
"l2_loss": l2_loss,
|
407 |
+
"l0_norm": l0_norm,
|
408 |
+
"l1_norm": l1_norm,
|
409 |
+
"explained_variance": explained_variance,
|
410 |
+
}
|
411 |
+
|
412 |
+
# Add timing if available
|
413 |
+
if hasattr(self, "start_event") and self.start_event is not None:
|
414 |
+
output["forward_time_ms"] = self.start_event.elapsed_time(self.end_event)
|
415 |
+
|
416 |
+
return output
|
417 |
+
|
418 |
+
|
419 |
+
import torch
|
420 |
+
import torch.nn as nn
|
421 |
+
|
422 |
+
class RectangleFunction(autograd.Function):
|
423 |
+
@staticmethod
|
424 |
+
def forward(ctx, x):
|
425 |
+
ctx.save_for_backward(x)
|
426 |
+
return ((x > -0.5) & (x < 0.5)).float()
|
427 |
+
|
428 |
+
@staticmethod
|
429 |
+
def backward(ctx, grad_output):
|
430 |
+
(x,) = ctx.saved_tensors
|
431 |
+
grad_input = grad_output.clone()
|
432 |
+
grad_input[(x <= -0.5) | (x >= 0.5)] = 0
|
433 |
+
return grad_input
|
434 |
+
|
435 |
+
class JumpReLUFunction(autograd.Function):
|
436 |
+
@staticmethod
|
437 |
+
def forward(ctx, x, log_threshold, bandwidth):
|
438 |
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
439 |
+
threshold = torch.exp(log_threshold)
|
440 |
+
return x * (x > threshold).float()
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def backward(ctx, grad_output):
|
444 |
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
445 |
+
bandwidth = bandwidth_tensor.item()
|
446 |
+
threshold = torch.exp(log_threshold)
|
447 |
+
x_grad = (x > threshold).float() * grad_output
|
448 |
+
threshold_grad = (
|
449 |
+
-(threshold / bandwidth)
|
450 |
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
451 |
+
* grad_output
|
452 |
+
)
|
453 |
+
return x_grad, threshold_grad, None # None for bandwidth
|
454 |
+
|
455 |
+
class JumpReLU(nn.Module):
|
456 |
+
def __init__(self, feature_size, bandwidth, device='cpu'):
|
457 |
+
super(JumpReLU, self).__init__()
|
458 |
+
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
|
459 |
+
self.bandwidth = bandwidth
|
460 |
+
|
461 |
+
def forward(self, x):
|
462 |
+
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
|
463 |
+
|
464 |
+
class StepFunction(autograd.Function):
|
465 |
+
@staticmethod
|
466 |
+
def forward(ctx, x, log_threshold, bandwidth):
|
467 |
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
468 |
+
threshold = torch.exp(log_threshold)
|
469 |
+
return (x > threshold).float()
|
470 |
+
|
471 |
+
@staticmethod
|
472 |
+
def backward(ctx, grad_output):
|
473 |
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
474 |
+
bandwidth = bandwidth_tensor.item()
|
475 |
+
threshold = torch.exp(log_threshold)
|
476 |
+
x_grad = torch.zeros_like(x)
|
477 |
+
threshold_grad = (
|
478 |
+
-(1.0 / bandwidth)
|
479 |
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
480 |
+
* grad_output
|
481 |
+
)
|
482 |
+
return x_grad, threshold_grad, None # None for bandwidth
|
483 |
+
|
484 |
+
class JumpReLUSAE(BaseSAE):
|
485 |
+
def __init__(self, config: SAEConfig):
|
486 |
+
super().__init__(config)
|
487 |
+
self.jumprelu = JumpReLU(
|
488 |
+
feature_size=config.dict_size,
|
489 |
+
bandwidth=config.bandwidth,
|
490 |
+
device=config.device if hasattr(config, 'device') else 'cpu'
|
491 |
+
)
|
492 |
+
|
493 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
494 |
+
"""
|
495 |
+
Encode input tensor to sparse features using JumpReLU
|
496 |
+
"""
|
497 |
+
if self.config.input_unit_norm:
|
498 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
499 |
+
x = x - x_mean
|
500 |
+
x_std = x.std(dim=-1, keepdim=True)
|
501 |
+
x = x / (x_std + 1e-5)
|
502 |
+
|
503 |
+
pre_activations = F.relu(x @ self.W_enc + self.b_enc)
|
504 |
+
return self.jumprelu(pre_activations)
|
505 |
+
|
506 |
+
def forward(self, x):
|
507 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
508 |
+
pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
|
509 |
+
feature_magnitudes = self.jumprelu(pre_activations)
|
510 |
+
x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
|
511 |
+
return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
|
512 |
+
|
513 |
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
514 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
515 |
+
|
516 |
+
l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
|
517 |
+
l0_loss = self.config.l1_coeff * l0
|
518 |
+
l1_loss = l0_loss
|
519 |
+
|
520 |
+
loss = l2_loss + l1_loss
|
521 |
+
num_dead_features = (
|
522 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
523 |
+
).sum()
|
524 |
+
|
525 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
526 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
527 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
528 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
529 |
+
output = {
|
530 |
+
"sae_out": sae_out,
|
531 |
+
"feature_acts": acts,
|
532 |
+
"num_dead_features": num_dead_features,
|
533 |
+
"loss": loss,
|
534 |
+
"l1_loss": l1_loss,
|
535 |
+
"l2_loss": l2_loss,
|
536 |
+
"l0_norm": l0,
|
537 |
+
"l1_norm": l0,
|
538 |
+
"explained_variance": explained_variance,
|
539 |
+
}
|
540 |
+
return output
|
541 |
+
|
542 |
+
@torch.no_grad()
|
543 |
+
def fold_W_dec_norm(self):
|
544 |
+
"""
|
545 |
+
Make decoder weights unit norm and adjust encoder and thresholds accordingly
|
546 |
+
"""
|
547 |
+
# Get current decoder norms
|
548 |
+
W_dec_norm = self.W_dec.norm(dim=-1, keepdim=True)
|
549 |
+
|
550 |
+
# Call parent method to handle weights
|
551 |
+
super().fold_W_dec_norm()
|
552 |
+
|
553 |
+
# Scale thresholds to preserve sparsity (remove keepdim to match threshold shape)
|
554 |
+
self.jumprelu.log_threshold.data = self.jumprelu.log_threshold + torch.log(W_dec_norm.squeeze(-1))
|
555 |
+
|
556 |
+
return self
|
557 |
+
|
558 |
+
|
559 |
+
SAEConfig.register_for_auto_class("AutoConfig")
|
560 |
+
BatchTopKSAE.register_for_auto_class()
|
561 |
+
JumpReLUSAE.register_for_auto_class()
|
562 |
+
VanillaSAE.register_for_auto_class()
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c57832cc7dfbb94a1acde90650666bf23e5edc0dc76717f4ada0da7a8648aac9
|
3 |
+
size 67242576
|