File size: 22,331 Bytes
838c37a
 
 
 
b0221f6
 
 
 
838c37a
 
 
 
 
c90eb91
 
838c37a
c90eb91
838c37a
c90eb91
b0221f6
838c37a
b0221f6
838c37a
b0221f6
 
 
 
838c37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0221f6
 
 
f19606f
838c37a
 
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
 
 
b0221f6
 
 
838c37a
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
 
 
 
 
 
 
 
 
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
 
 
 
 
 
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
b0221f6
838c37a
b0221f6
 
 
 
 
 
838c37a
 
 
 
 
 
 
b0221f6
 
 
 
 
 
 
 
838c37a
 
 
 
 
 
b0221f6
 
 
838c37a
 
 
b0221f6
 
 
 
 
838c37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954e323
838c37a
 
 
 
 
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
 
b0221f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f19606f
 
 
 
838c37a
 
f19606f
 
 
 
838c37a
f19606f
 
 
 
 
 
 
 
838c37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f19606f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838c37a
f19606f
8f57823
f19606f
838c37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f19606f
 
 
b0221f6
f19606f
 
 
838c37a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
import math
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from packaging import version
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
    BaseModelOutputWithPooling,
    ModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
from transformers.tokenization_utils import BatchEncoding

from .configuration_nllbllm2vec import NLLBLLM2VecConfig
from .modeling_llama_encoder import LlamaEncoderModel

DEFAULT_TOKENIZE_KWARGS = {
    "padding": True,
    "truncation": True,
    "max_length": 512,
    "return_tensors": "pt",
}

DEFAULT_DATALOADER_KWARGS = {
    "shuffle": False,
    "batch_size": 32,
    "pin_memory": True,
}


def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
    def collate_fn(batch: list[str]) -> BatchEncoding:
        return tokenizer(batch, **tokenize_kwargs)
    return collate_fn


def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict:
    return default_dict if kwd_dict is None else {**default_dict, **kwd_dict}


@dataclass
class SequenceClassifierOutputWithPastAndPooler(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    pooler_output: torch.FloatTensor = None


class NLLBLLM2Vec(PreTrainedModel):
    config_class = NLLBLLM2VecConfig
    model_type = "nllb-llm2vec"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    """
    NLLBLLM2Vec model combining NLLB and LLama encoders.

    Args:
        config (Optional[NLLBLLM2VecConfig]): Configuration object.
        nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder.
        llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder.
        *inputs: Additional positional arguments.
        **kwargs: Additional keyword arguments.
    """

    def __init__(
        self,
        config: Optional[NLLBLLM2VecConfig] = None,
        nllb_encoder: Optional[M2M100Encoder] = None,
        llm2vec: Optional[LlamaEncoderModel] = None,
        *inputs,
        **kwargs,
    ):
        # Ensure that either config is not None or both encoders are provided
        if config is None and (nllb_encoder is None or llm2vec is None):
            raise ValueError(
                "Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified."
            )

        if config is not None:
            super().__init__(config, *inputs, **kwargs)
            # from_pretrained overwrites this after config instantiation, so we make sure it's correctly set
            config.nllb_config._attn_implementation = config._attn_implementation
            config.llm2vec_config._attn_implementation = config._attn_implementation
            self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
            self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
            self.config = config

        else:
            # Both encoders are provided
            self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
            self.llm2vec = cast(LlamaEncoderModel, llm2vec)
            self.config = NLLBLLM2VecConfig(
                nllb_config=self.nllb_encoder.config,  # type: ignore
                llm2vec_config=self.llm2vec.config,  # type: ignore
            )
            super().__init__(self.config, *inputs, **kwargs)

        self.up_proj = nn.Linear(
            self.nllb_encoder.config.d_model,
            self.llm2vec.config.hidden_size,
            bias=False,
        )

        # TODO: update this once commit is included
        min_version = "4.46.0"
        if self.config.nllb_config._attn_implementation == "flash_attention_2":
            if version.parse(transformers.__version__) < version.parse(min_version):
                warnings.warn(
                    f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.",
                    UserWarning,
                )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        *args,
        **kwargs,
    ) -> BaseModelOutputWithPooling:
        """
        Forward pass of the model.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask.
            indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets.

        Returns:
            BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output.
        """
        # Compute input indices and offsets if not provided
        if indices is None:
            seq_indices, seq_offsets = self._get_input_offsets(attention_mask)
        else:
            seq_indices, seq_offsets = indices

        nllb_outputs = self.nllb_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        nllb_last_hidden_state = nllb_outputs.last_hidden_state
        nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
        outputs = self.llm2vec(
            inputs_embeds=nllb_last_hidden_state,
            attention_mask=attention_mask,
        )
        pooler_output = self._mean_embedding(
            hidden_states=outputs.last_hidden_state,
            input_indices=seq_indices,
            offsets=seq_offsets,
        )
        return BaseModelOutputWithPooling(
            last_hidden_state=outputs.last_hidden_state,
            pooler_output=pooler_output,
        )

    @property
    def tokenizer(self):
        """
        Get the tokenizer associated with the model.

        Returns:
            PreTrainedTokenizer: The tokenizer instance.
        """
        if not hasattr(self, "_tokenizer"):
            from transformers import AutoTokenizer

            self._tokenizer = AutoTokenizer.from_pretrained(
                "facebook/nllb-200-distilled-600M", padding_side="right"
            )
        return self._tokenizer

    def encode(
        self,
        inputs: List[str],
        src_lang: str = "eng_Latn",
        dataloader_kwargs: Optional[Dict[str, Any]] = None,
        tokenize_kwargs: Optional[Dict[str, Any]] = None,
        collate_fn_closure: Optional[Callable] = None,
    ) -> torch.Tensor:
        """
        Encode input texts into embeddings.

        Args:
            inputs (List[str]): List of input texts.
            src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`).
            dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`.
                Defaults to:
                >>    dataloader_kwargs = {
                >>        "shuffle": False,
                >>        "pin_memory": True,
                >>    }
            tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
                Defaults to:
                >>    tokenize_kwargs = {
                >>        "padding": True,
                >>        "truncation": True,
                >>        "max_length": 512,
                >>        "return_tensors": "pt",
                >>    }
            collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`.
                Defaults to:
                >>    def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
                >>        def collate_fn(batch: list[str]) -> BatchEncoding:
                >>            return tokenizer(batch, **tokenize_kwargs)
                >>        return collate_fn
        Returns:
            torch.Tensor: Mean-pooled sequence embeddings of the inputs.
        """
        # merge user kwargs with defaults, giving priority to user kwargs
        tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS)
        dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS)

        tokenizer = self.tokenizer
        tokenizer.src_lang = src_lang
        device = next(self.parameters()).device

        if collate_fn_closure is None:
            collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs)
        else:
            collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs)
        assert (
            "collate_fn" not in dataloader_kwargs
        ), "`collate_fn` should be created via `collate_fn_closure`"
        self.eval()
        if len(inputs) > dataloader_kwargs.get("batch_size", 1):
            dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs)  # type: ignore
            all_embeddings = []
            # Iterate through the dataloader with a progress bar and autocast
            with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
                for batch in tqdm(dataloader, desc="Encoding"):
                    # Move batch to device
                    batch = {k: v.to(device) for k, v in batch.items()}
                    # Forward pass through the model (assumes model returns embeddings)
                    with torch.inference_mode():
                        pooled_embeddings = cast(
                            SequenceClassifierOutputWithPastAndPooler, self(**batch)
                        ).pooler_output  # Assuming model returns sequence embeddings
                    all_embeddings.append(pooled_embeddings)
            # Concatenate all pooled embeddings along the batch dimension
            all_embeddings = torch.cat(all_embeddings, dim=0)
        else:
            batch = {k: v.to(device) for k, v in collate_fn(inputs).items()}
            with torch.inference_mode():
                all_embeddings = cast(
                    SequenceClassifierOutputWithPastAndPooler, self(**batch)
                ).pooler_output  # Assuming model returns sequence embeddings
        return all_embeddings

    @staticmethod
    def _get_input_offsets(
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute indices and offsets for mean pooling using EmbeddingBag.

        Args:
            attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - input_indices: Indices of non-padded tokens in the flattened input.
                - offsets: Offsets indicating the start index of each sequence in the flattened input.
        """
        # Find the indices of non-padded tokens in flattened hidden_states
        input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze()

        # Compute the offsets: for each sequence, where it starts in the flattened input
        non_padded_lengths = attention_mask.sum(
            dim=1
        )  # Count non-padded tokens per sequence
        offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1)
        offsets[0] = 0
        return input_indices, offsets

    @staticmethod
    def _mean_embedding(
        hidden_states: torch.Tensor,
        input_indices: torch.Tensor,
        offsets: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute the mean of non-padded embeddings using `embedding_bag`,
        properly handling padding with offsets.

        Args:
            hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim).
            input_indices (torch.Tensor): Indices of non-padded tokens in flattened form.
            offsets (torch.Tensor): Offsets specifying the start of each sequence.

        Returns:
            torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim).
        """
        # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim)
        batch_size, seq_len, embed_dim = hidden_states.shape
        token_embeds = hidden_states.view(-1, embed_dim)

        # Use embedding_bag with mode 'mean' and appropriate indices
        return F.embedding_bag(
            input=input_indices,  # Indices of non-padded tokens in flattened form
            weight=token_embeds,  # The flattened hidden states as embedding matrix
            offsets=offsets,  # Offsets specifying start of each sequence
            mode="mean",  # Aggregation mode
        )


class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
    config_class = NLLBLLM2VecConfig
    model_type = "nllb-llm2vec"
    base_model_prefix = "model"
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.model = NLLBLLM2Vec(config)
        self.score = nn.Linear(
            config.llm2vec_config.hidden_size, self.num_labels, bias=False
        )

        # Initialize weights and apply final processing
        self.post_init()

    def _init_weights(self, module):
        if module is self.score:
            # INFO:
            # - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
            # - Initialization needs to be redone, otherwise borked
            #   - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
            self.score = self.score.to(torch.float32)
            torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
        elif isinstance(module, nn.Linear):
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                if module.bias is not None:
                    module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def get_input_embeddings(self):
        return self.model.nllb.embed_tokens

    def set_input_embeddings(self, value):
        self.model.nllb.embed_tokens = value

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        transformer_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs.pooler_output
        pooled_logits = self.score(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                if self.num_labels == 1:
                    loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
                else:
                    loss = F.mse_loss(pooled_logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss = F.cross_entropy(
                    pooled_logits.view(-1, self.num_labels), labels.view(-1)
                )
            elif self.config.problem_type == "multi_label_classification":
                loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
        if not return_dict:
            output = (pooled_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPastAndPooler(
            loss=loss,
            hidden_states=hidden_states,
            logits=pooled_logits,
            pooler_output=transformer_outputs.pooler_output,
        )


class NLLBLLM2VecForTokenClassification(PreTrainedModel):
    config_class = NLLBLLM2VecConfig
    model_type = "nllb-llm2vec"
    base_model_prefix = "model"
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def __init__(self, config: NLLBLLM2VecConfig):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.model = NLLBLLM2Vec(config)
        self.classifier = nn.Linear(
            config.llm2vec_config.hidden_size, self.num_labels, bias=False
        )

        # Initialize weights and apply final processing
        self.post_init()

    def _init_weights(self, module):
        if module is self.classifier:
            # INFO:
            # - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
            # - Initialization needs to be redone, otherwise borked
            #   - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
            self.classifier = self.classifier.to(torch.float32)
            torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
        elif isinstance(module, nn.Linear):
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                if module.bias is not None:
                    module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def get_input_embeddings(self):
        return self.model.nllb.embed_tokens

    def set_input_embeddings(self, value):
        self.model.nllb.embed_tokens = value

    # adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification
    # - removed classifier dropout
    # - use F.cross_entropy
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec)
AutoModelForSequenceClassification.register(
    NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification
)
AutoModelForSequenceClassification.register(
    NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification
)