Better Implementation of PairRM
Introduction
This version of PairRM have some fixes on training process, which improve model's performance by 15%.
Minor Fixes
- Longer Context Length (2048 -> 3370)
Thanks to deberta's tokenzer, original PairRM model had enough Context Length.
But, the longer the better :>
Major Fixes
- Change Prompt Format
Why use something like
<Response i + 1> {response}
So, I changed to a format based on Vicuna 1.1.
- Change Truncate side
The original process was using right side truncate even on Input. This can cause serious problem when Input exceeds model's context length.
- Dataset Filter
There was decent amount of empty assistant response on original dataset. So, I dropped them.
Example Code
The code below is modified from (PairRM-hf Repo)[https://huggingface.co/llm-blender/PairRM-hf]
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import AutoTokenizer
from typing import List
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
source_prefix = "<|source|>"
cand1_prefix = "<|candidate1|>"
cand2_prefix = "<|candidate2|>"
inputs = ["hello!", "I love you!"]
candidates_A = ["hi!", "I hate you!"]
candidates_B = ["f**k off!", "I love you, too!"]
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
ids = []
assert len(sources) == len(candidate1s) == len(candidate2s)
max_length = source_max_length + 2 * candidate_max_length
for i in range(len(sources)):
source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
candidate_max_length = (max_length - len(source_ids)) // 2
candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
ids.append(source_ids + candidate1_ids + candidate2_ids)
encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
return encodings
encodings = tokenize_pair(inputs, candidates_A, candidates_B)
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
outputs = pairrm(**encodings)
logits = outputs.logits.tolist()
comparison_results = outputs.logits > 0
print(logits)
print(comparison_results)
You can also easily compare two conversations like the followings:
import jinja2
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
def truncate_texts(text, max_length, truncate_side):
tokenizer.truncation_side = truncate_side
tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text
MY_JINJA_TEMPLATE = """{% for message in messages -%}
{% if message['role'] == 'user' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'assistant' -%}
ASSISTANT: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'user_context' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'system' -%}
SYSTEM MESSAGE: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% endif %}
{% endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
ASSISTANT: {% endif -%}"""
my_jinja2_env = jinja2.Environment()
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)
def tokenize_conv_pair(convAs: List[str], convBs: List[str]):
# check conversations correctness
assert len(convAs) == len(convBs), "Number of conversations must be the same"
for c_a, c_b in zip(convAs, convBs):
assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
inputs = [
truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
]
cand1_texts = [
truncate_texts(x[-1]['content'], 670, "right") for x in convAs
]
cand2_texts = [
truncate_texts(x[-1]['content'], 670, "right") for x in convBs
]
encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
return encodings
Statistics
Context length
PairRanker type | Source max length | Candidate max length | Total max length |
---|---|---|---|
pair-ranker | 128 | 128 | 384 |
PairRM | 1224 | 412 | 2048 |
Better-PairRM (This model) | 2030 | 670 | 3370 |
Performance
Reward-Bench by AllenAI
Metric | llm-blender/PairRM-hf | maywell/Better-PairRM |
---|---|---|
model | llm-blender/PairRM-hf | maywell/Better-PairRM |
model_type | Custom Classifier | Custom Classifier |
alpacaeval-length | 0.758 | 0.863 |
alpacaeval-hard | 0.979 | 1.000 |
alpacaeval-easy | 0.970 | 0.990 |
donotanswer | 0.360 | 0.522 |
hep-cpp | 0.628 | 0.646 |
hep-go | 0.689 | 0.713 |
hep-java | 0.628 | 0.713 |
hep-js | 0.604 | 0.707 |
hep-python | 0.646 | 0.713 |
hep-rust | 0.652 | 0.726 |
llmbar-adver-GPTInst | 0.304 | 0.141 |
llmbar-adver-GPTOut | 0.596 | 0.447 |
llmbar-adver-manual | 0.500 | 0.261 |
llmbar-adver-neighbor | 0.433 | 0.276 |
llmbar-natural | 0.800 | 0.720 |
math-prm | 0.333 | 0.295 |
mt-bench-hard | 0.649 | 0.703 |
mt-bench-med | 0.900 | 1.000 |
mt-bench-easy | 0.964 | 0.929 |
refusals-dangerous | 0.080 | 0.730 |
refusals-offensive | 0.010 | 0.940 |
xstest-should-refuse | 0.370 | 0.968 |
xstest-should-respond | 0.952 | 0.876 |
average | 0.600 | 0.690 |
Note - llmbar test score is bit weird across all models on Reward-Bench
Thanks to
- Sionic AI for providing the A100 cluster.
Contact
Original Paper
@inproceedings{llm-blender-2023,
title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
year = "2023"
}
- Downloads last month
- 132