|
--- |
|
pipeline_tag: text-generation |
|
tags: |
|
- text-generation-inference |
|
- backpack |
|
- backpackmodel |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- openwebtext |
|
language: |
|
- en |
|
--- |
|
|
|
# Model Card for Levanter-Backpack-1.4B |
|
This is 1.4B parameter version of [Backpack architecture](https://arxiv.org/abs/2305.16765), intended to combine strong modeling performance |
|
with an interface for interpretability and control. |
|
|
|
# Training Details |
|
|
|
## Training Data |
|
This model was trained on the [OpenWebText](https://huggingface.co/datasets/openwebtext) corpus. |
|
## Training Procedure |
|
|
|
This model was trained for 450k gradient steps and cosine decaying learning rate from 1e-4 to zero, with a linear warmup of 5k steps. |
|
|
|
# Environmental Impact |
|
|
|
- **Hardware Type:** v3-128 TPU (128 cores, 2TB Memory) |
|
- **Hours used:** Roughly 8.6 days. |
|
- **Cloud Provider:** Google Cloud Patform |
|
- **Compute Region:** North America. |
|
|
|
## Model Architecture and Objective |
|
|
|
This model was trained to minimize the cross-entropy loss, and is a [Backpack language model](https://arxiv.org/pdf/2305.16765.pdf). |
|
|
|
### Software |
|
|
|
This model was trained with [Levanter](https://github.com/stanford-crfm/levanter/) and [Jax](https://github.com/google/jax). |
|
|
|
### Loss Curve |
|
![Loss Curve](assets/train_loss.png) |
|
|
|
# How to Get Started with the Model |
|
|
|
Please install `transformers`, `safetensors` and `torch` to use this model. |
|
|
|
```bash |
|
pip install transformers safetensors torch |
|
``` |
|
|
|
Run the following Python code: |
|
|
|
```python |
|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
model_id = "stanford-crfm/levanter-backpack-1b" |
|
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
|
torch_model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
config=config, |
|
trust_remote_code=True |
|
) |
|
torch_model.eval() |
|
|
|
input = torch.randint(0, 50264, (1, 512), dtype=torch.long) |
|
torch_out = torch_model(input, position_ids=None,) |
|
torch_out = torch.nn.functional.softmax(torch_out.logits, dim=-1) |
|
print(torch_out.shape) |
|
``` |
|
|