|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
tags: |
|
- jamba |
|
- mamba |
|
- moe |
|
--- |
|
|
|
# Model Card for Jamba |
|
|
|
Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It delivers throughput gains over traditional Transformer-based models, while outperforming or matching the leading models of its size class on most common benchmarks. |
|
|
|
Jamba is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations. |
|
|
|
This model card is for the base version of Jamba. It’s a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and a total of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU. |
|
|
|
For full details of this model please read the [white paper](https://arxiv.org/abs/2403.19887) and the [release blog post](https://www.ai21.com/blog/announcing-jamba). |
|
|
|
## Model Details |
|
|
|
- **Developed by:** [AI21](https://www.ai21.com) |
|
- **Model type:** Joint Attention and Mamba (Jamba) |
|
- **License:** Apache 2.0 |
|
- **Context length:** 256K |
|
- **Knowledge cutoff date:** March 5, 2024 |
|
|
|
## Usage |
|
### Presequities |
|
Jamba requires you use `transformers` version 4.39.0 or higher: |
|
```bash |
|
pip install transformers>=4.39.0 |
|
``` |
|
|
|
In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`: |
|
```bash |
|
pip install mamba-ssm causal-conv1d>=1.2.0 |
|
``` |
|
You also have to have the model on a CUDA device. |
|
|
|
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model. |
|
|
|
### Run the model |
|
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture. |
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
|
trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") |
|
|
|
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"] |
|
|
|
outputs = model.generate(input_ids, max_new_tokens=216) |
|
|
|
print(tokenizer.batch_decode(outputs)) |
|
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"] |
|
``` |
|
|
|
<details> |
|
<summary><strong>Loading the model in half precision</strong></summary> |
|
|
|
The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM |
|
import torch |
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16 |
|
``` |
|
|
|
When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index): |
|
```python |
|
from transformers import AutoModelForCausalLM |
|
import torch |
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
device_map="auto") |
|
``` |
|
|
|
</details> |
|
<details><summary><strong>Load the model in 8-bit</strong></summary> |
|
|
|
**Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True, |
|
llm_int8_skip_modules=["mamba"]) |
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
quantization_config=quantization_config) |
|
``` |
|
</details> |
|
|
|
### Fine-tuning example |
|
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library: |
|
|
|
```python |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
from peft import LoraConfig |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") |
|
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto') |
|
|
|
dataset = load_dataset("Abirate/english_quotes", split="train") |
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
num_train_epochs=3, |
|
per_device_train_batch_size=4, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
learning_rate=2e-3 |
|
) |
|
lora_config = LoraConfig( |
|
r=8, |
|
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"], |
|
task_type="CAUSAL_LM", |
|
bias="none" |
|
) |
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
peft_config=lora_config, |
|
train_dataset=dataset, |
|
dataset_text_field="quote", |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
## Results on common benchmarks |
|
| Benchmark | Score | |
|
|--------------|:-----:| |
|
| HellaSwag | 87.1% | |
|
| Arc Challenge | 64.4% | |
|
| WinoGrande | 82.5% | |
|
| PIQA | 83.2% | |
|
| MMLU | 67.4% | |
|
| BBH | 45.4% | |
|
| TruthfulQA | 46.4% | |
|
| GSM8K (CoT) | 59.9% | |
|
|
|
It's crucial that the 'BOS' token is added to all prompts, which might not be enabled by default in all eval frameworks. |
|
|
|
|
|
## Notice |
|
Jamba is a pretrained base model and did not undergo any alignment for instruct/chat interactions. |
|
|
|
As a base model, Jamba is intended for use as a foundation layer for fine tuning, training, and developing custom solutions. Jamba does not have safety moderation mechanisms and guardrails should be added for responsible and safe use. |
|
|
|
## About AI21 |
|
AI21 builds reliable, practical, and scalable AI solutions for the enterprise. |
|
|
|
Jamba is the first in AI21’s new family of models, and the Instruct version of Jamba is coming soon to the [AI21 platform](https://www.ai21.com/studio). |
|
|