namgoodfire commited on
Commit
4e9deb7
·
verified ·
1 Parent(s): 72ed567

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +235 -1
README.md CHANGED
@@ -24,4 +24,238 @@ models like Llama 3.3 70B, we aim to accelerate progress in interpretability res
24
  Our initial work with these SAEs has revealed promising applications in model steering,
25
  enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai).
26
  We look forward to seeing how the research community builds upon these
27
- foundations and uncovers new applications.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  Our initial work with these SAEs has revealed promising applications in model steering,
25
  enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai).
26
  We look forward to seeing how the research community builds upon these
27
+ foundations and uncovers new applications.
28
+
29
+ ### How to use
30
+
31
+ ```
32
+ import torch
33
+ from typing import Optional, Callable
34
+
35
+ import nnsight
36
+ from nnsight.intervention import InterventionProxy
37
+
38
+
39
+ # Autoencoder
40
+
41
+
42
+ class SparseAutoEncoder(torch.nn.Module):
43
+ def __init__(
44
+ self,
45
+ d_in: int,
46
+ d_hidden: int,
47
+ device: torch.device,
48
+ dtype: torch.dtype = torch.bfloat16,
49
+ ):
50
+ super().__init__()
51
+ self.d_in = d_in
52
+ self.d_hidden = d_hidden
53
+ self.device = device
54
+ self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
55
+ self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
56
+ self.dtype = dtype
57
+ self.to(self.device, self.dtype)
58
+
59
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
60
+ """Encode a batch of data using a linear, followed by a ReLU."""
61
+ return torch.nn.functional.relu(self.encoder_linear(x))
62
+
63
+ def decode(self, x: torch.Tensor) -> torch.Tensor:
64
+ """Decode a batch of data using a linear."""
65
+ return self.decoder_linear(x)
66
+
67
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68
+ """SAE forward pass. Returns the reconstruction and the encoded features."""
69
+ f = self.encode(x)
70
+ return self.decode(f), f
71
+
72
+
73
+ def load_sae(
74
+ path: str,
75
+ d_model: int,
76
+ expansion_factor: int,
77
+ device: torch.device = torch.device("cpu"),
78
+ ):
79
+ sae = SparseAutoEncoder(
80
+ d_model,
81
+ d_model * expansion_factor,
82
+ device,
83
+ )
84
+ sae_dict = torch.load(
85
+ path, weights_only=True, map_location=device
86
+ )
87
+ sae.load_state_dict(sae_dict)
88
+
89
+ return sae
90
+
91
+
92
+ # Lanngugae model
93
+
94
+
95
+ InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
96
+
97
+
98
+ class ObservableLanguageModel:
99
+ def __init__(
100
+ self,
101
+ model: str,
102
+ device: str = "cuda",
103
+ dtype: torch.dtype = torch.bfloat16,
104
+ ):
105
+ self.dtype = dtype
106
+ self.device = device
107
+ self._original_model = model
108
+
109
+
110
+ self._model = nnsight.LanguageModel(
111
+ self._original_model,
112
+ device_map=device,
113
+ torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
114
+ )
115
+
116
+ self.tokenizer = self._model.tokenizer
117
+
118
+ self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
119
+
120
+ self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
121
+
122
+ def _attempt_to_infer_hidden_layer_dimensions(self):
123
+ config = self._model.config
124
+ if hasattr(config, "hidden_size"):
125
+ return int(config.hidden_size)
126
+
127
+ raise Exception(
128
+ "Could not infer hidden number of layer dimensions from model config"
129
+ )
130
+
131
+ def _find_module(self, hook_point: str):
132
+ submodules = hook_point.split(".")
133
+ module = self._model
134
+ while submodules:
135
+ module = getattr(module, submodules.pop(0))
136
+ return module
137
+
138
+ def forward(
139
+ self,
140
+ inputs: torch.Tensor,
141
+ cache_activations_at: Optional[list[str]] = None,
142
+ interventions: Optional[dict[str, InterventionInterface]] = None,
143
+ use_cache: bool = True,
144
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
145
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
146
+ cache: dict[str, torch.Tensor] = {}
147
+ with self._model.trace(
148
+ inputs,
149
+ scan=self.safe_mode,
150
+ validate=self.safe_mode,
151
+ use_cache=use_cache,
152
+ past_key_values=past_key_values,
153
+ ):
154
+ # If we input an intervention
155
+ if interventions:
156
+ for hook_site in interventions.keys():
157
+ if interventions[hook_site] is None:
158
+ continue
159
+
160
+ module = self._find_module(hook_site)
161
+
162
+ if self.cleanup_intervention_layer:
163
+ last_layer = self._find_module(
164
+ self.cleanup_intervention_layer
165
+ )
166
+ else:
167
+ last_layer = None
168
+
169
+ intervened_acts, direct_effect_tensor = interventions[
170
+ hook_site
171
+ ](module.output[0])
172
+ # Add direct effect tensor as 0 if it is None
173
+ if direct_effect_tensor is None:
174
+ direct_effect_tensor = 0
175
+ # We only modify module.output[0]
176
+ if use_cache:
177
+ module.output = (
178
+ intervened_acts,
179
+ module.output[1],
180
+ )
181
+ if last_layer:
182
+ last_layer.output = (
183
+ last_layer.output[0] - direct_effect_tensor,
184
+ last_layer.output[1],
185
+ )
186
+ else:
187
+ module.output = (intervened_acts,)
188
+ if last_layer:
189
+ last_layer.output = (
190
+ last_layer.output[0] - direct_effect_tensor,
191
+ )
192
+
193
+ if cache_activations_at is not None:
194
+ for hook_point in cache_activations_at:
195
+ module = self._find_module(hook_point)
196
+ cache[hook_point] = module.output.save()
197
+
198
+ if not past_key_values:
199
+ logits = self._model.output[0][:, -1, :].save()
200
+ else:
201
+ logits = self._model.output[0].squeeze(1).save()
202
+
203
+ kv_cache = self._model.output.past_key_values.save()
204
+
205
+ return (
206
+ logits.value.detach(),
207
+ kv_cache.value,
208
+ {k: v[0].detach() for k, v in cache.items()},
209
+ )
210
+
211
+
212
+ # Reading out features from the model
213
+
214
+ llama_3_1_8b = ObservableLanguageModel(
215
+ "meta-llama/Llama-3.1-8B-Instruct",
216
+ )
217
+
218
+ input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
219
+ [
220
+ {"role": "user", "content": "Hello, how are you?"},
221
+ ],
222
+ return_tensors="pt",
223
+ )
224
+ logits, kv_cache, features = llama_3_1_8b.forward(
225
+ input_tokens,
226
+ cache_activations_at=["model.layers.19"],
227
+ )
228
+
229
+ print(features["model.layers.19"].shape)
230
+
231
+
232
+ # Intervention example
233
+
234
+ sae = load_sae(
235
+ path="./llama-3-8b-d-hidden.pth",
236
+ d_model=4096,
237
+ expansion_factor=16,
238
+ )
239
+
240
+ PIRATE_FEATURE_INDEX = 0
241
+ VALUE_TO_MODIFY = 0.1
242
+
243
+ def example_intervention(activations: nnsight.InterventionProxy):
244
+ features = sae.encode(activations).detach()
245
+ reconstructed_acts = sae.decode(features).detach()
246
+ error = activations - reconstructed_acts
247
+
248
+ # Modify feature at index 0 across all token positions
249
+ features[:, 0] += 0.1
250
+
251
+ # Very important to add the error term back in!
252
+ return sae.decode(features) + error
253
+
254
+
255
+ logits, kv_cache, features = llama_3_1_8b.forward(
256
+ input_tokens,
257
+ interventions={"model.layers.19": example_intervention},
258
+ )
259
+
260
+ print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
261
+ ```