Korean Reranker Training on Amazon SageMaker
ํ๊ตญ์ด Reranker ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค.
ko-reranker๋ BAAI/bge-reranker-larger ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค.
๋ณด๋ค ์์ธํ ์ฌํญ์ korean-reranker-git์ ์ฐธ๊ณ ํ์ธ์
0. Features
Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.
Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์ ๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.
Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.
1.Usage
- using Transformers
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"], \
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
- using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type='ml.g5.large' # ec2 instance type
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "๋๋ ๋๋ฅผ ์ซ์ดํด", "text_pair": "๋๋ ๋๋ฅผ ์ฌ๋ํด"},
{"text": "๋๋ ๋๋ฅผ ์ข์ํด", "text_pair": "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')
2. Backgound
์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค(Lost in Middel, Liu et al., 2023)
Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?)
- Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค.
- ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. (Hypothetical Document Embeddings)
- ANNs(Approximate Nearest Neighbors) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ
3. Reranker models
[Cohere] Reranker
[BAAI] bge-reranker-large
[BAAI] bge-reranker-base
4. Dataset
msmarco-triplets
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
- ํด๋น ๋ฐ์ดํฐ ์ ์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค.
Format
{"query": str, "pos": List[str], "neg": List[str]}
Query๋ ์ง๋ฌธ์ด๊ณ , pos๋ ๊ธ์ ํ ์คํธ ๋ชฉ๋ก, neg๋ ๋ถ์ ํ ์คํธ ๋ชฉ๋ก์ ๋๋ค. ์ฟผ๋ฆฌ์ ๋ํ ๋ถ์ ํ ์คํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ๋ง๋ญ์น์์ ์ผ๋ถ๋ฅผ ๋ฌด์์๋ก ์ถ์ถํ์ฌ ๋ถ์ ํ ์คํธ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค.
Example
{"query": "๋ํ๋ฏผ๊ตญ์ ์๋๋?", "pos": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ํ๊ตญ์ ์์ธ์ด๋ค."], "neg": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ๋ถํ์ ํ์์ด๋ค."]}
5. Performance
Model | has-right-in-contexts | mrr (mean reciprocal rank) |
---|---|---|
without-reranker (default) | 0.93 | 0.80 |
with-reranker (bge-reranker-large) | 0.95 | 0.84 |
with-reranker (fine-tuned using korean) | 0.96 | 0.87 |
- evaluation set:
./dataset/evaluation/eval_dataset.csv
- training parameters:
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01,
}
6. Acknowledgement
- Part of the code is developed based on FlagEmbedding and KoSimCSE-SageMaker.
7. Citation
- If you find this repository useful, please consider giving a like โญ and citation
8. Contributors:
9. License
- FlagEmbedding is licensed under the MIT License.