File size: 8,756 Bytes
44eb94a
 
 
 
 
 
a73b7ee
 
348ec1b
e0ece42
 
 
 
 
 
 
7d7e289
 
348ec1b
7d7e289
72ed567
7d7e289
 
 
 
 
4e9deb7
 
5fd1483
 
348ec1b
4e9deb7
38773cd
4e9deb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fd1483
 
348ec1b
5fd1483
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
---
license: llama3.3
language:
- en
base_model:
- meta-llama/Llama-3.3-70B-Instruct
---

## Model Information

The Goodfire SAE (Sparse Autoencoder) for Llama 3.3 70B is an interpreter model designed to analyze and understand 
the internal representations of Llama-3.3-70B-Instruct. This SAE model is trained specifically on layer 50 of 
Llama 3.3 70B and achieves an L0 count of 121, enabling the decomposition of complex neural activations 
into interpretable features. The model is optimized for interpretability tasks and model steering applications, 
allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. 
As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control 
over large language model operations.

## Intended Use

By open-sourcing SAEs for leading open models, especially large-scale 
models like Llama 3.3 70B, we aim to accelerate progress in interpretability research. 

Our initial work with these SAEs has revealed promising applications in model steering, 
enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai). 
We look forward to seeing how the research community builds upon these 
foundations and uncovers new applications.

#### Feature labels

## How to use

```python
import torch
from typing import Optional, Callable

import nnsight
from nnsight.intervention import InterventionProxy


# Autoencoder


class SparseAutoEncoder(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        d_hidden: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.device = device
        self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
        self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
        self.dtype = dtype
        self.to(self.device, self.dtype)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a batch of data using a linear, followed by a ReLU."""
        return torch.nn.functional.relu(self.encoder_linear(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a batch of data using a linear."""
        return self.decoder_linear(x)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """SAE forward pass. Returns the reconstruction and the encoded features."""
        f = self.encode(x)
        return self.decode(f), f


def load_sae(
    path: str,
    d_model: int,
    expansion_factor: int,
    device: torch.device = torch.device("cpu"),
):
    sae = SparseAutoEncoder(
        d_model,
        d_model * expansion_factor,
        device,
    )
    sae_dict = torch.load(
        path, weights_only=True, map_location=device
    )
    sae.load_state_dict(sae_dict)

    return sae


# Lanngugae model


InterventionInterface = Callable[[InterventionProxy], InterventionProxy]


class ObservableLanguageModel:
    def __init__(
        self,
        model: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.dtype = dtype
        self.device = device
        self._original_model = model


        self._model = nnsight.LanguageModel(
            self._original_model,
            device_map=device,
            torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
        )

        self.tokenizer = self._model.tokenizer

        self.d_model = self._attempt_to_infer_hidden_layer_dimensions()

        self.safe_mode = False  # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.

    def _attempt_to_infer_hidden_layer_dimensions(self):
        config = self._model.config
        if hasattr(config, "hidden_size"):
            return int(config.hidden_size)

        raise Exception(
            "Could not infer hidden number of layer dimensions from model config"
        )

    def _find_module(self, hook_point: str):
        submodules = hook_point.split(".")
        module = self._model
        while submodules:
            module = getattr(module, submodules.pop(0))
        return module

    def forward(
        self,
        inputs: torch.Tensor,
        cache_activations_at: Optional[list[str]] = None,
        interventions: Optional[dict[str, InterventionInterface]] = None,
        use_cache: bool = True,
        past_key_values: Optional[tuple[torch.Tensor]] = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
        cache: dict[str, torch.Tensor] = {}
        with self._model.trace(
            inputs,
            scan=self.safe_mode,
            validate=self.safe_mode,
            use_cache=use_cache,
            past_key_values=past_key_values,
        ):
            # If we input an intervention
            if interventions:
                for hook_site in interventions.keys():
                    if interventions[hook_site] is None:
                        continue

                    module = self._find_module(hook_site)

                    if self.cleanup_intervention_layer:
                        last_layer = self._find_module(
                            self.cleanup_intervention_layer
                        )
                    else:
                        last_layer = None

                    intervened_acts, direct_effect_tensor = interventions[
                        hook_site
                    ](module.output[0])
                    # Add direct effect tensor as 0 if it is None
                    if direct_effect_tensor is None:
                        direct_effect_tensor = 0
                    # We only modify module.output[0]
                    if use_cache:
                        module.output = (
                            intervened_acts,
                            module.output[1],
                        )
                        if last_layer:
                            last_layer.output = (
                                last_layer.output[0] - direct_effect_tensor,
                                last_layer.output[1],
                            )
                    else:
                        module.output = (intervened_acts,)
                        if last_layer:
                            last_layer.output = (
                                last_layer.output[0] - direct_effect_tensor,
                            )

            if cache_activations_at is not None:
                for hook_point in cache_activations_at:
                    module = self._find_module(hook_point)
                    cache[hook_point] = module.output.save()

            if not past_key_values:
                logits = self._model.output[0][:, -1, :].save()
            else:
                logits = self._model.output[0].squeeze(1).save()
           
            kv_cache = self._model.output.past_key_values.save()

        return (
            logits.value.detach(),
            kv_cache.value,
            {k: v[0].detach() for k, v in cache.items()},
        )


# Reading out features from the model

llama_3_1_8b = ObservableLanguageModel(
    "meta-llama/Llama-3.1-8B-Instruct",
)

input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    return_tensors="pt",
)
logits, kv_cache, features = llama_3_1_8b.forward(
    input_tokens,
    cache_activations_at=["model.layers.19"],
)

print(features["model.layers.19"].shape)


# Intervention example

sae = load_sae(
    path="./llama-3-8b-d-hidden.pth",
    d_model=4096,
    expansion_factor=16,
)

PIRATE_FEATURE_INDEX = 0
VALUE_TO_MODIFY = 0.1

def example_intervention(activations: nnsight.InterventionProxy):
    features = sae.encode(activations).detach()
    reconstructed_acts = sae.decode(features).detach()
    error = activations - reconstructed_acts

    # Modify feature at index 0 across all token positions
    features[:, 0] += 0.1

    # Very important to add the error term back in!
    return sae.decode(features) + error


logits, kv_cache, features = llama_3_1_8b.forward(
    input_tokens,
    interventions={"model.layers.19": example_intervention},
)

print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
```

## Responsibility & Safety

Safety is at the core of everything we do at Goodfire. As a public benefit 
corporation, we’re dedicated to understanding AI models to enable safer, more reliable 
generative AI. You can read more about our comprehensive approach to 
safety and responsible development in our detailed [safety overview](https://www.goodfire.ai/blog/our-approach-to-safety/).