|
--- |
|
license: llama3.3 |
|
language: |
|
- en |
|
base_model: |
|
- meta-llama/Llama-3.3-70B-Instruct |
|
tags: |
|
- mechanistic interpretability |
|
- sparse autoencoder |
|
--- |
|
|
|
## Model Information |
|
|
|
The Goodfire SAE (Sparse Autoencoder) for [meta-llama/Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) |
|
is an interpreter model designed to analyze and understand |
|
the model's internal representations. This SAE model is trained specifically on layer 50 of |
|
Llama 3.3 70B and achieves an L0 count of 121, enabling the decomposition of complex neural activations |
|
into interpretable features. The model is optimized for interpretability tasks and model steering applications, |
|
allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. |
|
As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control |
|
over large language model operations. |
|
|
|
__Model Creator__: [meta-llama](https://huggingface.co/meta-llama) |
|
|
|
By using Goodfire/Llama-3.3-70B-Instruct__model.layers.50 you agree to the [LLAMA 3.3 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/LICENSE) |
|
|
|
|
|
## Intended Use |
|
|
|
By open-sourcing SAEs for leading open models, especially large-scale |
|
models like Llama 3.3 70B, we aim to accelerate progress in interpretability research. |
|
|
|
Our initial work with these SAEs has revealed promising applications in model steering, |
|
enhancing jailbreaking safeguards, and interpretable classification methods. |
|
We look forward to seeing how the research community builds upon these |
|
foundations and uncovers new applications. |
|
|
|
#### Feature labels |
|
|
|
To explore the feature labels check out the [Goodfire Ember SDK](https://www.goodfire.ai/blog/announcing-goodfire-ember/). |
|
The SDK provides an intuitive interface for interacting with these |
|
features, allowing you to investigate how Llama processes information |
|
and even steer its behavior. Get started with feature |
|
exploration at [docs.goodfire.ai](https://docs.goodfire.ai). |
|
|
|
## How to use |
|
|
|
```python |
|
import torch |
|
from typing import Optional, Callable |
|
|
|
import nnsight |
|
from nnsight.intervention import InterventionProxy |
|
|
|
|
|
# Autoencoder |
|
|
|
|
|
class SparseAutoEncoder(torch.nn.Module): |
|
def __init__( |
|
self, |
|
d_in: int, |
|
d_hidden: int, |
|
device: torch.device, |
|
dtype: torch.dtype = torch.bfloat16, |
|
): |
|
super().__init__() |
|
self.d_in = d_in |
|
self.d_hidden = d_hidden |
|
self.device = device |
|
self.encoder_linear = torch.nn.Linear(d_in, d_hidden) |
|
self.decoder_linear = torch.nn.Linear(d_hidden, d_in) |
|
self.dtype = dtype |
|
self.to(self.device, self.dtype) |
|
|
|
def encode(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Encode a batch of data using a linear, followed by a ReLU.""" |
|
return torch.nn.functional.relu(self.encoder_linear(x)) |
|
|
|
def decode(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Decode a batch of data using a linear.""" |
|
return self.decoder_linear(x) |
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""SAE forward pass. Returns the reconstruction and the encoded features.""" |
|
f = self.encode(x) |
|
return self.decode(f), f |
|
|
|
|
|
def load_sae( |
|
path: str, |
|
d_model: int, |
|
expansion_factor: int, |
|
device: torch.device = torch.device("cpu"), |
|
): |
|
sae = SparseAutoEncoder( |
|
d_model, |
|
d_model * expansion_factor, |
|
device, |
|
) |
|
sae_dict = torch.load( |
|
path, weights_only=True, map_location=device |
|
) |
|
sae.load_state_dict(sae_dict) |
|
|
|
return sae |
|
|
|
|
|
# Lanngugae model |
|
|
|
|
|
InterventionInterface = Callable[[InterventionProxy], InterventionProxy] |
|
|
|
|
|
class ObservableLanguageModel: |
|
def __init__( |
|
self, |
|
model: str, |
|
device: str = "cuda", |
|
dtype: torch.dtype = torch.bfloat16, |
|
): |
|
self.dtype = dtype |
|
self.device = device |
|
self._original_model = model |
|
|
|
|
|
self._model = nnsight.LanguageModel( |
|
self._original_model, |
|
device_map=device, |
|
torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype |
|
) |
|
|
|
self.tokenizer = self._model.tokenizer |
|
|
|
self.d_model = self._attempt_to_infer_hidden_layer_dimensions() |
|
|
|
self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug. |
|
|
|
def _attempt_to_infer_hidden_layer_dimensions(self): |
|
config = self._model.config |
|
if hasattr(config, "hidden_size"): |
|
return int(config.hidden_size) |
|
|
|
raise Exception( |
|
"Could not infer hidden number of layer dimensions from model config" |
|
) |
|
|
|
def _find_module(self, hook_point: str): |
|
submodules = hook_point.split(".") |
|
module = self._model |
|
while submodules: |
|
module = getattr(module, submodules.pop(0)) |
|
return module |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
cache_activations_at: Optional[list[str]] = None, |
|
interventions: Optional[dict[str, InterventionInterface]] = None, |
|
use_cache: bool = True, |
|
past_key_values: Optional[tuple[torch.Tensor]] = None, |
|
) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]: |
|
cache: dict[str, torch.Tensor] = {} |
|
with self._model.trace( |
|
inputs, |
|
scan=self.safe_mode, |
|
validate=self.safe_mode, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
): |
|
# If we input an intervention |
|
if interventions: |
|
for hook_site in interventions.keys(): |
|
if interventions[hook_site] is None: |
|
continue |
|
|
|
module = self._find_module(hook_site) |
|
|
|
if self.cleanup_intervention_layer: |
|
last_layer = self._find_module( |
|
self.cleanup_intervention_layer |
|
) |
|
else: |
|
last_layer = None |
|
|
|
intervened_acts, direct_effect_tensor = interventions[ |
|
hook_site |
|
](module.output[0]) |
|
# Add direct effect tensor as 0 if it is None |
|
if direct_effect_tensor is None: |
|
direct_effect_tensor = 0 |
|
# We only modify module.output[0] |
|
if use_cache: |
|
module.output = ( |
|
intervened_acts, |
|
module.output[1], |
|
) |
|
if last_layer: |
|
last_layer.output = ( |
|
last_layer.output[0] - direct_effect_tensor, |
|
last_layer.output[1], |
|
) |
|
else: |
|
module.output = (intervened_acts,) |
|
if last_layer: |
|
last_layer.output = ( |
|
last_layer.output[0] - direct_effect_tensor, |
|
) |
|
|
|
if cache_activations_at is not None: |
|
for hook_point in cache_activations_at: |
|
module = self._find_module(hook_point) |
|
cache[hook_point] = module.output.save() |
|
|
|
if not past_key_values: |
|
logits = self._model.output[0][:, -1, :].save() |
|
else: |
|
logits = self._model.output[0].squeeze(1).save() |
|
|
|
kv_cache = self._model.output.past_key_values.save() |
|
|
|
return ( |
|
logits.value.detach(), |
|
kv_cache.value, |
|
{k: v[0].detach() for k, v in cache.items()}, |
|
) |
|
|
|
|
|
# Reading out features from the model |
|
|
|
llama_3_1_8b = ObservableLanguageModel( |
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
) |
|
|
|
input_tokens = llama_3_1_8b.tokenizer.apply_chat_template( |
|
[ |
|
{"role": "user", "content": "Hello, how are you?"}, |
|
], |
|
return_tensors="pt", |
|
) |
|
logits, kv_cache, features = llama_3_1_8b.forward( |
|
input_tokens, |
|
cache_activations_at=["model.layers.19"], |
|
) |
|
|
|
print(features["model.layers.19"].shape) |
|
|
|
|
|
# Intervention example |
|
|
|
sae = load_sae( |
|
path="./llama-3-8b-d-hidden.pth", |
|
d_model=4096, |
|
expansion_factor=16, |
|
) |
|
|
|
PIRATE_FEATURE_INDEX = 0 |
|
VALUE_TO_MODIFY = 0.1 |
|
|
|
def example_intervention(activations: nnsight.InterventionProxy): |
|
features = sae.encode(activations).detach() |
|
reconstructed_acts = sae.decode(features).detach() |
|
error = activations - reconstructed_acts |
|
|
|
# Modify feature at index 0 across all token positions |
|
features[:, 0] += 0.1 |
|
|
|
# Very important to add the error term back in! |
|
return sae.decode(features) + error |
|
|
|
|
|
logits, kv_cache, features = llama_3_1_8b.forward( |
|
input_tokens, |
|
interventions={"model.layers.19": example_intervention}, |
|
) |
|
|
|
print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1))) |
|
``` |
|
|
|
## Responsibility & Safety |
|
|
|
Safety is at the core of everything we do at Goodfire. As a public benefit |
|
corporation, we’re dedicated to understanding AI models to enable safer, more reliable |
|
generative AI. You can read more about our comprehensive approach to |
|
safety and responsible development in our detailed [safety overview](https://www.goodfire.ai/blog/our-approach-to-safety/). |