--- license: llama3.3 language: - en base_model: - meta-llama/Llama-3.3-70B-Instruct --- ### Model Information The Goodfire SAE (Sparse Autoencoder) for Llama 3.3 70B is an interpreter model designed to analyze and understand the internal representations of Llama-3.3-70B-Instruct. This SAE model is trained specifically on layer 50 of  Llama 3.3 70B and achieves an L0 count of 121, enabling the decomposition of complex neural activations into interpretable features. The model is optimized for interpretability tasks and model steering applications, allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control over large language model operations. ### Intended Use By open-sourcing SAEs for leading open models, especially large-scale models like Llama 3.3 70B, we aim to accelerate progress in interpretability research. Our initial work with these SAEs has revealed promising applications in model steering, enhancing jailbreaking safeguards, and interpretable classification methods (docs.goodfire.ai). We look forward to seeing how the research community builds upon these foundations and uncovers new applications. ### How to use ``` import torch from typing import Optional, Callable import nnsight from nnsight.intervention import InterventionProxy # Autoencoder class SparseAutoEncoder(torch.nn.Module): def __init__( self, d_in: int, d_hidden: int, device: torch.device, dtype: torch.dtype = torch.bfloat16, ): super().__init__() self.d_in = d_in self.d_hidden = d_hidden self.device = device self.encoder_linear = torch.nn.Linear(d_in, d_hidden) self.decoder_linear = torch.nn.Linear(d_hidden, d_in) self.dtype = dtype self.to(self.device, self.dtype) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a batch of data using a linear, followed by a ReLU.""" return torch.nn.functional.relu(self.encoder_linear(x)) def decode(self, x: torch.Tensor) -> torch.Tensor: """Decode a batch of data using a linear.""" return self.decoder_linear(x) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """SAE forward pass. Returns the reconstruction and the encoded features.""" f = self.encode(x) return self.decode(f), f def load_sae( path: str, d_model: int, expansion_factor: int, device: torch.device = torch.device("cpu"), ): sae = SparseAutoEncoder( d_model, d_model * expansion_factor, device, ) sae_dict = torch.load( path, weights_only=True, map_location=device ) sae.load_state_dict(sae_dict) return sae # Lanngugae model InterventionInterface = Callable[[InterventionProxy], InterventionProxy] class ObservableLanguageModel: def __init__( self, model: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): self.dtype = dtype self.device = device self._original_model = model self._model = nnsight.LanguageModel( self._original_model, device_map=device, torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype ) self.tokenizer = self._model.tokenizer self.d_model = self._attempt_to_infer_hidden_layer_dimensions() self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug. def _attempt_to_infer_hidden_layer_dimensions(self): config = self._model.config if hasattr(config, "hidden_size"): return int(config.hidden_size) raise Exception( "Could not infer hidden number of layer dimensions from model config" ) def _find_module(self, hook_point: str): submodules = hook_point.split(".") module = self._model while submodules: module = getattr(module, submodules.pop(0)) return module def forward( self, inputs: torch.Tensor, cache_activations_at: Optional[list[str]] = None, interventions: Optional[dict[str, InterventionInterface]] = None, use_cache: bool = True, past_key_values: Optional[tuple[torch.Tensor]] = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]: cache: dict[str, torch.Tensor] = {} with self._model.trace( inputs, scan=self.safe_mode, validate=self.safe_mode, use_cache=use_cache, past_key_values=past_key_values, ): # If we input an intervention if interventions: for hook_site in interventions.keys(): if interventions[hook_site] is None: continue module = self._find_module(hook_site) if self.cleanup_intervention_layer: last_layer = self._find_module( self.cleanup_intervention_layer ) else: last_layer = None intervened_acts, direct_effect_tensor = interventions[ hook_site ](module.output[0]) # Add direct effect tensor as 0 if it is None if direct_effect_tensor is None: direct_effect_tensor = 0 # We only modify module.output[0] if use_cache: module.output = ( intervened_acts, module.output[1], ) if last_layer: last_layer.output = ( last_layer.output[0] - direct_effect_tensor, last_layer.output[1], ) else: module.output = (intervened_acts,) if last_layer: last_layer.output = ( last_layer.output[0] - direct_effect_tensor, ) if cache_activations_at is not None: for hook_point in cache_activations_at: module = self._find_module(hook_point) cache[hook_point] = module.output.save() if not past_key_values: logits = self._model.output[0][:, -1, :].save() else: logits = self._model.output[0].squeeze(1).save() kv_cache = self._model.output.past_key_values.save() return ( logits.value.detach(), kv_cache.value, {k: v[0].detach() for k, v in cache.items()}, ) # Reading out features from the model llama_3_1_8b = ObservableLanguageModel( "meta-llama/Llama-3.1-8B-Instruct", ) input_tokens = llama_3_1_8b.tokenizer.apply_chat_template( [ {"role": "user", "content": "Hello, how are you?"}, ], return_tensors="pt", ) logits, kv_cache, features = llama_3_1_8b.forward( input_tokens, cache_activations_at=["model.layers.19"], ) print(features["model.layers.19"].shape) # Intervention example sae = load_sae( path="./llama-3-8b-d-hidden.pth", d_model=4096, expansion_factor=16, ) PIRATE_FEATURE_INDEX = 0 VALUE_TO_MODIFY = 0.1 def example_intervention(activations: nnsight.InterventionProxy): features = sae.encode(activations).detach() reconstructed_acts = sae.decode(features).detach() error = activations - reconstructed_acts # Modify feature at index 0 across all token positions features[:, 0] += 0.1 # Very important to add the error term back in! return sae.decode(features) + error logits, kv_cache, features = llama_3_1_8b.forward( input_tokens, interventions={"model.layers.19": example_intervention}, ) print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1))) ```