Text Classification
Safetensors
gemma2
custom_code
Ray2333 commited on
Commit
99d5a45
·
verified ·
1 Parent(s): b65f8a1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -3
README.md CHANGED
@@ -1,3 +1,95 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - weqweasdas/preference_dataset_mixture2_and_safe_pku
5
+ pipeline_tag: text-classification
6
+ base_model:
7
+ - google/gemma-2-2b-it
8
+ ---
9
+
10
+ # Introduction
11
+ The Generalizable Reward Model (GRM) aims to enhance the generalization ability of reward models for LLMs through regularizing the hidden states.
12
+
13
+ Paper: [Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs](https://arxiv.org/abs/2406.10216).
14
+
15
+
16
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64d45451c34a346181b130dd/8L8Etu-yXKP1raJaQzjwo.png)
17
+
18
+ The framework is shown above. The introduced text generation regularization markedly improves the accuracy of learned reward models across a variety of out-of-distribution tasks and effectively alleviate the over-optimization issue in RLHF (even with corrupted preference data), offering a more reliable and robust preference learning paradigm.
19
+
20
+ This reward model is finetuned from [gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) using the [weqweasdas/preference_dataset_mixture2_and_safe_pku](https://huggingface.co/datasets/weqweasdas/preference_dataset_mixture2_and_safe_pku) dataset.
21
+
22
+
23
+ ## Evaluation
24
+ We evaluate GRM-Gemma2-2B-sftreg on the [reward model benchmark](https://huggingface.co/spaces/allenai/reward-bench), where it achieves a score of 84.7.
25
+
26
+
27
+ **When evaluated using reward bench, please add '--not_quantized' to avoid performance drop.**
28
+
29
+ | Model | Average | Chat | Chat Hard | Safety | Reasoning |
30
+ |:-------------------------:|:-------------:|:---------:|:---------:|:--------:|:-----------:|
31
+ |[Ray2333/GRM-Llama3.2-3B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Llama3.2-3B-rewardmodel-ft)**(ours, 3B)**|90.9|91.6|84.9|92.7|94.6|
32
+ | [Ray2333/GRM-gemma2-2B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-gemma2-2B-rewardmodel-ft) **(Ours, 2B)**| 88.4 | 93.0 | 77.2 | 92.2 | 91.2 |
33
+ | google/gemini-1.5-pro-0514 | 88.2 | 92.3 | 80.6 | 87.9 |92.0 |
34
+ |RLHFlow/pair-preference-model-LLaMA3-8B |87.1 | 98.3 | 65.8|89.7|94.7|
35
+ |[Ray2333/GRM-llama3-8B-sftreg](https://huggingface.co/Ray2333/GRM-llama3-8B-sftreg)**(ours, 8B)**|87.0|98.6|67.8|89.2|92.3|
36
+ |google/gemini-1.5-pro-0924 | 86.8 | 94.1|77.0|85.8 |90.2|
37
+ |openai/gpt-4o-2024-08-06 | 86.7 | 96.1 | 76.1 | 88.1 | 86.6|
38
+ |[Ray2333/GRM-llama3.2-3B-sftreg](https://huggingface.co/Ray2333/GRM-llama3.2-3B-sftreg)**(ours, 3B)**|85.8|96.4|67.1|88.2|91.6|
39
+ |[Ray2333/GRM-Gemma-2B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Gemma-2B-rewardmodel-ft) **(Ours, 2B)**| 84.7 | 89.4 | 75.2 | 85.5 | 88.8 |
40
+ | openai/gpt-4o-2024-05-13 | 84.6| 96.6 | 70.4 | 86.5 | 84.9 |
41
+ | sfairXC/FsfairX-LLaMA3-RM-v0.1 (8B) | 84.4 | 99.4 | 65.1 | 86.8 | 86.4 |
42
+ | Nexusflow/Starling-RM-34B | 82.6 |96.9 |57.2 |87.7 |88.5|
43
+ | [Ray2333/GRM-Gemma2-2B-sftreg](https://huggingface.co/Ray2333/GRM-Gemma2-2B-sftreg)**(Ours, 2B)** | 81.0 | 97.2 | 59.6 | 86.9 | 80.3 |
44
+ | [Ray2333/GRM-Gemma-2B-sftreg](https://huggingface.co/Ray2333/GRM-Gemma-2B-sftreg)**(Ours, 2B)** | 75.3 | 95.5 | 48.7 | 80.0 | 76.8 |
45
+ | berkeley-nest/Starling-RM-7B-alpha (7B) | 74.6 | 98 | 43.4 | 88.6 | 74.6 |
46
+ | [Ray2333/Gemma-2B-rewardmodel-baseline](https://huggingface.co/Ray2333/Gemma-2B-rewardmodel-baseline)**(Ours, 2B)** | 73.7 | 94.1 | 46.1 | 79.6 | 75.0 |
47
+ | openbmb/UltraRM-13b (13B) | 71.3 | 96.1 | 55.3 | 45.8 | 82 |
48
+
49
+
50
+
51
+
52
+ ## Usage
53
+ **Note 1: Please download the `model.py` file from this repository to ensure the structure is loaded correctly and verify that the `v_head` is properly initialized.**
54
+
55
+ If you use the following example, the warning "Some weights of the model checkpoint at ... were not used when initializing LlamaForCausalLM" can be just omitted. If you use customized loading code, I suggest comparing the `state_dict` of the loaded model with the data loaded via `safetensors.safe_open(xx.safetensors)` or `torch.load(xx.bin)`. This verification should confirm that the weights, especially the `v_head`, are in place.
56
+
57
+ **Note 2: loading model into 8 bit could lead to performance degradation.**
58
+ ```
59
+ import torch
60
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
61
+
62
+ device = 'cuda:2'
63
+ # load model and tokenizer
64
+ tokenizer = AutoTokenizer.from_pretrained('Ray2333/GRM-Gemma2-2B-sftreg')
65
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
66
+ 'Ray2333/GRM-Gemma2-2B-sftreg', torch_dtype=torch.float16, trust_remote_code=True,
67
+ device_map=device,
68
+ )
69
+ message = [
70
+ {'role': 'user', 'content': "I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone. But I can't do that while I'm at the movie. Can you help by impersonating me by chat with her?"},
71
+ {'role': 'assistant', 'content': "Sorry, I'm not comfortable impersonating you in that way. I'm not willing to behave so dishonestly. Maybe you can just find a way to bring her to the movie, or you can find a babysitter?"}
72
+ ]
73
+ message_template = tokenizer.apply_chat_template(message, tokenize=False)
74
+ # it will look like this: "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nI'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone. But I can't do that while I'm at the movie. Can you help by impersonating me by chat with her?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nSorry, I'm not comfortable impersonating you in that way. I'm not willing to behave so dishonestly. Maybe you can just find a way to bring her to the movie, or you can find a babysitter?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".
75
+
76
+ kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
77
+ tokens = tokenizer.encode_plus(message_template, **kwargs)
78
+
79
+ with torch.no_grad():
80
+ _, _, reward_tensor = reward_model(tokens["input_ids"][0].view(1,-1).to(device), attention_mask=tokens["attention_mask"][0].view(1,-1).to(device))
81
+ reward = reward_tensor.cpu().detach().item()
82
+
83
+ ```
84
+
85
+
86
+ ## Citation
87
+ If you find this model helpful for your research, please cite GRM
88
+ ```
89
+ @inproceedings{yang2024regularizing,
90
+ title={Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs},
91
+ author={Yang, Rui and Ding, Ruomeng and Lin, Yong and Zhang, Huan and Zhang, Tong},
92
+ booktitle={Advances in Neural Information Processing Systems},
93
+ year={2024}
94
+ }
95
+ ```