lv12's picture
Update README.md
7ec400f verified
|
raw
history blame
3.8 kB
metadata
language:
  - en
library_name: transformers
tags:
  - cross-encoder
  - search
  - product-search
base_model: cross-encoder/ms-marco-MiniLM-L-12-v2
model-index:
  - name: esci-ms-marco-MiniLM-L-12-v2
    results:
      - task:
          type: text-classification
        metrics:
          - type: mrr@10
            value: 91.81
          - type: ndcg@10
            value: 85.46

Model Descripton

Fine tunes a cross encoder on the Amazon ESCI dataset.

Usage

Transformers

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch import no_grad

model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"

queries = [
    "adidas shoes",
    "adidas shoes",
    "girls sandals",
    "backpacks",
    "shoes", 
    "mustard sleeveless gown"
]
documents =  [
    '{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}',
    '{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}',
    '{"title": "Womens sandals", "description": "Sandals:  wide width 9", "brand": "Chacos", "color": "blue"}',
    '{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}',
    '{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}',
    '{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}'
]

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(
    queries,
    documents,
    padding=True,
    truncation=True,
    return_tensors="pt",
)

model.eval()
with no_grad():
    scores = model(**inputs).logits.cpu().detach().numpy()
    print(scores)

Sentence Transformers

from sentence_transformers import CrossEncoder

model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"


queries = [
    "adidas shoes",
    "adidas shoes",
    "girls sandals",
    "backpacks",
    "shoes", 
    "mustard sleeveless gown"
]
documents =  [
    '{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}',
    '{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}',
    '{"title": "Womens sandals", "description": "Sandals:  wide width 9", "brand": "Chacos", "color": "blue"}',
    '{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}',
    '{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}',
    '{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}'
]
model = CrossEncoder(model_name, max_length=512)
scores = model.predict([(q, d) for q, d in zip(queries, documents)])
print(scores)
[ 1.057739   1.6751697  1.039221   1.5969192 -0.8867093  0.5035825 ]

Training

Trained using CrossEntropyLoss using <query, document> pairs with grade as the label.

from sentence_transformers import InputExample

train_samples = [
    InputExample(texts=["query 1", "document 1"], label=0.3),
    InputExample(texts=["query 1", "document 2"], label=0.8),
    InputExample(texts=["query 2", "document 2"], label=0.1),
]