Welcome Gemma 2 - Google’s new open LLM

Published June 27, 2024
Update on GitHub

Google released Gemma 2, the latest addition to its family of state-of-the-art open LLMs, and we are excited to collaborate with Google to ensure the best integration in the Hugging Face ecosystem. You can find the 4 open-weight models (2 base models & 2 fine-tuned ones) on the Hub. Among the features and integrations being released, we have:

Table of contents

What is Gemma 2?

Gemma 2 is Google's latest iteration of open LLMs. It comes in two sizes, 9 billion and 27 billion parameters with base (pre-trained) and instruction-tuned versions. Gemma is based on Google Deepmind Gemini and has a context length of 8K tokens:

The Gemma 2 models were trained on ~2x more data than their first iteration, totaling 13 trillion tokens for the 27B version and 8 trillion tokens for the 9B version of web data (primarily English), code, and math. We don’t know the exact details of the training mix, and we can only guess that bigger and more careful data curation was a big factor in the improved performance.

Gemma 2 comes with the same license as the first iteration, which is a permissive license that allows redistribution, fine-tuning, commercial use, and derivative works.

Technical advances in Gemma 2

Gemma 2 has many similarities with the first iteration. It has a context length of 8192 tokens and uses Rotary Position Embedding (RoPE). There are four main advances in Gemma 2 compared to the original Gemma:

  • Sliding window attention: Interleave sliding window and full-quadratic attention for quality generation.
  • Logit soft-capping: Prevents logits from growing excessively by scaling them to a fixed range, improving training.
  • Knowledge Distillation: Leverage a larger teacher model to train a smaller model (for the 9B model).
  • Model Merging: Combines two or more LLMs into a single new model

Gemma 2 was trained on Google Cloud TPU (27B on v5p, 9B on TPU v4) using JAX and ML Pathways. Gemma 2 Instruct has been optimized for dialogue applications and trained on a mix of synthetic and human-generated prompt-response pairs using Supervised Fine-Tuning (SFT), Distillation from a larger model, Reinforcement Learning from Human Feedback (RLHF) using a reward model oriented more towards conversational capabilities, and model merging using WARP to improve overall performance.

Similar to the pre-training mix, no details about the fine-tuning datasets or the hyperparameters associated with SFT and RLHF have been shared.

Sliding window attention

Sliding window attention is a method to reduce the memory and time requirements of the attention computations in transformer models and has been used in models such as Mistral. The novelty of Gemma 2 is that a sliding window is applied to every other layer (local - 4096 tokens), while the layers in between still use full quadratic global attention (8192 tokens). We suppose this is a way to increase quality in long context situations (half of the layers still attend to all tokens) while partially benefiting from the advantages of sliding attention.

Soft-capping and attention implementations

Soft capping is a technique that prevents logits from growing excessively large without truncating them. It works by dividing the logits by a maximum value threshold (soft_cap), then passing them through a tanh layer (ensuring they are in the (-1, 1) range), and finally multiplying by the threshold again. This guarantees that the final values will be in the (-soft_cap, +soft_cap) interval without losing much information but stabilizing the training.

Putting it all together, the logits are calculated by: logits ← soft_cap ∗ tanh(logits/soft_cap)

Gemma 2 employs soft capping for the final layer and for every attention layer. The attention logits are capped at 50.0, and the final logits at 30.0.

At the time of release, soft-capping is incompatible with Flash Attention / SDPA, but they can still be used in inference for maximum efficiency. The Gemma 2 team observed very minor differences when soft-capping is removed during inference.

Note: For stable fine-tuning runs, you still need to enable soft-capping and hence, we recommend fine-tuning with eager attention instead of SDPA.

Knowledge Distillation

Knowledge distillation is a popular technique for training a smaller student model to mimic the behavior of a larger but better-performing teacher. This works by augmenting the next-token prediction task of LLMs with a distribution of token probabilities from the teacher (e.g., GPT-4, Claude, or Gemini), which provides a richer signal for the student to learn from.

According to the Gemma 2 tech report, knowledge distillation was used to pre-train the 9B model, while the 27B model was pre-trained from scratch.

For post-training, the Gemma 2 team generated a diverse set of completions from a teacher (unspecified in the report, but presumably Gemini Ultra), and then trained the student models on this synthetic data with SFT. This is the basis of many open models, such as Zephyr and OpenHermes, which are trained entirely on synthetic data from larger LLMs.

Although effective, this method has drawbacks since the model capacity mismatch between the student and teacher can lead to a train-inference mismatch, where the text generated by the student during inference is out-of-distribution compared to that seen during training.

To handle this issue, the Gemma 2 team used “on-policy distillation”, where the student generates completions from the SFT prompts. These completions are then used to compute the KL divergence between the teacher’s and student’s logits. By minimizing the KL divergence throughout training, the student learns to model the behavior of the teacher accurately while also minimizing the train-inference mismatch.

This approach is quite interesting, as we’ve seen in the community that on-policy methods like online DPO produce stronger models, and one advantage of on-policy distillation is that you only need the logits from the teacher, so you don’t need to rely on reward models or LLM-as-a-judge to improve the model. It will be exciting to see if this method becomes more popular among fine-tuners in the coming months!

Model Merging

Model merging is a technique that combines two or more LLMs into a single new model. It's relatively new and experimental and can be used without accelerators. Mergekit is a popular open-source toolkit for merging LLMs. It implements linear, SLERP, TIES, DARE, and other merging techniques.

According to the Technical Report, Gemma 2 used Warp, a new merging technique that merges models in three distinct stages:

  1. Exponential Moving Average (EMA): This is applied during the reinforcement learning (RL) fine-tuning process.
  2. Spherical Linear intERPolation (SLERP): This is applied after the RL fine-tuning of multiple policies.
  3. Linear Interpolation Towards Initialization (LITI): This stage is applied after the SLERP stage.

Gemma 2 evaluation

How good are the Gemma models? Below are performance comparisons to other open models based on the Technical Report and the new version of the open LLM Leaderboard.

Technical Report results

This Technical Report of Gemma 2 compares the performance of different open LLMs on the previous Open LLM Leaderboard benchmarks.

Llama 3 (70B) Qwen 1.5 (32B) Gemma 2 (27B)
MMLU 79.2 74.3 75.2
GSM8K 76.9 61.1 75.1
ARC-c 68.8 63.6 71.4
HellaSwag 88.0 85.0 86.4
Winogrande 85.3 81.5 83.7

The Report also compares the performance of Small Language Models.

Benchmark Mistral (7B) Llama 3 (8B) Gemma (8B) Gemma 2 (9B)
MMLU 62.5 66.6 64.4 71.3
GSM8K 34.5 45.7 50.9 62.3
ARC-C 60.5 59.2 61.1 68.4
HellaSwag 83.0 82.0 82.3 81.9
Winogrande 78.5 78.5 79.0 80.6

Open LLM Leaderboard results

Note: We are currently evaluating Google Gemma 2 individually on the new Open LLM Leaderboard benchmark and will update this section later today.

How to prompt Gemma 2

The base models have no prompt format. Like other base models, they can be used to continue an input sequence with a plausible continuation or for zero-shot/few-shot inference. The Instruct versions have a very simple conversation structure:

<start_of_turn>user
knock knock<end_of_turn>
<start_of_turn>model
who is there<end_of_turn>
<start_of_turn>user
LaMDA<end_of_turn>
<start_of_turn>model
LaMDA who?<end_of_turn><eos>

This format has to be exactly reproduced for effective use. We’ll later show how easy it is to reproduce the instruct prompt with the chat template available in transformers.

Demo

You can chat with the Gemma 27B Instruct model on Hugging Chat! Check out the link here: https://huggingface.co/chat/models/google/gemma-2-27b-it.

Using Hugging Face Transformers

With Transformers release 4.42, you can use Gemma and leverage all the tools within the Hugging Face ecosystem. To use Gemma models with transformers, make sure to use the latest transformers release:

pip install "transformers>=4.42.3" --upgrade

The following snippet shows how to use gemma-2-9b-it with transformers. It requires about 18 GB of RAM, which fits many consumer GPUs. The same snippet works for gemma-2-27b-it, which, at 56GB of RAM, makes it a very interesting model for production use cases. Memory consumption can be further reduced by loading in 8-bit or 4-bit mode.

from transformers import pipeline
import torch

pipe = pipeline(
    "text-generation",
    model="google/gemma-2-9b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

messages = [
    {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
]
outputs = pipe(
    messages,
    max_new_tokens=256,
    do_sample=False,
)
assistant_response = outputs[0]["generated_text"][-1]["content"]
print(assistant_response)

Ahoy, matey! I be a humble ship o' words, sailin' the digital seas. They call me Gemma, a creation o' the fine folks at Google DeepMind. I be trained on a treasure trove o' texts, learnin' to speak and write like a true scallywag.

Ask me yer questions, and I'll do me best to answer 'em, aye! 🦜📚

We used bfloat16 because that’s the reference precision for the instruction-tuned model. Running in float16 may be faster on your hardware, and results should be similar on the 9B model. Do note, however, that the 27B instruction-tuned model produces erratic outputs when using float16: you must use bfloat16 for that model weight.

You can also automatically quantize the model, loading it in 8-bit or even 4-bit mode. 4-bit loading of the large 27B version takes about 18 GB of memory to run, making it compatible with a lot of consumer cards and GPUs in Google Colab. This is how you’d load the generation pipeline in 4-bit:

pipeline = pipeline(
    "text-generation",
    model=model,
    model_kwargs={
        "torch_dtype": torch.bfloat16,
        "quantization_config": {"load_in_4bit": True}
    },
)

For more details on using the models with transformers, please check the model cards.

Integration with Google Cloud

Note: We are currently working on adding new containers to GKE and Vertex AI to run Google Gemma 2 efficiently. We will update this section as soon as the containers are available.

Fine-tuning with 🤗 TRL

Training LLMs can be technically and computationally challenging. In this section, we’ll look at the tools available in the Hugging Face ecosystem to efficiently train Gemma on consumer-size GPUs

An example command to fine-tune Gemma on OpenAssistant’s chat dataset can be found below. We use 4-bit quantization and QLoRA to conserve memory to target all the attention blocks' linear layers. Note that, unlike dense transformers, one should not target the MLP layers as they are sparse and don’t interact well with PEFT.

First, install the nightly version of 🤗 TRL and clone the repo to access the training script:

pip install "transformers>=4.42.3" --upgrade
pip install --upgrade bitsandbytes
pip install --ugprade peft
pip install git+https://github.com/huggingface/trl
git clone https://github.com/huggingface/trl
cd trl

Then you can run the script:

# peft tuning; single GPU; https://wandb.ai/costa-huang/huggingface/runs/l1l53cst
python \
    examples/scripts/sft.py \
    --model_name google/gemma-2-27b \
    --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
    --dataset_text_field="text" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-4 \
    --report_to wandb \
    --bf16 \
    --max_seq_length 1024 \
    --lora_r 16 --lora_alpha 32 \
    --lora_target_modules q_proj k_proj v_proj o_proj \
    --load_in_4bit \
    --use_peft \
    --attn_implementation eager \
    --logging_steps=10 \
    --gradient_checkpointing \
    --output_dir models/gemma2

alt_text

If you have more GPUs to spare, you can run training with DeepSpeed and ZeRO Stage 3:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft.py \
    --model_name google/gemma-2-27b \
    --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
    --dataset_text_field="text" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-5 \
    --report_to wandb \
    --bf16 \
    --max_seq_length 1024 \
    --attn_implementation eager \
    --logging_steps=10 \
    --gradient_checkpointing \
    --output_dir models/gemma2

alt_text

Integration with Inference Endpoints

You can deploy Gemma 2 on Hugging Face's Inference Endpoints using Text Generation Inference as the backend. Text Generation Inference is a production-ready inference container developed by Hugging Face to enable easy deployment of large language models. It has features such as continuous batching, token streaming, tensor parallelism for fast inference on multiple GPUs, and production-ready logging and tracing.

To deploy a Gemma 2 model, go to the model page and click on the Deploy -> Inference Endpoints widget. Inference Endpoints supports OpenAI compatible Messages API that allows you to switch from another closed model to an open one by simply changing the URL.

from openai import OpenAI

# initialize the client but point it to TGI
client = OpenAI(
    base_url="<ENDPOINT_URL>" + "/v1/",  # replace with your endpoint url
    api_key="<HF_API_TOKEN>",  # replace with your token
)
chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {"role": "user", "content": "Why is open-source software important?"},
    ],
    stream=True,
    max_tokens=500
)

# iterate and print stream
for message in chat_completion:
    print(message.choices[0].delta.content, end="")

Additional Resources

Acknowledgments

Releasing such models with support and evaluations in the ecosystem would not be possible without the contributions of many community members, including Clémentine and Nathan for LLM evaluations; Nicolas for Text Generation Inference Support; Arthur, Sanchit, Joao, and Lysandre for integrating Gemma 2 into transformers; Nathan and Victor for making Gemma 2 available in Hugging Chat.

And Thank you to the Google Team for releasing Gemma 2 and making it available to the open-source AI community!