Sentence Similarity
Transformers
Safetensors
multilingual
nllb-llm2vec
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
import math | |
import warnings | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import transformers | |
from packaging import version | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm import tqdm | |
from transformers.cache_utils import Cache | |
from transformers.modeling_outputs import ( | |
BaseModelOutputWithPooling, | |
ModelOutput, | |
SequenceClassifierOutputWithPast, | |
TokenClassifierOutput, | |
) | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification | |
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder | |
from transformers.tokenization_utils import BatchEncoding | |
from .configuration_nllbllm2vec import NLLBLLM2VecConfig | |
from .modeling_llama_encoder import LlamaEncoderModel | |
DEFAULT_TOKENIZE_KWARGS = { | |
"padding": True, | |
"truncation": True, | |
"max_length": 512, | |
"return_tensors": "pt", | |
} | |
DEFAULT_DATALOADER_KWARGS = { | |
"shuffle": False, | |
"batch_size": 32, | |
"pin_memory": True, | |
} | |
def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: | |
def collate_fn(batch: list[str]) -> BatchEncoding: | |
return tokenizer(batch, **tokenize_kwargs) | |
return collate_fn | |
def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict: | |
return default_dict if kwd_dict is None else {**default_dict, **kwd_dict} | |
class SequenceClassifierOutputWithPastAndPooler(ModelOutput): | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | |
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | |
pooler_output: torch.FloatTensor = None | |
class NLLBLLM2Vec(PreTrainedModel): | |
config_class = NLLBLLM2VecConfig | |
model_type = "nllb-llm2vec" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
""" | |
NLLBLLM2Vec model combining NLLB and LLama encoders. | |
Args: | |
config (Optional[NLLBLLM2VecConfig]): Configuration object. | |
nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder. | |
llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder. | |
*inputs: Additional positional arguments. | |
**kwargs: Additional keyword arguments. | |
""" | |
def __init__( | |
self, | |
config: Optional[NLLBLLM2VecConfig] = None, | |
nllb_encoder: Optional[M2M100Encoder] = None, | |
llm2vec: Optional[LlamaEncoderModel] = None, | |
*inputs, | |
**kwargs, | |
): | |
# Ensure that either config is not None or both encoders are provided | |
if config is None and (nllb_encoder is None or llm2vec is None): | |
raise ValueError( | |
"Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified." | |
) | |
if config is not None: | |
super().__init__(config, *inputs, **kwargs) | |
# from_pretrained overwrites this after config instantiation, so we make sure it's correctly set | |
config.nllb_config._attn_implementation = config._attn_implementation | |
config.llm2vec_config._attn_implementation = config._attn_implementation | |
self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config) | |
self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config) | |
self.config = config | |
else: | |
# Both encoders are provided | |
self.nllb_encoder = cast(M2M100Encoder, nllb_encoder) | |
self.llm2vec = cast(LlamaEncoderModel, llm2vec) | |
self.config = NLLBLLM2VecConfig( | |
nllb_config=self.nllb_encoder.config, # type: ignore | |
llm2vec_config=self.llm2vec.config, # type: ignore | |
) | |
super().__init__(self.config, *inputs, **kwargs) | |
self.up_proj = nn.Linear( | |
self.nllb_encoder.config.d_model, | |
self.llm2vec.config.hidden_size, | |
bias=False, | |
) | |
# TODO: update this once commit is included | |
min_version = "4.46.0" | |
if self.config.nllb_config._attn_implementation == "flash_attention_2": | |
if version.parse(transformers.__version__) < version.parse(min_version): | |
warnings.warn( | |
f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.", | |
UserWarning, | |
) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
*args, | |
**kwargs, | |
) -> BaseModelOutputWithPooling: | |
""" | |
Forward pass of the model. | |
Args: | |
input_ids (torch.Tensor): Input token IDs. | |
attention_mask (torch.Tensor): Attention mask. | |
indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets. | |
Returns: | |
BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output. | |
""" | |
# Compute input indices and offsets if not provided | |
if indices is None: | |
seq_indices, seq_offsets = self._get_input_offsets(attention_mask) | |
else: | |
seq_indices, seq_offsets = indices | |
nllb_outputs = self.nllb_encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
) | |
nllb_last_hidden_state = nllb_outputs.last_hidden_state | |
nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state) | |
outputs = self.llm2vec( | |
inputs_embeds=nllb_last_hidden_state, | |
attention_mask=attention_mask, | |
) | |
pooler_output = self._mean_embedding( | |
hidden_states=outputs.last_hidden_state, | |
input_indices=seq_indices, | |
offsets=seq_offsets, | |
) | |
return BaseModelOutputWithPooling( | |
last_hidden_state=outputs.last_hidden_state, | |
pooler_output=pooler_output, | |
) | |
def tokenizer(self): | |
""" | |
Get the tokenizer associated with the model. | |
Returns: | |
PreTrainedTokenizer: The tokenizer instance. | |
""" | |
if not hasattr(self, "_tokenizer"): | |
from transformers import AutoTokenizer | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/nllb-200-distilled-600M", padding_side="right" | |
) | |
return self._tokenizer | |
def encode( | |
self, | |
inputs: List[str], | |
src_lang: str = "eng_Latn", | |
dataloader_kwargs: Optional[Dict[str, Any]] = None, | |
tokenize_kwargs: Optional[Dict[str, Any]] = None, | |
collate_fn_closure: Optional[Callable] = None, | |
) -> torch.Tensor: | |
""" | |
Encode input texts into embeddings. | |
Args: | |
inputs (List[str]): List of input texts. | |
src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`). | |
dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`. | |
Defaults to: | |
>> dataloader_kwargs = { | |
>> "shuffle": False, | |
>> "pin_memory": True, | |
>> } | |
tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. | |
Defaults to: | |
>> tokenize_kwargs = { | |
>> "padding": True, | |
>> "truncation": True, | |
>> "max_length": 512, | |
>> "return_tensors": "pt", | |
>> } | |
collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`. | |
Defaults to: | |
>> def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: | |
>> def collate_fn(batch: list[str]) -> BatchEncoding: | |
>> return tokenizer(batch, **tokenize_kwargs) | |
>> return collate_fn | |
Returns: | |
torch.Tensor: Mean-pooled sequence embeddings of the inputs. | |
""" | |
# merge user kwargs with defaults, giving priority to user kwargs | |
tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS) | |
dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS) | |
tokenizer = self.tokenizer | |
tokenizer.src_lang = src_lang | |
device = next(self.parameters()).device | |
if collate_fn_closure is None: | |
collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs) | |
else: | |
collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs) | |
assert ( | |
"collate_fn" not in dataloader_kwargs | |
), "`collate_fn` should be created via `collate_fn_closure`" | |
self.eval() | |
if len(inputs) > dataloader_kwargs.get("batch_size", 1): | |
dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs) # type: ignore | |
all_embeddings = [] | |
# Iterate through the dataloader with a progress bar and autocast | |
with torch.autocast(device_type=device.type, dtype=torch.bfloat16): | |
for batch in tqdm(dataloader, desc="Encoding"): | |
# Move batch to device | |
batch = {k: v.to(device) for k, v in batch.items()} | |
# Forward pass through the model (assumes model returns embeddings) | |
with torch.inference_mode(): | |
pooled_embeddings = cast( | |
SequenceClassifierOutputWithPastAndPooler, self(**batch) | |
).pooler_output # Assuming model returns sequence embeddings | |
all_embeddings.append(pooled_embeddings) | |
# Concatenate all pooled embeddings along the batch dimension | |
all_embeddings = torch.cat(all_embeddings, dim=0) | |
else: | |
batch = {k: v.to(device) for k, v in collate_fn(inputs).items()} | |
with torch.inference_mode(): | |
all_embeddings = cast( | |
SequenceClassifierOutputWithPastAndPooler, self(**batch) | |
).pooler_output # Assuming model returns sequence embeddings | |
return all_embeddings | |
def _get_input_offsets( | |
attention_mask: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Compute indices and offsets for mean pooling using EmbeddingBag. | |
Args: | |
attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: A tuple containing: | |
- input_indices: Indices of non-padded tokens in the flattened input. | |
- offsets: Offsets indicating the start index of each sequence in the flattened input. | |
""" | |
# Find the indices of non-padded tokens in flattened hidden_states | |
input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() | |
# Compute the offsets: for each sequence, where it starts in the flattened input | |
non_padded_lengths = attention_mask.sum( | |
dim=1 | |
) # Count non-padded tokens per sequence | |
offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1) | |
offsets[0] = 0 | |
return input_indices, offsets | |
def _mean_embedding( | |
hidden_states: torch.Tensor, | |
input_indices: torch.Tensor, | |
offsets: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Compute the mean of non-padded embeddings using `embedding_bag`, | |
properly handling padding with offsets. | |
Args: | |
hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). | |
input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. | |
offsets (torch.Tensor): Offsets specifying the start of each sequence. | |
Returns: | |
torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). | |
""" | |
# Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) | |
batch_size, seq_len, embed_dim = hidden_states.shape | |
token_embeds = hidden_states.view(-1, embed_dim) | |
# Use embedding_bag with mode 'mean' and appropriate indices | |
return F.embedding_bag( | |
input=input_indices, # Indices of non-padded tokens in flattened form | |
weight=token_embeds, # The flattened hidden states as embedding matrix | |
offsets=offsets, # Offsets specifying start of each sequence | |
mode="mean", # Aggregation mode | |
) | |
class NLLBLLM2VecForSequenceClassification(PreTrainedModel): | |
config_class = NLLBLLM2VecConfig | |
model_type = "nllb-llm2vec" | |
base_model_prefix = "model" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = NLLBLLM2Vec(config) | |
self.score = nn.Linear( | |
config.llm2vec_config.hidden_size, self.num_labels, bias=False | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def _init_weights(self, module): | |
if module is self.score: | |
# INFO: | |
# - critical that clf head is in float32 (NusaX perf. drops funky otherwise) | |
# - Initialization needs to be redone, otherwise borked | |
# - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse | |
self.score = self.score.to(torch.float32) | |
torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
elif isinstance(module, nn.Linear): | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
def get_input_embeddings(self): | |
return self.model.nllb.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.nllb.embed_tokens = value | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs.pooler_output | |
pooled_logits = self.score(hidden_states) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and ( | |
labels.dtype == torch.long or labels.dtype == torch.int | |
): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
if self.num_labels == 1: | |
loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze()) | |
else: | |
loss = F.mse_loss(pooled_logits, labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss = F.cross_entropy( | |
pooled_logits.view(-1, self.num_labels), labels.view(-1) | |
) | |
elif self.config.problem_type == "multi_label_classification": | |
loss = F.binary_cross_entropy_with_logits(pooled_logits, labels) | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutputWithPastAndPooler( | |
loss=loss, | |
hidden_states=hidden_states, | |
logits=pooled_logits, | |
pooler_output=transformer_outputs.pooler_output, | |
) | |
class NLLBLLM2VecForTokenClassification(PreTrainedModel): | |
config_class = NLLBLLM2VecConfig | |
model_type = "nllb-llm2vec" | |
base_model_prefix = "model" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
def __init__(self, config: NLLBLLM2VecConfig): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = NLLBLLM2Vec(config) | |
self.classifier = nn.Linear( | |
config.llm2vec_config.hidden_size, self.num_labels, bias=False | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def _init_weights(self, module): | |
if module is self.classifier: | |
# INFO: | |
# - critical that clf head is in float32 (NusaX perf. drops funky otherwise) | |
# - Initialization needs to be redone, otherwise borked | |
# - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse | |
self.classifier = self.classifier.to(torch.float32) | |
torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
elif isinstance(module, nn.Linear): | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
def get_input_embeddings(self): | |
return self.model.nllb.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.nllb.embed_tokens = value | |
# adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification | |
# - removed classifier dropout | |
# - use F.cross_entropy | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | |
""" | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
# move labels to correct device to enable model parallelism | |
labels = labels.to(logits.device) | |
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec) | |
AutoModelForSequenceClassification.register( | |
NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification | |
) | |
AutoModelForSequenceClassification.register( | |
NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification | |
) | |