elephantmipt commited on
Commit
ca2139a
·
verified ·
1 Parent(s): 0678243

Upload BatchTopKSAE

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. config.py +177 -0
  3. sae.py +390 -0
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": [
4
  "BatchTopKSAE"
5
  ],
 
 
 
 
6
  "aux_penalty": 0.03125,
7
  "bandwidth": 0.001,
8
  "dict_size": 128,
 
3
  "architectures": [
4
  "BatchTopKSAE"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.SAEConfig",
8
+ "AutoModel": "sae.BatchTopKSAE"
9
+ },
10
  "aux_penalty": 0.03125,
11
  "bandwidth": 0.001,
12
  "dict_size": 128,
config.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, Literal
3
+ import torch
4
+ import pyrallis
5
+ from transformers import PretrainedConfig
6
+ from typing import Optional
7
+ from dataclasses import asdict
8
+
9
+
10
+ @dataclass
11
+ class TrainingConfig:
12
+ # Model settings
13
+ model_name: str = "unsloth/Meta-Llama-3.1-8B"
14
+ layer: int = 12
15
+ hook_point: str = "resid_mid"
16
+ act_size: Optional[int] = None # Will be set after model initialization
17
+
18
+ # SAE settings
19
+ sae_type: str = "batchtopk"
20
+ dict_size: int = 2**15
21
+ aux_penalty: float = 1/32
22
+ input_unit_norm: bool = True
23
+
24
+ # TopK specific settings
25
+ top_k: int = 50
26
+ top_k_warmup_steps_fraction: float = 0.1
27
+ start_top_k: int = 4096
28
+ top_k_aux: int = 512
29
+
30
+ n_batches_to_dead: int = 10
31
+
32
+ # Training settings
33
+ lr: float = 3e-4
34
+ bandwidth: float = 0.001
35
+ l1_coeff: float = 0.0018
36
+ num_tokens: int = int(1e9)
37
+ seq_len: int = 1024
38
+ model_batch_size: int = 16
39
+ num_batches_in_buffer: int = 5
40
+ max_grad_norm: float = 1.0
41
+ batch_size: int = 8192
42
+
43
+ # scheduler
44
+ warmup_fraction: float = 0.1
45
+ scheduler_type: str = 'linear'
46
+
47
+ # Hardware settings
48
+ device: str = "cuda"
49
+ dtype: torch.dtype = field(default=torch.float32)
50
+ sae_dtype: torch.dtype = field(default=torch.float32)
51
+
52
+ # Dataset settings
53
+ dataset_path: str = "cerebras/SlimPajama-627B"
54
+
55
+ # Logging settings
56
+ wandb_project: str = "turbo-llama-lens"
57
+
58
+ performance_log_steps: int = 100
59
+ save_checkpoint_steps: int = 10_000
60
+ def __post_init__(self):
61
+ if self.device == "cuda" and not torch.cuda.is_available():
62
+ print("CUDA not available, falling back to CPU")
63
+ self.device = "cpu"
64
+
65
+ # Convert string dtype to torch.dtype if needed
66
+ if isinstance(self.dtype, str):
67
+ self.dtype = getattr(torch, self.dtype)
68
+
69
+
70
+ class SAEConfig(PretrainedConfig):
71
+ model_type = "sae"
72
+
73
+ def __init__(
74
+ self,
75
+ # SAE architecture
76
+ act_size: int = None,
77
+ dict_size: int = 2**15,
78
+ sae_type: str = "batchtopk",
79
+ input_unit_norm: bool = True,
80
+
81
+ # TopK specific settings
82
+ top_k: int = 50,
83
+ top_k_aux: int = 512,
84
+ n_batches_to_dead: int = 10,
85
+
86
+ # Training hyperparameters
87
+ aux_penalty: float = 1/32,
88
+ l1_coeff: float = 0.0018,
89
+ bandwidth: float = 0.001,
90
+
91
+ # Hardware settings
92
+ dtype: str = "float32",
93
+ sae_dtype: str = "float32",
94
+
95
+ # Optional parent model info
96
+ parent_model_name: Optional[str] = None,
97
+ parent_layer: Optional[int] = None,
98
+ parent_hook_point: Optional[str] = None,
99
+
100
+ **kwargs
101
+ ):
102
+ super().__init__(**kwargs)
103
+ self.act_size = act_size
104
+ self.dict_size = dict_size
105
+ self.sae_type = sae_type
106
+ self.input_unit_norm = input_unit_norm
107
+
108
+ self.top_k = top_k
109
+ self.top_k_aux = top_k_aux
110
+ self.n_batches_to_dead = n_batches_to_dead
111
+
112
+ self.aux_penalty = aux_penalty
113
+ self.l1_coeff = l1_coeff
114
+ self.bandwidth = bandwidth
115
+
116
+ self.dtype = dtype
117
+ self.sae_dtype = sae_dtype
118
+
119
+ self.parent_model_name = parent_model_name
120
+ self.parent_layer = parent_layer
121
+ self.parent_hook_point = parent_hook_point
122
+
123
+ def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
124
+ dtype_map = {
125
+ "float32": torch.float32,
126
+ "float16": torch.float16,
127
+ "bfloat16": torch.bfloat16,
128
+ }
129
+ return dtype_map.get(dtype_str, torch.float32)
130
+
131
+ @classmethod
132
+ def from_training_config(cls, cfg: TrainingConfig):
133
+ """Convert TrainingConfig to SAEConfig"""
134
+ return cls(
135
+ act_size=cfg.act_size,
136
+ dict_size=cfg.dict_size,
137
+ sae_type=cfg.sae_type,
138
+ input_unit_norm=cfg.input_unit_norm,
139
+ top_k=cfg.top_k,
140
+ top_k_aux=cfg.top_k_aux,
141
+ n_batches_to_dead=cfg.n_batches_to_dead,
142
+ aux_penalty=cfg.aux_penalty,
143
+ l1_coeff=cfg.l1_coeff,
144
+ bandwidth=cfg.bandwidth,
145
+ dtype=str(cfg.dtype).split('.')[-1],
146
+ sae_dtype=str(cfg.sae_dtype).split('.')[-1],
147
+ parent_model_name=cfg.model_name,
148
+ parent_layer=cfg.layer,
149
+ parent_hook_point=cfg.hook_point,
150
+ )
151
+
152
+ def to_training_config(self) -> TrainingConfig:
153
+ """Convert SAEConfig back to TrainingConfig"""
154
+ config_dict = asdict(self)
155
+ config_dict['dtype'] = self.get_torch_dtype(self.dtype)
156
+ config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype)
157
+ config_dict['model_name'] = self.parent_model_name
158
+ config_dict['layer'] = self.parent_layer
159
+ config_dict['hook_point'] = self.parent_hook_point
160
+ return TrainingConfig(**config_dict)
161
+
162
+
163
+ @pyrallis.wrap()
164
+ def get_config() -> TrainingConfig:
165
+ return TrainingConfig()
166
+
167
+
168
+ # For backward compatibility
169
+ def get_default_cfg() -> TrainingConfig:
170
+ return get_config()
171
+
172
+
173
+ def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig:
174
+ """
175
+ Any additional configuration setup that needs to happen after model initialization
176
+ """
177
+ return cfg
sae.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from typing import Optional, Dict, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.autograd as autograd
7
+ from copy import deepcopy
8
+ from safetensors.torch import save_file, load_file
9
+ from sae.modeling.config import SAEConfig
10
+ import os
11
+
12
+
13
+ class BaseSAE(PreTrainedModel):
14
+ """Base class for autoencoder models."""
15
+ config_class = SAEConfig
16
+ base_model_prefix = "sae"
17
+
18
+ def __init__(self, config: SAEConfig):
19
+ super().__init__(config)
20
+ print(config)
21
+ self.config = config
22
+ torch.manual_seed(42)
23
+
24
+ self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
25
+ self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
26
+ self.W_enc = nn.Parameter(
27
+ torch.nn.init.kaiming_uniform_(
28
+ torch.empty(self.config.act_size, self.config.dict_size)
29
+ )
30
+ )
31
+ self.W_dec = nn.Parameter(
32
+ torch.nn.init.kaiming_uniform_(
33
+ torch.empty(self.config.dict_size, self.config.act_size)
34
+ )
35
+ )
36
+ self.W_dec.data[:] = self.W_enc.t().data
37
+ self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
38
+ self.num_batches_not_active = torch.zeros((self.config.dict_size,))
39
+
40
+ self.to(self.config.get_torch_dtype(self.config.dtype))
41
+
42
+ def preprocess_input(self, x):
43
+ x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
44
+ if self.config.input_unit_norm:
45
+ x_mean = x.mean(dim=-1, keepdim=True)
46
+ x = x - x_mean
47
+ x_std = x.std(dim=-1, keepdim=True)
48
+ x = x / (x_std + 1e-5)
49
+ return x, x_mean, x_std
50
+ else:
51
+ return x, None, None
52
+
53
+ def postprocess_output(self, x_reconstruct, x_mean, x_std):
54
+ if self.config.input_unit_norm:
55
+ x_reconstruct = x_reconstruct * x_std + x_mean
56
+ return x_reconstruct
57
+
58
+ @torch.no_grad()
59
+ def make_decoder_weights_and_grad_unit_norm(self):
60
+ W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
61
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
62
+ -1, keepdim=True
63
+ ) * W_dec_normed
64
+ self.W_dec.grad -= W_dec_grad_proj
65
+ self.W_dec.data = W_dec_normed
66
+
67
+ def update_inactive_features(self, acts):
68
+ self.num_batches_not_active += (acts.sum(0) == 0).float()
69
+ self.num_batches_not_active[acts.sum(0) > 0] = 0
70
+
71
+ # @classmethod
72
+ # def from_pretrained(
73
+ # cls,
74
+ # pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
75
+ # *model_args,
76
+ # **kwargs
77
+ # ) -> "BaseSAE":
78
+ # config = kwargs.pop("config", None)
79
+ # if config is None:
80
+ # config = SAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
81
+
82
+ # model = cls(config)
83
+ # model.load_state_dict(
84
+ # load_file(os.path.join(pretrained_model_name_or_path, "model.safetensors"))
85
+ # )
86
+ # return model
87
+
88
+ # def save_pretrained(
89
+ # self,
90
+ # save_directory: Union[str, os.PathLike],
91
+ # **kwargs
92
+ # ):
93
+ # os.makedirs(save_directory, exist_ok=True)
94
+
95
+ # # Save the config
96
+ # self.config.save_pretrained(save_directory)
97
+
98
+ # # Save the model weights
99
+ # save_file(
100
+ # self.state_dict(),
101
+ # os.path.join(save_directory, "model.safetensors")
102
+ # )
103
+
104
+
105
+ class BatchTopKSAE(BaseSAE):
106
+ def forward(self, x):
107
+ x, x_mean, x_std = self.preprocess_input(x)
108
+
109
+ x_cent = x - self.b_dec
110
+ acts = F.relu(x_cent @ self.W_enc)
111
+ acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
112
+ acts_topk = (
113
+ torch.zeros_like(acts.flatten())
114
+ .scatter(-1, acts_topk.indices, acts_topk.values)
115
+ .reshape(acts.shape)
116
+ )
117
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
118
+
119
+ self.update_inactive_features(acts_topk)
120
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
121
+ return output
122
+
123
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
124
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
125
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
126
+ l1_loss = self.config.l1_coeff * l1_norm
127
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
128
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
129
+ loss = l2_loss + aux_loss
130
+ num_dead_features = (
131
+ self.num_batches_not_active > self.config.n_batches_to_dead
132
+ ).sum()
133
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
134
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
135
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
136
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
137
+ output = {
138
+ "sae_out": sae_out,
139
+ "feature_acts": acts_topk,
140
+ "num_dead_features": num_dead_features,
141
+ "loss": loss,
142
+ "l1_loss": l1_loss,
143
+ "l2_loss": l2_loss,
144
+ "l0_norm": l0_norm,
145
+ "l1_norm": l1_norm,
146
+ "aux_loss": aux_loss,
147
+ "explained_variance": explained_variance,
148
+ "top_k": self.config.top_k
149
+ }
150
+ return output
151
+
152
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
153
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
154
+ if dead_features.sum() > 0:
155
+ residual = x.float() - x_reconstruct.float()
156
+ acts_topk_aux = torch.topk(
157
+ acts[:, dead_features],
158
+ min(self.config.top_k_aux, dead_features.sum()),
159
+ dim=-1,
160
+ )
161
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
162
+ -1, acts_topk_aux.indices, acts_topk_aux.values
163
+ )
164
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
165
+ l2_loss_aux = (
166
+ self.config.aux_penalty
167
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
168
+ )
169
+ return l2_loss_aux
170
+ else:
171
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
172
+
173
+
174
+ class TopKSAE(BaseSAE):
175
+ def forward(self, x):
176
+ x, x_mean, x_std = self.preprocess_input(x)
177
+
178
+ x_cent = x - self.b_dec
179
+ acts = F.relu(x_cent @ self.W_enc)
180
+ acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
181
+ acts_topk = torch.zeros_like(acts).scatter(
182
+ -1, acts_topk.indices, acts_topk.values
183
+ )
184
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
185
+
186
+ self.update_inactive_features(acts_topk)
187
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
188
+ return output
189
+
190
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
191
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
192
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
193
+ l1_loss = self.config.l1_coeff * l1_norm
194
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
195
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
196
+ loss = l2_loss + l1_loss + aux_loss
197
+ num_dead_features = (
198
+ self.num_batches_not_active > self.config.n_batches_to_dead
199
+ ).sum()
200
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
201
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
202
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
203
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
204
+ output = {
205
+ "sae_out": sae_out,
206
+ "feature_acts": acts_topk,
207
+ "num_dead_features": num_dead_features,
208
+ "loss": loss,
209
+ "l1_loss": l1_loss,
210
+ "l2_loss": l2_loss,
211
+ "l0_norm": l0_norm,
212
+ "l1_norm": l1_norm,
213
+ "explained_variance": explained_variance,
214
+ "aux_loss": aux_loss,
215
+ }
216
+ return output
217
+
218
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
219
+ dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
220
+ if dead_features.sum() > 0:
221
+ residual = x.float() - x_reconstruct.float()
222
+ acts_topk_aux = torch.topk(
223
+ acts[:, dead_features],
224
+ min(self.config.top_k_aux, dead_features.sum()),
225
+ dim=-1,
226
+ )
227
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
228
+ -1, acts_topk_aux.indices, acts_topk_aux.values
229
+ )
230
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
231
+ l2_loss_aux = (
232
+ self.config.aux_penalty
233
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
234
+ )
235
+ return l2_loss_aux
236
+ else:
237
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
238
+
239
+
240
+ class VanillaSAE(BaseSAE):
241
+ def forward(self, x):
242
+ x, x_mean, x_std = self.preprocess_input(x)
243
+ x_cent = x - self.b_dec
244
+ acts = F.relu(x_cent @ self.W_enc + self.b_enc)
245
+ x_reconstruct = acts @ self.W_dec + self.b_dec
246
+ self.update_inactive_features(acts)
247
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
248
+ return output
249
+
250
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
251
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
252
+ l1_norm = acts.float().abs().sum(-1).mean()
253
+ l1_loss = self.config.l1_coeff * l1_norm
254
+ l0_norm = (acts > 0).float().sum(-1).mean()
255
+ loss = l2_loss + l1_loss
256
+ num_dead_features = (
257
+ self.num_batches_not_active > self.config.n_batches_to_dead
258
+ ).sum()
259
+
260
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
261
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
262
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
263
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
264
+ output = {
265
+ "sae_out": sae_out,
266
+ "feature_acts": acts,
267
+ "num_dead_features": num_dead_features,
268
+ "loss": loss,
269
+ "l1_loss": l1_loss,
270
+ "l2_loss": l2_loss,
271
+ "l0_norm": l0_norm,
272
+ "l1_norm": l1_norm,
273
+ "explained_variance": explained_variance,
274
+ }
275
+ return output
276
+
277
+
278
+ import torch
279
+ import torch.nn as nn
280
+
281
+ class RectangleFunction(autograd.Function):
282
+ @staticmethod
283
+ def forward(ctx, x):
284
+ ctx.save_for_backward(x)
285
+ return ((x > -0.5) & (x < 0.5)).float()
286
+
287
+ @staticmethod
288
+ def backward(ctx, grad_output):
289
+ (x,) = ctx.saved_tensors
290
+ grad_input = grad_output.clone()
291
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
292
+ return grad_input
293
+
294
+ class JumpReLUFunction(autograd.Function):
295
+ @staticmethod
296
+ def forward(ctx, x, log_threshold, bandwidth):
297
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
298
+ threshold = torch.exp(log_threshold)
299
+ return x * (x > threshold).float()
300
+
301
+ @staticmethod
302
+ def backward(ctx, grad_output):
303
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
304
+ bandwidth = bandwidth_tensor.item()
305
+ threshold = torch.exp(log_threshold)
306
+ x_grad = (x > threshold).float() * grad_output
307
+ threshold_grad = (
308
+ -(threshold / bandwidth)
309
+ * RectangleFunction.apply((x - threshold) / bandwidth)
310
+ * grad_output
311
+ )
312
+ return x_grad, threshold_grad, None # None for bandwidth
313
+
314
+ class JumpReLU(nn.Module):
315
+ def __init__(self, feature_size, bandwidth, device='cpu'):
316
+ super(JumpReLU, self).__init__()
317
+ self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
318
+ self.bandwidth = bandwidth
319
+
320
+ def forward(self, x):
321
+ return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
322
+
323
+ class StepFunction(autograd.Function):
324
+ @staticmethod
325
+ def forward(ctx, x, log_threshold, bandwidth):
326
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
327
+ threshold = torch.exp(log_threshold)
328
+ return (x > threshold).float()
329
+
330
+ @staticmethod
331
+ def backward(ctx, grad_output):
332
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
333
+ bandwidth = bandwidth_tensor.item()
334
+ threshold = torch.exp(log_threshold)
335
+ x_grad = torch.zeros_like(x)
336
+ threshold_grad = (
337
+ -(1.0 / bandwidth)
338
+ * RectangleFunction.apply((x - threshold) / bandwidth)
339
+ * grad_output
340
+ )
341
+ return x_grad, threshold_grad, None # None for bandwidth
342
+
343
+ class JumpReLUSAE(BaseSAE):
344
+ def __init__(self, config: SAEConfig):
345
+ super().__init__(config)
346
+ self.jumprelu = JumpReLU(
347
+ feature_size=config.dict_size,
348
+ bandwidth=config.bandwidth,
349
+ device=config.device if hasattr(config, 'device') else 'cpu'
350
+ )
351
+
352
+ def forward(self, x, use_pre_enc_bias=False):
353
+ x, x_mean, x_std = self.preprocess_input(x)
354
+ if use_pre_enc_bias:
355
+ x = x - self.b_dec
356
+ pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
357
+ feature_magnitudes = self.jumprelu(pre_activations)
358
+
359
+ x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
360
+
361
+ return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
362
+
363
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
364
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
365
+
366
+ l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
367
+ l0_loss = self.config.l1_coeff * l0
368
+ l1_loss = l0_loss
369
+
370
+ loss = l2_loss + l1_loss
371
+ num_dead_features = (
372
+ self.num_batches_not_active > self.config.n_batches_to_dead
373
+ ).sum()
374
+
375
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
376
+ per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
377
+ total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
378
+ explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
379
+ output = {
380
+ "sae_out": sae_out,
381
+ "feature_acts": acts,
382
+ "num_dead_features": num_dead_features,
383
+ "loss": loss,
384
+ "l1_loss": l1_loss,
385
+ "l2_loss": l2_loss,
386
+ "l0_norm": l0,
387
+ "l1_norm": l0,
388
+ "explained_variance": explained_variance,
389
+ }
390
+ return output