from transformers import MistralPreTrainedModel, MistralModel, MistralConfig from typing import Dict from transformers.file_utils import ModelOutput from typing import List, Optional, Tuple, Union from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch import nn, Tensor from dataclasses import dataclass from torch import nn import torch from transformers.file_utils import ModelOutput import torch.nn.functional as F COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y) @dataclass class EncoderOutput(ModelOutput): loss: Optional[Tensor] = None class MistralModelEmbedding(MistralPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.model = MistralModel(config) self.dense_layer = nn.Linear( self.config.hidden_size, self.config.embedding_size, bias=False ) self.post_init() def encode(self, features): if features is None: return None psg_out = self.model.forward(**features,return_dict=True) logits = self.dense_layer(psg_out.last_hidden_state) input_ids = features['input_ids'] batch_size = input_ids.shape[0] sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] return pooled_logits def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, labels = None, margin = 1.0): q_reps = self.encode(query) p_reps = self.encode(passage) loss = None if labels is not None: distances = COSINE_DISTANCE(q_reps, p_reps) losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(margin - distances).pow(2)) loss = losses.mean() return EncoderOutput( loss=loss, )