namgoodfire
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -24,4 +24,238 @@ models like Llama 3.3 70B, we aim to accelerate progress in interpretability res
|
|
24 |
Our initial work with these SAEs has revealed promising applications in model steering,
|
25 |
enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai).
|
26 |
We look forward to seeing how the research community builds upon these
|
27 |
-
foundations and uncovers new applications.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
Our initial work with these SAEs has revealed promising applications in model steering,
|
25 |
enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai).
|
26 |
We look forward to seeing how the research community builds upon these
|
27 |
+
foundations and uncovers new applications.
|
28 |
+
|
29 |
+
### How to use
|
30 |
+
|
31 |
+
```
|
32 |
+
import torch
|
33 |
+
from typing import Optional, Callable
|
34 |
+
|
35 |
+
import nnsight
|
36 |
+
from nnsight.intervention import InterventionProxy
|
37 |
+
|
38 |
+
|
39 |
+
# Autoencoder
|
40 |
+
|
41 |
+
|
42 |
+
class SparseAutoEncoder(torch.nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
d_in: int,
|
46 |
+
d_hidden: int,
|
47 |
+
device: torch.device,
|
48 |
+
dtype: torch.dtype = torch.bfloat16,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.d_in = d_in
|
52 |
+
self.d_hidden = d_hidden
|
53 |
+
self.device = device
|
54 |
+
self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
|
55 |
+
self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
|
56 |
+
self.dtype = dtype
|
57 |
+
self.to(self.device, self.dtype)
|
58 |
+
|
59 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
60 |
+
"""Encode a batch of data using a linear, followed by a ReLU."""
|
61 |
+
return torch.nn.functional.relu(self.encoder_linear(x))
|
62 |
+
|
63 |
+
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
64 |
+
"""Decode a batch of data using a linear."""
|
65 |
+
return self.decoder_linear(x)
|
66 |
+
|
67 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
68 |
+
"""SAE forward pass. Returns the reconstruction and the encoded features."""
|
69 |
+
f = self.encode(x)
|
70 |
+
return self.decode(f), f
|
71 |
+
|
72 |
+
|
73 |
+
def load_sae(
|
74 |
+
path: str,
|
75 |
+
d_model: int,
|
76 |
+
expansion_factor: int,
|
77 |
+
device: torch.device = torch.device("cpu"),
|
78 |
+
):
|
79 |
+
sae = SparseAutoEncoder(
|
80 |
+
d_model,
|
81 |
+
d_model * expansion_factor,
|
82 |
+
device,
|
83 |
+
)
|
84 |
+
sae_dict = torch.load(
|
85 |
+
path, weights_only=True, map_location=device
|
86 |
+
)
|
87 |
+
sae.load_state_dict(sae_dict)
|
88 |
+
|
89 |
+
return sae
|
90 |
+
|
91 |
+
|
92 |
+
# Lanngugae model
|
93 |
+
|
94 |
+
|
95 |
+
InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
|
96 |
+
|
97 |
+
|
98 |
+
class ObservableLanguageModel:
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
model: str,
|
102 |
+
device: str = "cuda",
|
103 |
+
dtype: torch.dtype = torch.bfloat16,
|
104 |
+
):
|
105 |
+
self.dtype = dtype
|
106 |
+
self.device = device
|
107 |
+
self._original_model = model
|
108 |
+
|
109 |
+
|
110 |
+
self._model = nnsight.LanguageModel(
|
111 |
+
self._original_model,
|
112 |
+
device_map=device,
|
113 |
+
torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
114 |
+
)
|
115 |
+
|
116 |
+
self.tokenizer = self._model.tokenizer
|
117 |
+
|
118 |
+
self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
|
119 |
+
|
120 |
+
self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
|
121 |
+
|
122 |
+
def _attempt_to_infer_hidden_layer_dimensions(self):
|
123 |
+
config = self._model.config
|
124 |
+
if hasattr(config, "hidden_size"):
|
125 |
+
return int(config.hidden_size)
|
126 |
+
|
127 |
+
raise Exception(
|
128 |
+
"Could not infer hidden number of layer dimensions from model config"
|
129 |
+
)
|
130 |
+
|
131 |
+
def _find_module(self, hook_point: str):
|
132 |
+
submodules = hook_point.split(".")
|
133 |
+
module = self._model
|
134 |
+
while submodules:
|
135 |
+
module = getattr(module, submodules.pop(0))
|
136 |
+
return module
|
137 |
+
|
138 |
+
def forward(
|
139 |
+
self,
|
140 |
+
inputs: torch.Tensor,
|
141 |
+
cache_activations_at: Optional[list[str]] = None,
|
142 |
+
interventions: Optional[dict[str, InterventionInterface]] = None,
|
143 |
+
use_cache: bool = True,
|
144 |
+
past_key_values: Optional[tuple[torch.Tensor]] = None,
|
145 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
|
146 |
+
cache: dict[str, torch.Tensor] = {}
|
147 |
+
with self._model.trace(
|
148 |
+
inputs,
|
149 |
+
scan=self.safe_mode,
|
150 |
+
validate=self.safe_mode,
|
151 |
+
use_cache=use_cache,
|
152 |
+
past_key_values=past_key_values,
|
153 |
+
):
|
154 |
+
# If we input an intervention
|
155 |
+
if interventions:
|
156 |
+
for hook_site in interventions.keys():
|
157 |
+
if interventions[hook_site] is None:
|
158 |
+
continue
|
159 |
+
|
160 |
+
module = self._find_module(hook_site)
|
161 |
+
|
162 |
+
if self.cleanup_intervention_layer:
|
163 |
+
last_layer = self._find_module(
|
164 |
+
self.cleanup_intervention_layer
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
last_layer = None
|
168 |
+
|
169 |
+
intervened_acts, direct_effect_tensor = interventions[
|
170 |
+
hook_site
|
171 |
+
](module.output[0])
|
172 |
+
# Add direct effect tensor as 0 if it is None
|
173 |
+
if direct_effect_tensor is None:
|
174 |
+
direct_effect_tensor = 0
|
175 |
+
# We only modify module.output[0]
|
176 |
+
if use_cache:
|
177 |
+
module.output = (
|
178 |
+
intervened_acts,
|
179 |
+
module.output[1],
|
180 |
+
)
|
181 |
+
if last_layer:
|
182 |
+
last_layer.output = (
|
183 |
+
last_layer.output[0] - direct_effect_tensor,
|
184 |
+
last_layer.output[1],
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
module.output = (intervened_acts,)
|
188 |
+
if last_layer:
|
189 |
+
last_layer.output = (
|
190 |
+
last_layer.output[0] - direct_effect_tensor,
|
191 |
+
)
|
192 |
+
|
193 |
+
if cache_activations_at is not None:
|
194 |
+
for hook_point in cache_activations_at:
|
195 |
+
module = self._find_module(hook_point)
|
196 |
+
cache[hook_point] = module.output.save()
|
197 |
+
|
198 |
+
if not past_key_values:
|
199 |
+
logits = self._model.output[0][:, -1, :].save()
|
200 |
+
else:
|
201 |
+
logits = self._model.output[0].squeeze(1).save()
|
202 |
+
|
203 |
+
kv_cache = self._model.output.past_key_values.save()
|
204 |
+
|
205 |
+
return (
|
206 |
+
logits.value.detach(),
|
207 |
+
kv_cache.value,
|
208 |
+
{k: v[0].detach() for k, v in cache.items()},
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
# Reading out features from the model
|
213 |
+
|
214 |
+
llama_3_1_8b = ObservableLanguageModel(
|
215 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
216 |
+
)
|
217 |
+
|
218 |
+
input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
|
219 |
+
[
|
220 |
+
{"role": "user", "content": "Hello, how are you?"},
|
221 |
+
],
|
222 |
+
return_tensors="pt",
|
223 |
+
)
|
224 |
+
logits, kv_cache, features = llama_3_1_8b.forward(
|
225 |
+
input_tokens,
|
226 |
+
cache_activations_at=["model.layers.19"],
|
227 |
+
)
|
228 |
+
|
229 |
+
print(features["model.layers.19"].shape)
|
230 |
+
|
231 |
+
|
232 |
+
# Intervention example
|
233 |
+
|
234 |
+
sae = load_sae(
|
235 |
+
path="./llama-3-8b-d-hidden.pth",
|
236 |
+
d_model=4096,
|
237 |
+
expansion_factor=16,
|
238 |
+
)
|
239 |
+
|
240 |
+
PIRATE_FEATURE_INDEX = 0
|
241 |
+
VALUE_TO_MODIFY = 0.1
|
242 |
+
|
243 |
+
def example_intervention(activations: nnsight.InterventionProxy):
|
244 |
+
features = sae.encode(activations).detach()
|
245 |
+
reconstructed_acts = sae.decode(features).detach()
|
246 |
+
error = activations - reconstructed_acts
|
247 |
+
|
248 |
+
# Modify feature at index 0 across all token positions
|
249 |
+
features[:, 0] += 0.1
|
250 |
+
|
251 |
+
# Very important to add the error term back in!
|
252 |
+
return sae.decode(features) + error
|
253 |
+
|
254 |
+
|
255 |
+
logits, kv_cache, features = llama_3_1_8b.forward(
|
256 |
+
input_tokens,
|
257 |
+
interventions={"model.layers.19": example_intervention},
|
258 |
+
)
|
259 |
+
|
260 |
+
print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
|
261 |
+
```
|