elephantmipt commited on
Commit
b405a20
·
verified ·
1 Parent(s): 8b0cce1

Upload JumpReLUSAE

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +28 -0
  3. config.py +203 -0
  4. model.py +474 -0
  5. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_size": 2304,
3
+ "architectures": [
4
+ "JumpReLUSAE"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.SAEConfig",
8
+ "AutoModel": "model.JumpReLUSAE"
9
+ },
10
+ "aux_penalty": 0.03125,
11
+ "bandwidth": 0.001,
12
+ "dict_size": 16384,
13
+ "dtype": "float32",
14
+ "input_mean": -0.0031099182087928057,
15
+ "input_std": 0.8221414089202881,
16
+ "input_unit_norm": false,
17
+ "l1_coeff": 0.0025,
18
+ "model_type": "sae",
19
+ "n_batches_to_dead": 10,
20
+ "parent_hook_point": "attn_out",
21
+ "parent_layer": 6,
22
+ "parent_model_name": "google/gemma-2-2b",
23
+ "sae_dtype": "float32",
24
+ "sae_type": "jumprelu",
25
+ "top_k_aux": 512,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.47.0"
28
+ }
config.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+ import torch
4
+ import pyrallis
5
+ from transformers import PretrainedConfig
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class TrainingConfig:
11
+ # Model settings
12
+ model_name: str = "unsloth/Meta-Llama-3.1-8B"
13
+ layer: int = 12
14
+ hook_point: str = "resid_post"
15
+ act_size: Optional[int] = None # Will be set after model initialization
16
+
17
+ # SAE settings
18
+ sae_type: str = "batchtopk"
19
+ dict_size: int = 2**15
20
+ aux_penalty: float = 1/32
21
+ input_unit_norm: bool = True
22
+
23
+ # TopK specific settings
24
+ top_k: int = 50
25
+ top_k_warmup_steps_fraction: float = 0.1
26
+ start_top_k: int = 4096
27
+ top_k_aux: int = 512
28
+
29
+ n_batches_to_dead: int = 10
30
+
31
+ # Training settings
32
+ lr: float = 3e-4
33
+ bandwidth: float = 0.001
34
+ l1_coeff: float = 0.0018
35
+ num_tokens: int = int(1e9)
36
+ seq_len: int = 1024
37
+ model_batch_size: int = 16
38
+ num_batches_in_buffer: int = 5
39
+ max_grad_norm: float = 1.0
40
+ batch_size: int = 8192
41
+
42
+ # scheduler
43
+ warmup_fraction: float = 0.1
44
+ scheduler_type: str = 'linear'
45
+
46
+ # Hardware settings
47
+ device: str = "cuda"
48
+ dtype: torch.dtype = field(default=torch.float32)
49
+ sae_dtype: torch.dtype = field(default=torch.float32)
50
+
51
+ # Dataset settings
52
+ dataset_path: str = "cerebras/SlimPajama-627B"
53
+
54
+ # Logging settings
55
+ wandb_project: str = "turbo-llama-lens"
56
+
57
+ performance_log_steps: int = 100
58
+ save_checkpoint_steps: int = 10_000
59
+ def __post_init__(self):
60
+ if self.device == "cuda" and not torch.cuda.is_available():
61
+ print("CUDA not available, falling back to CPU")
62
+ self.device = "cpu"
63
+
64
+ # Convert string dtype to torch.dtype if needed
65
+ if isinstance(self.dtype, str):
66
+ self.dtype = getattr(torch, self.dtype)
67
+
68
+
69
+ class SAEConfig(PretrainedConfig):
70
+ model_type = "sae"
71
+
72
+ def __init__(
73
+ self,
74
+ # SAE architecture
75
+ act_size: int = None,
76
+ dict_size: int = 2**15,
77
+ sae_type: str = "batchtopk",
78
+ input_unit_norm: bool = True,
79
+
80
+ # TopK specific settings
81
+ top_k: int = 50,
82
+ top_k_aux: int = 512,
83
+ n_batches_to_dead: int = 10,
84
+
85
+ # Training hyperparameters
86
+ aux_penalty: float = 1/32,
87
+ l1_coeff: float = 0.0018,
88
+ bandwidth: float = 0.001,
89
+
90
+ # Hardware settings
91
+ dtype: str = "float32",
92
+ sae_dtype: str = "float32",
93
+
94
+ # Optional parent model info
95
+ parent_model_name: Optional[str] = None,
96
+ parent_layer: Optional[int] = None,
97
+ parent_hook_point: Optional[str] = None,
98
+
99
+ # Input normalization settings
100
+ input_mean: Optional[float] = None,
101
+ input_std: Optional[float] = None,
102
+
103
+ **kwargs
104
+ ):
105
+ super().__init__(**kwargs)
106
+ self.act_size = act_size
107
+ self.dict_size = dict_size
108
+ self.sae_type = sae_type
109
+ self.input_unit_norm = input_unit_norm
110
+
111
+ self.top_k = top_k
112
+ self.top_k_aux = top_k_aux
113
+ self.n_batches_to_dead = n_batches_to_dead
114
+
115
+ self.aux_penalty = aux_penalty
116
+ self.l1_coeff = l1_coeff
117
+ self.bandwidth = bandwidth
118
+
119
+ self.dtype = dtype
120
+ self.sae_dtype = sae_dtype
121
+
122
+ self.parent_model_name = parent_model_name
123
+ self.parent_layer = parent_layer
124
+ self.parent_hook_point = parent_hook_point
125
+
126
+ self.input_mean = input_mean
127
+ self.input_std = input_std
128
+
129
+ def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
130
+ dtype_map = {
131
+ "float32": torch.float32,
132
+ "float16": torch.float16,
133
+ "bfloat16": torch.bfloat16,
134
+ }
135
+ return dtype_map.get(dtype_str, torch.float32)
136
+
137
+ @classmethod
138
+ def from_training_config(cls, cfg: TrainingConfig):
139
+ """Convert TrainingConfig to SAEConfig"""
140
+ return cls(
141
+ act_size=cfg.act_size,
142
+ dict_size=cfg.dict_size,
143
+ sae_type=cfg.sae_type,
144
+ input_unit_norm=cfg.input_unit_norm,
145
+ top_k=cfg.top_k,
146
+ top_k_aux=cfg.top_k_aux,
147
+ n_batches_to_dead=cfg.n_batches_to_dead,
148
+ aux_penalty=cfg.aux_penalty,
149
+ l1_coeff=cfg.l1_coeff,
150
+ bandwidth=cfg.bandwidth,
151
+ dtype=str(cfg.dtype).split('.')[-1],
152
+ sae_dtype=str(cfg.sae_dtype).split('.')[-1],
153
+ parent_model_name=cfg.model_name,
154
+ parent_layer=cfg.layer,
155
+ parent_hook_point=cfg.hook_point,
156
+ input_mean=cfg.input_mean if hasattr(cfg, 'input_mean') else None,
157
+ input_std=cfg.input_std if hasattr(cfg, 'input_std') else None,
158
+ )
159
+
160
+ def to_training_config(self) -> TrainingConfig:
161
+ """Convert SAEConfig back to TrainingConfig"""
162
+ return TrainingConfig(
163
+ dtype=self.get_torch_dtype(self.dtype),
164
+ sae_dtype=self.get_torch_dtype(self.sae_dtype),
165
+ model_name=self.parent_model_name,
166
+ layer=self.parent_layer,
167
+ hook_point=self.parent_hook_point,
168
+ act_size=self.act_size,
169
+ dict_size=self.dict_size,
170
+ sae_type=self.sae_type,
171
+ input_unit_norm=self.input_unit_norm,
172
+ top_k=self.top_k,
173
+ top_k_aux=self.top_k_aux,
174
+ n_batches_to_dead=self.n_batches_to_dead,
175
+ aux_penalty=self.aux_penalty,
176
+ l1_coeff=self.l1_coeff,
177
+ bandwidth=self.bandwidth,
178
+ )
179
+
180
+
181
+ @pyrallis.wrap()
182
+ def get_config() -> TrainingConfig:
183
+ return TrainingConfig()
184
+
185
+
186
+ # For backward compatibility
187
+ def get_default_cfg() -> TrainingConfig:
188
+ return get_config()
189
+
190
+
191
+ def post_init_cfg(cfg: TrainingConfig, activation_store = None) -> TrainingConfig:
192
+ """
193
+ Any additional configuration setup that needs to happen after model initialization
194
+ Args:
195
+ cfg: Training configuration
196
+ activation_store: Optional activation store to get input statistics
197
+ """
198
+ if activation_store is not None:
199
+ cfg.input_mean = activation_store.mean
200
+ cfg.input_std = activation_store.std
201
+ print(f"Setting input statistics from activation store - Mean: {cfg.input_mean:.4f}, Std: {cfg.input_std:.4f}")
202
+
203
+ return cfg
model.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from typing import Optional, Dict, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.autograd as autograd
7
+ from .config import SAEConfig
8
+
9
+
10
+ class BaseSAE(PreTrainedModel):
11
+ """Base class for autoencoder models."""
12
+ config_class = SAEConfig
13
+ base_model_prefix = "sae"
14
+
15
+ def __init__(self, config: SAEConfig):
16
+ super().__init__(config)
17
+ print(config)
18
+ self.config = config
19
+ torch.manual_seed(42)
20
+
21
+ self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
22
+ self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
23
+ self.W_enc = nn.Parameter(
24
+ torch.nn.init.kaiming_uniform_(
25
+ torch.empty(self.config.act_size, self.config.dict_size)
26
+ )
27
+ )
28
+ self.W_dec = nn.Parameter(
29
+ torch.nn.init.kaiming_uniform_(
30
+ torch.empty(self.config.dict_size, self.config.act_size)
31
+ )
32
+ )
33
+ self.W_dec.data[:] = self.W_enc.t().data
34
+ self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
35
+ self.register_buffer('num_batches_not_active', torch.zeros((self.config.dict_size,)))
36
+
37
+ self.to(self.config.get_torch_dtype(self.config.dtype))
38
+
39
+ # Initialize input normalization parameters if provided
40
+ if config.input_mean is not None and config.input_std is not None:
41
+ self.register_buffer('input_mean', torch.tensor(config.input_mean))
42
+ self.register_buffer('input_std', torch.tensor(config.input_std))
43
+ else:
44
+ self.input_mean = None
45
+ self.input_std = None
46
+
47
+ def preprocess_input(self, x):
48
+ x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
49
+ if self.config.input_unit_norm:
50
+ if self.input_mean is not None and self.input_std is not None:
51
+ # Use pre-computed statistics
52
+ x = (x - self.input_mean) / (self.input_std + 1e-5)
53
+ return x, self.input_mean, self.input_std
54
+ else:
55
+ # Compute statistics on the fly
56
+ x_mean = x.mean(dim=-1, keepdim=True)
57
+ x = x - x_mean
58
+ x_std = x.std(dim=-1, keepdim=True)
59
+ x = x / (x_std + 1e-5)
60
+ return x, x_mean, x_std
61
+ else:
62
+ return x, None, None
63
+
64
+ def postprocess_output(self, x_reconstruct, x_mean, x_std):
65
+ if self.config.input_unit_norm:
66
+ x_reconstruct = x_reconstruct * x_std + x_mean
67
+ return x_reconstruct
68
+
69
+ @torch.no_grad()
70
+ def make_decoder_weights_and_grad_unit_norm(self):
71
+ W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
72
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
73
+ -1, keepdim=True
74
+ ) * W_dec_normed
75
+ self.W_dec.grad -= W_dec_grad_proj
76
+ self.W_dec.data = W_dec_normed
77
+
78
+ def update_inactive_features(self, acts):
79
+ self.num_batches_not_active += (acts.sum(0) == 0).float()
80
+ self.num_batches_not_active[acts.sum(0) > 0] = 0
81
+
82
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Encode input tensor to sparse features
85
+ Args:
86
+ x: Input tensor of shape (batch_size, act_size)
87
+ Returns:
88
+ Encoded features of shape (batch_size, dict_size)
89
+ """
90
+ if self.config.input_unit_norm:
91
+ x_mean = x.mean(dim=-1, keepdim=True)
92
+ x = x - x_mean
93
+ x_std = x.std(dim=-1, keepdim=True)
94
+ x = x / (x_std + 1e-5)
95
+
96
+ x_cent = x - self.b_dec
97
+ return F.relu(x_cent @ self.W_enc + self.b_enc)
98
+
99
+ def decode(self, h: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ Decode features back to input space
102
+ Args:
103
+ h: Encoded features of shape (batch_size, dict_size)
104
+ Returns:
105
+ Reconstructed input of shape (batch_size, act_size)
106
+ """
107
+ return h @ self.W_dec + self.b_dec
108
+
109
+ def forward(self, x):
110
+ x, x_mean, x_std = self.preprocess_input(x)
111
+ x_cent = x - self.b_dec
112
+ acts = F.relu(x_cent @ self.W_enc + self.b_enc)
113
+ x_reconstruct = acts @ self.W_dec + self.b_dec
114
+ self.update_inactive_features(acts)
115
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
116
+ return output
117
+
118
+
119
+ class BatchTopKSAE(BaseSAE):
120
+ def encode(self, x: torch.Tensor, use_pre_enc_bias: bool = False) -> torch.Tensor:
121
+ """
122
+ Encode input tensor to sparse features with batch-wise top-k
123
+ Args:
124
+ x: Input tensor of shape (batch_size, act_size)
125
+ use_pre_enc_bias: Whether to use pre-encoder bias
126
+ Returns:
127
+ Encoded features of shape (batch_size, dict_size)
128
+ """
129
+ if self.config.input_unit_norm:
130
+ x_mean = x.mean(dim=-1, keepdim=True)
131
+ x = x - x_mean
132
+ x_std = x.std(dim=-1, keepdim=True)
133
+ x = x / (x_std + 1e-5)
134
+
135
+ if use_pre_enc_bias:
136
+ x = x - self.b_dec
137
+
138
+ acts = F.relu(x @ self.W_enc + self.b_enc)
139
+ acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
140
+ return (
141
+ torch.zeros_like(acts.flatten())
142
+ .scatter(-1, acts_topk.indices, acts_topk.values)
143
+ .reshape(acts.shape)
144
+ )
145
+
146
+ def forward(self, x):
147
+ x, x_mean, x_std = self.preprocess_input(x)
148
+
149
+ x_cent = x - self.b_dec
150
+ acts = F.relu(x_cent @ self.W_enc)
151
+ acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
152
+ acts_topk = (
153
+ torch.zeros_like(acts.flatten())
154
+ .scatter(-1, acts_topk.indices, acts_topk.values)
155
+ .reshape(acts.shape)
156
+ )
157
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
158
+
159
+ self.update_inactive_features(acts_topk)
160
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
161
+ return output
162
+
163
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
164
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
165
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
166
+ l1_loss = self.config.l1_coeff * l1_norm
167
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
168
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
169
+ loss = l2_loss + aux_loss
170
+ num_dead_features = (
171
+ self.num_batches_not_active > self.config.n_batches_to_dead
172
+ ).sum()
173
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
174
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
175
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
176
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
177
+ output = {
178
+ "sae_out": sae_out,
179
+ "feature_acts": acts_topk,
180
+ "num_dead_features": num_dead_features,
181
+ "loss": loss,
182
+ "l1_loss": l1_loss,
183
+ "l2_loss": l2_loss,
184
+ "l0_norm": l0_norm,
185
+ "l1_norm": l1_norm,
186
+ "aux_loss": aux_loss,
187
+ "explained_variance": explained_variance,
188
+ "top_k": self.config.top_k
189
+ }
190
+ return output
191
+
192
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
193
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
194
+ if dead_features.sum() > 0:
195
+ residual = x.float() - x_reconstruct.float()
196
+ acts_topk_aux = torch.topk(
197
+ acts[:, dead_features],
198
+ min(self.config.top_k_aux, dead_features.sum()),
199
+ dim=-1,
200
+ )
201
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
202
+ -1, acts_topk_aux.indices, acts_topk_aux.values
203
+ )
204
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
205
+ l2_loss_aux = (
206
+ self.config.aux_penalty
207
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
208
+ )
209
+ return l2_loss_aux
210
+ else:
211
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
212
+
213
+
214
+ class TopKSAE(BaseSAE):
215
+ def encode(self, x: torch.Tensor, use_pre_enc_bias: bool = False) -> torch.Tensor:
216
+ """
217
+ Encode input tensor to sparse features with per-sample top-k
218
+ Args:
219
+ x: Input tensor of shape (batch_size, act_size)
220
+ use_pre_enc_bias: Whether to use pre-encoder bias
221
+ Returns:
222
+ Encoded features of shape (batch_size, dict_size)
223
+ """
224
+ if self.config.input_unit_norm:
225
+ x_mean = x.mean(dim=-1, keepdim=True)
226
+ x = x - x_mean
227
+ x_std = x.std(dim=-1, keepdim=True)
228
+ x = x / (x_std + 1e-5)
229
+
230
+ if use_pre_enc_bias:
231
+ x = x - self.b_dec
232
+
233
+ acts = F.relu(x @ self.W_enc + self.b_enc)
234
+ acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
235
+ return torch.zeros_like(acts).scatter(
236
+ -1, acts_topk.indices, acts_topk.values
237
+ )
238
+
239
+ def forward(self, x):
240
+ x, x_mean, x_std = self.preprocess_input(x)
241
+
242
+ x_cent = x - self.b_dec
243
+ acts = F.relu(x_cent @ self.W_enc)
244
+ acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
245
+ acts_topk = torch.zeros_like(acts).scatter(
246
+ -1, acts_topk.indices, acts_topk.values
247
+ )
248
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
249
+
250
+ self.update_inactive_features(acts_topk)
251
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
252
+ return output
253
+
254
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
255
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
256
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
257
+ l1_loss = self.config.l1_coeff * l1_norm
258
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
259
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
260
+ loss = l2_loss + l1_loss + aux_loss
261
+ num_dead_features = (
262
+ self.num_batches_not_active > self.config.n_batches_to_dead
263
+ ).sum()
264
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
265
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
266
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
267
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
268
+ output = {
269
+ "sae_out": sae_out,
270
+ "feature_acts": acts_topk,
271
+ "num_dead_features": num_dead_features,
272
+ "loss": loss,
273
+ "l1_loss": l1_loss,
274
+ "l2_loss": l2_loss,
275
+ "l0_norm": l0_norm,
276
+ "l1_norm": l1_norm,
277
+ "explained_variance": explained_variance,
278
+ "aux_loss": aux_loss,
279
+ }
280
+ return output
281
+
282
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
283
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
284
+ if dead_features.sum() > 0:
285
+ residual = x.float() - x_reconstruct.float()
286
+ acts_topk_aux = torch.topk(
287
+ acts[:, dead_features],
288
+ min(self.config.top_k_aux, dead_features.sum()),
289
+ dim=-1,
290
+ )
291
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
292
+ -1, acts_topk_aux.indices, acts_topk_aux.values
293
+ )
294
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
295
+ l2_loss_aux = (
296
+ self.config.aux_penalty
297
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
298
+ )
299
+ return l2_loss_aux
300
+ else:
301
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
302
+
303
+
304
+ class VanillaSAE(BaseSAE):
305
+ def forward(self, x):
306
+ x, x_mean, x_std = self.preprocess_input(x)
307
+ x_cent = x - self.b_dec
308
+ acts = F.relu(x_cent @ self.W_enc + self.b_enc)
309
+ x_reconstruct = acts @ self.W_dec + self.b_dec
310
+ self.update_inactive_features(acts)
311
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
312
+ return output
313
+
314
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
315
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
316
+ l1_norm = acts.float().abs().sum(-1).mean()
317
+ l1_loss = self.config.l1_coeff * l1_norm
318
+ l0_norm = (acts > 0).float().sum(-1).mean()
319
+ loss = l2_loss + l1_loss
320
+ num_dead_features = (
321
+ self.num_batches_not_active > self.config.n_batches_to_dead
322
+ ).sum()
323
+
324
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
325
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
326
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
327
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
328
+ output = {
329
+ "sae_out": sae_out,
330
+ "feature_acts": acts,
331
+ "num_dead_features": num_dead_features,
332
+ "loss": loss,
333
+ "l1_loss": l1_loss,
334
+ "l2_loss": l2_loss,
335
+ "l0_norm": l0_norm,
336
+ "l1_norm": l1_norm,
337
+ "explained_variance": explained_variance,
338
+ }
339
+ return output
340
+
341
+
342
+ import torch
343
+ import torch.nn as nn
344
+
345
+ class RectangleFunction(autograd.Function):
346
+ @staticmethod
347
+ def forward(ctx, x):
348
+ ctx.save_for_backward(x)
349
+ return ((x > -0.5) & (x < 0.5)).float()
350
+
351
+ @staticmethod
352
+ def backward(ctx, grad_output):
353
+ (x,) = ctx.saved_tensors
354
+ grad_input = grad_output.clone()
355
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
356
+ return grad_input
357
+
358
+ class JumpReLUFunction(autograd.Function):
359
+ @staticmethod
360
+ def forward(ctx, x, log_threshold, bandwidth):
361
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
362
+ threshold = torch.exp(log_threshold)
363
+ return x * (x > threshold).float()
364
+
365
+ @staticmethod
366
+ def backward(ctx, grad_output):
367
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
368
+ bandwidth = bandwidth_tensor.item()
369
+ threshold = torch.exp(log_threshold)
370
+ x_grad = (x > threshold).float() * grad_output
371
+ threshold_grad = (
372
+ -(threshold / bandwidth)
373
+ * RectangleFunction.apply((x - threshold) / bandwidth)
374
+ * grad_output
375
+ )
376
+ return x_grad, threshold_grad, None # None for bandwidth
377
+
378
+ class JumpReLU(nn.Module):
379
+ def __init__(self, feature_size, bandwidth, device='cpu'):
380
+ super(JumpReLU, self).__init__()
381
+ self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
382
+ self.bandwidth = bandwidth
383
+
384
+ def forward(self, x):
385
+ return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
386
+
387
+ class StepFunction(autograd.Function):
388
+ @staticmethod
389
+ def forward(ctx, x, log_threshold, bandwidth):
390
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
391
+ threshold = torch.exp(log_threshold)
392
+ return (x > threshold).float()
393
+
394
+ @staticmethod
395
+ def backward(ctx, grad_output):
396
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
397
+ bandwidth = bandwidth_tensor.item()
398
+ threshold = torch.exp(log_threshold)
399
+ x_grad = torch.zeros_like(x)
400
+ threshold_grad = (
401
+ -(1.0 / bandwidth)
402
+ * RectangleFunction.apply((x - threshold) / bandwidth)
403
+ * grad_output
404
+ )
405
+ return x_grad, threshold_grad, None # None for bandwidth
406
+
407
+ class JumpReLUSAE(BaseSAE):
408
+ def __init__(self, config: SAEConfig):
409
+ super().__init__(config)
410
+ self.jumprelu = JumpReLU(
411
+ feature_size=config.dict_size,
412
+ bandwidth=config.bandwidth,
413
+ device=config.device if hasattr(config, 'device') else 'cpu'
414
+ )
415
+
416
+ def encode(self, x: torch.Tensor, use_pre_enc_bias: bool = False) -> torch.Tensor:
417
+ """
418
+ Encode input tensor to sparse features using JumpReLU
419
+ """
420
+ if self.config.input_unit_norm:
421
+ x_mean = x.mean(dim=-1, keepdim=True)
422
+ x = x - x_mean
423
+ x_std = x.std(dim=-1, keepdim=True)
424
+ x = x / (x_std + 1e-5)
425
+
426
+ if use_pre_enc_bias:
427
+ x = x - self.b_dec
428
+ pre_activations = F.relu(x @ self.W_enc + self.b_enc)
429
+ return self.jumprelu(pre_activations)
430
+
431
+ def forward(self, x, use_pre_enc_bias=False):
432
+ x, x_mean, x_std = self.preprocess_input(x)
433
+ if use_pre_enc_bias:
434
+ x = x - self.b_dec
435
+ pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
436
+ feature_magnitudes = self.jumprelu(pre_activations)
437
+
438
+ x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
439
+
440
+ return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
441
+
442
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
443
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
444
+
445
+ l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
446
+ l0_loss = self.config.l1_coeff * l0
447
+ l1_loss = l0_loss
448
+
449
+ loss = l2_loss + l1_loss
450
+ num_dead_features = (
451
+ self.num_batches_not_active > self.config.n_batches_to_dead
452
+ ).sum()
453
+
454
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
455
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
456
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
457
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
458
+ output = {
459
+ "sae_out": sae_out,
460
+ "feature_acts": acts,
461
+ "num_dead_features": num_dead_features,
462
+ "loss": loss,
463
+ "l1_loss": l1_loss,
464
+ "l2_loss": l2_loss,
465
+ "l0_norm": l0,
466
+ "l1_norm": l0,
467
+ "explained_variance": explained_variance,
468
+ }
469
+ return output
470
+
471
+ SAEConfig.register_for_auto_class("AutoConfig")
472
+ BatchTopKSAE.register_for_auto_class()
473
+ JumpReLUSAE.register_for_auto_class()
474
+ VanillaSAE.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e583aa1318beed78cca624139ba6bf76e634db7d5c14233223b01d8e0672a8b9
3
+ size 302196416