fdschmidt93's picture
fix(model): correctly configure AutoModelForTokenClassification
0a0e65e
import math
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from packaging import version
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
ModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification, AutoModelForTokenClassification
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
from transformers.tokenization_utils import BatchEncoding
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
from .modeling_llama_encoder import LlamaEncoderModel
DEFAULT_TOKENIZE_KWARGS = {
"padding": True,
"truncation": True,
"max_length": 512,
"return_tensors": "pt",
}
DEFAULT_DATALOADER_KWARGS = {
"shuffle": False,
"batch_size": 32,
"pin_memory": True,
}
def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
def collate_fn(batch: list[str]) -> BatchEncoding:
return tokenizer(batch, **tokenize_kwargs)
return collate_fn
def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict:
return default_dict if kwd_dict is None else {**default_dict, **kwd_dict}
@dataclass
class SequenceClassifierOutputWithPastAndPooler(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
pooler_output: torch.FloatTensor = None
class NLLBLLM2Vec(PreTrainedModel):
config_class = NLLBLLM2VecConfig
model_type = "nllb-llm2vec"
_supports_flash_attn_2 = True
_supports_sdpa = True
"""
NLLBLLM2Vec model combining NLLB and LLama encoders.
Args:
config (Optional[NLLBLLM2VecConfig]): Configuration object.
nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder.
llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder.
*inputs: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
config: Optional[NLLBLLM2VecConfig] = None,
nllb_encoder: Optional[M2M100Encoder] = None,
llm2vec: Optional[LlamaEncoderModel] = None,
*inputs,
**kwargs,
):
# Ensure that either config is not None or both encoders are provided
if config is None and (nllb_encoder is None or llm2vec is None):
raise ValueError(
"Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified."
)
if config is not None:
super().__init__(config, *inputs, **kwargs)
# from_pretrained overwrites this after config instantiation, so we make sure it's correctly set
config.nllb_config._attn_implementation = config._attn_implementation
config.llm2vec_config._attn_implementation = config._attn_implementation
self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
self.config = config
else:
# Both encoders are provided
self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
self.llm2vec = cast(LlamaEncoderModel, llm2vec)
self.config = NLLBLLM2VecConfig(
nllb_config=self.nllb_encoder.config, # type: ignore
llm2vec_config=self.llm2vec.config, # type: ignore
)
super().__init__(self.config, *inputs, **kwargs)
self.up_proj = nn.Linear(
self.nllb_encoder.config.d_model,
self.llm2vec.config.hidden_size,
bias=False,
)
# TODO: update this once commit is included
min_version = "4.46.0"
if self.config.nllb_config._attn_implementation == "flash_attention_2":
if version.parse(transformers.__version__) < version.parse(min_version):
warnings.warn(
f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.",
UserWarning,
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> BaseModelOutputWithPooling:
"""
Forward pass of the model.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask.
indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets.
Returns:
BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output.
"""
# Compute input indices and offsets if not provided
if indices is None:
seq_indices, seq_offsets = self._get_input_offsets(attention_mask)
else:
seq_indices, seq_offsets = indices
nllb_outputs = self.nllb_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
nllb_last_hidden_state = nllb_outputs.last_hidden_state
nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
outputs = self.llm2vec(
inputs_embeds=nllb_last_hidden_state,
attention_mask=attention_mask,
)
pooler_output = self._mean_embedding(
hidden_states=outputs.last_hidden_state,
input_indices=seq_indices,
offsets=seq_offsets,
)
return BaseModelOutputWithPooling(
last_hidden_state=outputs.last_hidden_state,
pooler_output=pooler_output,
)
@property
def tokenizer(self):
"""
Get the tokenizer associated with the model.
Returns:
PreTrainedTokenizer: The tokenizer instance.
"""
if not hasattr(self, "_tokenizer"):
from transformers import AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M", padding_side="right"
)
return self._tokenizer
def encode(
self,
inputs: List[str],
src_lang: str = "eng_Latn",
dataloader_kwargs: Optional[Dict[str, Any]] = None,
tokenize_kwargs: Optional[Dict[str, Any]] = None,
collate_fn_closure: Optional[Callable] = None,
) -> torch.Tensor:
"""
Encode input texts into embeddings.
Args:
inputs (List[str]): List of input texts.
src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`).
dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`.
Defaults to:
>> dataloader_kwargs = {
>> "shuffle": False,
>> "pin_memory": True,
>> }
tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
Defaults to:
>> tokenize_kwargs = {
>> "padding": True,
>> "truncation": True,
>> "max_length": 512,
>> "return_tensors": "pt",
>> }
collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`.
Defaults to:
>> def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
>> def collate_fn(batch: list[str]) -> BatchEncoding:
>> return tokenizer(batch, **tokenize_kwargs)
>> return collate_fn
Returns:
torch.Tensor: Mean-pooled sequence embeddings of the inputs.
"""
# merge user kwargs with defaults, giving priority to user kwargs
tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS)
dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS)
tokenizer = self.tokenizer
tokenizer.src_lang = src_lang
device = next(self.parameters()).device
if collate_fn_closure is None:
collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs)
else:
collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs)
assert (
"collate_fn" not in dataloader_kwargs
), "`collate_fn` should be created via `collate_fn_closure`"
self.eval()
if len(inputs) > dataloader_kwargs.get("batch_size", 1):
dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs) # type: ignore
all_embeddings = []
# Iterate through the dataloader with a progress bar and autocast
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
for batch in tqdm(dataloader, desc="Encoding"):
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass through the model (assumes model returns embeddings)
with torch.inference_mode():
pooled_embeddings = cast(
SequenceClassifierOutputWithPastAndPooler, self(**batch)
).pooler_output # Assuming model returns sequence embeddings
all_embeddings.append(pooled_embeddings)
# Concatenate all pooled embeddings along the batch dimension
all_embeddings = torch.cat(all_embeddings, dim=0)
else:
batch = {k: v.to(device) for k, v in collate_fn(inputs).items()}
with torch.inference_mode():
all_embeddings = cast(
SequenceClassifierOutputWithPastAndPooler, self(**batch)
).pooler_output # Assuming model returns sequence embeddings
return all_embeddings
@staticmethod
def _get_input_offsets(
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute indices and offsets for mean pooling using EmbeddingBag.
Args:
attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- input_indices: Indices of non-padded tokens in the flattened input.
- offsets: Offsets indicating the start index of each sequence in the flattened input.
"""
# Find the indices of non-padded tokens in flattened hidden_states
input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze()
# Compute the offsets: for each sequence, where it starts in the flattened input
non_padded_lengths = attention_mask.sum(
dim=1
) # Count non-padded tokens per sequence
offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1)
offsets[0] = 0
return input_indices, offsets
@staticmethod
def _mean_embedding(
hidden_states: torch.Tensor,
input_indices: torch.Tensor,
offsets: torch.Tensor,
) -> torch.Tensor:
"""
Compute the mean of non-padded embeddings using `embedding_bag`,
properly handling padding with offsets.
Args:
hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim).
input_indices (torch.Tensor): Indices of non-padded tokens in flattened form.
offsets (torch.Tensor): Offsets specifying the start of each sequence.
Returns:
torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim).
"""
# Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim)
batch_size, seq_len, embed_dim = hidden_states.shape
token_embeds = hidden_states.view(-1, embed_dim)
# Use embedding_bag with mode 'mean' and appropriate indices
return F.embedding_bag(
input=input_indices, # Indices of non-padded tokens in flattened form
weight=token_embeds, # The flattened hidden states as embedding matrix
offsets=offsets, # Offsets specifying start of each sequence
mode="mean", # Aggregation mode
)
class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
config_class = NLLBLLM2VecConfig
model_type = "nllb-llm2vec"
base_model_prefix = "model"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = NLLBLLM2Vec(config)
self.score = nn.Linear(
config.llm2vec_config.hidden_size, self.num_labels, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module):
if module is self.score:
# INFO:
# - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
# - Initialization needs to be redone, otherwise borked
# - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
self.score = self.score.to(torch.float32)
torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif isinstance(module, nn.Linear):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.model.nllb.embed_tokens
def set_input_embeddings(self, value):
self.model.nllb.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs.pooler_output
pooled_logits = self.score(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
else:
loss = F.mse_loss(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPastAndPooler(
loss=loss,
hidden_states=hidden_states,
logits=pooled_logits,
pooler_output=transformer_outputs.pooler_output,
)
class NLLBLLM2VecForTokenClassification(PreTrainedModel):
config_class = NLLBLLM2VecConfig
model_type = "nllb-llm2vec"
base_model_prefix = "model"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: NLLBLLM2VecConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = NLLBLLM2Vec(config)
self.classifier = nn.Linear(
config.llm2vec_config.hidden_size, self.num_labels, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module):
if module is self.classifier:
# INFO:
# - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
# - Initialization needs to be redone, otherwise borked
# - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
self.classifier = self.classifier.to(torch.float32)
torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif isinstance(module, nn.Linear):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.model.nllb.embed_tokens
def set_input_embeddings(self, value):
self.model.nllb.embed_tokens = value
# adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification
# - removed classifier dropout
# - use F.cross_entropy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec)
AutoModelForSequenceClassification.register(
NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification
)
AutoModelForTokenClassification.register(
NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification
)