Korean Reranker Training on Amazon SageMaker

ํ•œ๊ตญ์–ด Reranker ๊ฐœ๋ฐœ์„ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ฐ€์ด๋“œ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.

ko-reranker๋Š” BAAI/bge-reranker-larger ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ fine-tuned model ์ž…๋‹ˆ๋‹ค.
๋ณด๋‹ค ์ž์„ธํ•œ ์‚ฌํ•ญ์€ korean-reranker-git์„ ์ฐธ๊ณ ํ•˜์„ธ์š”


0. Features

  • Reranker๋Š” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ์งˆ๋ฌธ๊ณผ ๋ฌธ์„œ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉฐ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ์œ ์‚ฌ๋„๋ฅผ ์ง์ ‘ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

  • Reranker์— ์งˆ๋ฌธ๊ณผ ๊ตฌ์ ˆ์„ ์ž…๋ ฅํ•˜๋ฉด ์—ฐ๊ด€์„ฑ ์ ์ˆ˜๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • Reranker๋Š” CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ตœ์ ํ™”๋˜๋ฏ€๋กœ ๊ด€๋ จ์„ฑ ์ ์ˆ˜๊ฐ€ ํŠน์ • ๋ฒ”์œ„์— ๊ตญํ•œ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

1.Usage

  • using Transformers
    def exp_normalize(x):
      b = x.max()
      y = np.exp(x - b)
      return y / y.sum()
    
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.eval()

    pairs = [["๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"], \
             ["๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"]]

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
        scores = exp_normalize(scores.numpy())
        print (f'first: {scores[0]}, second: {scores[1]}')
  • using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
    'HF_TASK':'text-classification'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.28.1',
    pytorch_version='2.0.0',
    py_version='py310',
    env=hub,
    role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.g5.large' # ec2 instance type
)

runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
    {
        "inputs": [
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "text_pair": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"},
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "text_pair": "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"}
        ]
    }
)

response = runtime_client.invoke_endpoint(
    EndpointName="<endpoint-name>",
    ContentType="application/json",
    Accept="application/json",
    Body=payload
)

## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')

2. Backgound

  • ์ปจํƒ์ŠคํŠธ ์ˆœ์„œ๊ฐ€ ์ •ํ™•๋„์— ์˜ํ–ฅ ์ค€๋‹ค(Lost in Middel, Liu et al., 2023)

  • Reranker ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ 

    • ํ˜„์žฌ LLM์€ context ๋งŽ์ด ๋„ฃ๋Š”๋‹ค๊ณ  ์ข‹์€๊ฑฐ ์•„๋‹˜, relevantํ•œ๊ฒŒ ์ƒ์œ„์— ์žˆ์–ด์•ผ ์ •๋‹ต์„ ์ž˜ ๋งํ•ด์ค€๋‹ค
    • Semantic search์—์„œ ์‚ฌ์šฉํ•˜๋Š” similarity(relevant) score๊ฐ€ ์ •๊ตํ•˜์ง€ ์•Š๋‹ค. (์ฆ‰, ์ƒ์œ„ ๋žญ์ปค๋ฉด ํ•˜์œ„ ๋žญ์ปค๋ณด๋‹ค ํ•ญ์ƒ ๋” ์งˆ๋ฌธ์— ์œ ์‚ฌํ•œ ์ •๋ณด๊ฐ€ ๋งž์•„?)
      • Embedding์€ meaning behind document๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒƒ์— ํŠนํ™”๋˜์–ด ์žˆ๋‹ค.
      • ์งˆ๋ฌธ๊ณผ ์ •๋‹ต์ด ์˜๋ฏธ์ƒ ๊ฐ™์€๊ฑด ์•„๋‹ˆ๋‹ค. (Hypothetical Document Embeddings)
      • ANNs(Approximate Nearest Neighbors) ์‚ฌ์šฉ์— ๋”ฐ๋ฅธ ํŒจ๋„ํ‹ฐ

3. Reranker models


4. Dataset

  • msmarco-triplets

    • (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
    • ํ•ด๋‹น ๋ฐ์ดํ„ฐ ์…‹์€ ์˜๋ฌธ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
    • Amazon Translate ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฒˆ์—ญํ•˜์—ฌ ํ™œ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.
  • Format

{"query": str, "pos": List[str], "neg": List[str]}
  • Query๋Š” ์งˆ๋ฌธ์ด๊ณ , pos๋Š” ๊ธ์ • ํ…์ŠคํŠธ ๋ชฉ๋ก, neg๋Š” ๋ถ€์ • ํ…์ŠคํŠธ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๋ถ€์ • ํ…์ŠคํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ „์ฒด ๋ง๋ญ‰์น˜์—์„œ ์ผ๋ถ€๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•˜์—ฌ ๋ถ€์ • ํ…์ŠคํŠธ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • Example

{"query": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š”?", "pos": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ํ•œ๊ตญ์€ ์„œ์šธ์ด๋‹ค."], "neg": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ๋ถํ•œ์€ ํ‰์–‘์ด๋‹ค."]}

5. Performance

Model has-right-in-contexts mrr (mean reciprocal rank)
without-reranker (default) 0.93 0.80
with-reranker (bge-reranker-large) 0.95 0.84
with-reranker (fine-tuned using korean) 0.96 0.87
  • evaluation set:
./dataset/evaluation/eval_dataset.csv
  • training parameters:
{
    "learning_rate": 5e-6,
    "fp16": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "train_group_size": 3,
    "max_len": 512,
    "weight_decay": 0.01,
}

6. Acknowledgement


7. Citation

  • If you find this repository useful, please consider giving a like โญ and citation

8. Contributors:

  • Dongjin Jang, Ph.D. (AWS AI/ML Specislist Solutions Architect) | Mail | Linkedin | Git |

9. License

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
Unable to determine this model's library. Check the docs .