Updates

Hi, everyone, thanks for using stella models. After six months of work, I trained the jasper model on top of the stella model, which is a multimodal model, and it can be ranked 2 in mteb (submitted the results on 2024-12-11, which may need official review https://github.com/embeddings-benchmark/results/pull/68).

Model link: https://huggingface.co/infgrad/jasper_en_vision_language_v1

I'll focus on the technical report, training data and related code, hopefully the tricks I've used will be of some help to you guys!

This work was accomplished during my free time, it's a personal hobby. One person's time and energy is limited, and you are welcome to make any contributions!

You can also find these models on my homepage.

Introduction

The models are trained based on Alibaba-NLP/gte-large-en-v1.5 and Alibaba-NLP/gte-Qwen2-1.5B-instruct. Thanks for their contributions!

We simplify usage of prompts, providing two prompts for most general tasks, one is for s2p, another one is for s2s.

Prompt of s2p task(e.g. retrieve task):

Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: {query}

Prompt of s2s task(e.g. semantic textual similarity task):

Instruct: Retrieve semantically similar text.\nQuery: {query}

The models are finally trained by MRL, so they have multiple dimensions: 512, 768, 1024, 2048, 4096, 6144 and 8192.

The higher the dimension, the better the performance. Generally speaking, 1024d is good enough. The MTEB score of 1024d is only 0.001 lower than 8192d.

Model directory structure

The model directory structure is very simple, it is a standard SentenceTransformer directory with a series of 2_Dense_{dims} folders, where dims represents the final vector dimension.

For example, the 2_Dense_256 folder stores Linear weights that convert vector dimensions to 256 dimensions. Please refer to the following chapters for specific instructions on how to use them.

Usage

You can use SentenceTransformers or transformers library to encode text.

Sentence Transformers

from sentence_transformers import SentenceTransformer

# This model supports two prompts: "s2p_query" and "s2s_query" for sentence-to-passage and sentence-to-sentence tasks, respectively.
# They are defined in `config_sentence_transformers.json`
query_prompt_name = "s2p_query"
queries = [
    "What are some ways to reduce stress?",
    "What are the benefits of drinking green tea?",
]
# docs do not need any prompts
docs = [
    "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.",
    "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.",
]

# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json` to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
model = SentenceTransformer("dunzhang/stella_en_1.5B_v5", trust_remote_code=True).cuda()
query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
doc_embeddings = model.encode(docs)
print(query_embeddings.shape, doc_embeddings.shape)
# (2, 1024) (2, 1024)

similarities = model.similarity(query_embeddings, doc_embeddings)
print(similarities)
# tensor([[0.8179, 0.2958],
#         [0.3194, 0.7854]])

Transformers

import os
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize

query_prompt = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
queries = [
    "What are some ways to reduce stress?",
    "What are the benefits of drinking green tea?",
]
queries = [query_prompt + query for query in queries]
# docs do not need any prompts
docs = [
    "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.",
    "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.",
]

# The path of your model after cloning it
model_dir = "{Your MODEL_PATH}"

vector_dim = 1024
vector_linear_directory = f"2_Dense_{vector_dim}"
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
vector_linear = torch.nn.Linear(in_features=model.config.hidden_size, out_features=vector_dim)
vector_linear_dict = {
    k.replace("linear.", ""): v for k, v in
    torch.load(os.path.join(model_dir, f"{vector_linear_directory}/pytorch_model.bin")).items()
}
vector_linear.load_state_dict(vector_linear_dict)
vector_linear.cuda()

# Embed the queries
with torch.no_grad():
    input_data = tokenizer(queries, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    input_data = {k: v.cuda() for k, v in input_data.items()}
    attention_mask = input_data["attention_mask"]
    last_hidden_state = model(**input_data)[0]
    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
    query_vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    query_vectors = normalize(vector_linear(query_vectors).cpu().numpy())

# Embed the documents
with torch.no_grad():
    input_data = tokenizer(docs, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    input_data = {k: v.cuda() for k, v in input_data.items()}
    attention_mask = input_data["attention_mask"]
    last_hidden_state = model(**input_data)[0]
    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
    docs_vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    docs_vectors = normalize(vector_linear(docs_vectors).cpu().numpy())

print(query_vectors.shape, docs_vectors.shape)
# (2, 1024) (2, 1024)

similarities = query_vectors @ docs_vectors.T
print(similarities)
# [[0.8178789  0.2958377 ]
#  [0.31938642 0.7853526 ]]

FAQ

Q: The details of training?

A: The training method and datasets will be released in the future. (specific time unknown, may be provided in a paper)

Q: How to choose a suitable prompt for my own task?

A: In most cases, please use the s2p and s2s prompts. These two prompts account for the vast majority of the training data.

Q: How to reproduce MTEB results?

A: Please use evaluation scripts in Alibaba-NLP/gte-Qwen2-1.5B-instruct or intfloat/e5-mistral-7b-instruct

Q: Why each dimension has a linear weight?

A: MRL has multiple training methods, we choose this method which has the best performance.

Q: What is the sequence length of models?

A: 512 is recommended, in our experiments, almost all models perform poorly on specialized long text retrieval datasets. Besides, the model is trained on datasets of 512 length. This may be an optimization term.

If you have any questions, please start a discussion on community.

Downloads last month
39
Safetensors
Model size
1.54B params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Evaluation results