Llama3-Athene-RM-8B

We introduce Llama3-Athene-RM-8B, an open-weights reward model based off Llama-3-8B-Instruct.

Usage

from transformers import LlamaModel, LlamaPreTrainedModel, TextClassificationPipeline
from torch import nn
import torch
from typing import Dict

class AtheneForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
        self.CLS_ID = 128003
        # Initialize weights and apply final processing
        self.post_init()

    def get_device(self):
        return self.model.device

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        position_ids=None,
    ):
        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_hidden_states=True,
        )
        hidden_states = transformer_outputs.hidden_states[-1]
        scores = []
        rewards = self.v_head(hidden_states).squeeze(-1)

        bs = int(input_ids.shape[0])

        for i in range(bs):
            c_inds = (input_ids[i] == self.CLS_ID).nonzero()
            c_ind = c_inds[-1].item()
            scores.append(rewards[i, c_ind])
        scores = torch.stack(scores)
        return {"scores": scores}

# Make a pipeline to handle pre and post-processing
class AtheneRewardPipeline(TextClassificationPipeline):

    def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:
        return_tensors = self.framework

        formatted = self.tokenizer.apply_chat_template(inputs, tokenize=False)

        formatted = formatted + self.tokenizer.cls_token

        return self.tokenizer(
            formatted,
            return_tensors=return_tensors,
            max_length=4096,
            padding="longest",
            truncation=True,
        )

    def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
        return model_outputs["scores"].cpu().float().item()

# Initialize the model
model = AtheneForSequenceClassification.from_pretrained("Nexusflow/Athene-RM-8B", torch_dtype=bfloat16)
tokenizer = AutoTokenizer.from_pretrained("Nexusflow/Athene-RM-8B")

# Initialize the pipeline
pipe = pipeline(
            task="text-classification",
            model=self.model,
            tokenizer=self.tokenizer,
            pipeline_class=AtheneRewardPipeline,
            device_map="auto",
        )

messages = [
    {
        "role": 'user',
        "content": "What is an Athene Noctura? Explain one sentence."
    },
    {
        "role": "assistant",
        "content": "The Athene noctua, also known as the little owl, is a small, nocturnal owl species native to Europe, Asia, and North Africa, characterized by its distinctive facial disk and piercing yellow eyes."
    }
]

print(pipe([messages])) # Print the reward!

Citation

@misc{Athene2024,
    title = {Athene-70B: Redefining the Boundaries of Post-Training for Open Models},
    url = {https://nexusflow.ai/blogs/athene},
    author = {Frick, Evan and Jin, Peter and Li, Tianle and Ganesan, Karthik and Zhang, Jian and Jiao, Jiantao and Zhu, Banghua},    
    month = {July},
    year = {2024}
}
Downloads last month
12
Safetensors
Model size
7.5B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for Nexusflow/Athene-RM-8B

Finetuned
(529)
this model