Sentence Similarity
Transformers
Safetensors
English
mistral
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
text-generation-inference
text-embeddings-inference
Inference Endpoints
Update modeling_mistral_encoder.py
Browse files- modeling_mistral_encoder.py +0 -66
modeling_mistral_encoder.py
CHANGED
@@ -13,15 +13,6 @@ from .attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
16 |
-
def batch_to_device(batch, target_device: device):
|
17 |
-
"""
|
18 |
-
send a pytorch batch to a device (CPU/GPU)
|
19 |
-
"""
|
20 |
-
for key in batch:
|
21 |
-
if isinstance(batch[key], Tensor):
|
22 |
-
batch[key] = batch[key].to(target_device)
|
23 |
-
return batch
|
24 |
-
|
25 |
class ModifiedMistralAttention(MistralAttention):
|
26 |
|
27 |
def __init__(self, *args, **kwargs):
|
@@ -218,60 +209,3 @@ class MistralEncoderModel(MistralModel):
|
|
218 |
hidden_states=all_hidden_states,
|
219 |
attentions=all_self_attns,
|
220 |
)
|
221 |
-
|
222 |
-
def prepare_for_tokenization(self, text):
|
223 |
-
|
224 |
-
text = '[INST] ' + text.strip() + ' [/INST]'
|
225 |
-
# if self.pooling_mode == "eos_token":
|
226 |
-
# text = text.strip() + ' </s>'
|
227 |
-
return text
|
228 |
-
|
229 |
-
def tokenize(self, texts):
|
230 |
-
# return self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
|
231 |
-
|
232 |
-
texts_2 = []
|
233 |
-
original_texts = []
|
234 |
-
for text in texts:
|
235 |
-
t = text.split("!@#$%^&*()")
|
236 |
-
texts_2.append(t[1])
|
237 |
-
original_texts.append("".join(t))
|
238 |
-
|
239 |
-
original = self.tokenizer(original_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
|
240 |
-
embed_mask = None
|
241 |
-
for t_i, t in enumerate(texts_2):
|
242 |
-
ids = self.tokenizer([t], return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False)
|
243 |
-
if embed_mask is None:
|
244 |
-
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
245 |
-
if len(ids["input_ids"][0]) > 0:
|
246 |
-
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
247 |
-
embed_mask = e_m.unsqueeze(0)
|
248 |
-
else:
|
249 |
-
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
250 |
-
if len(ids["input_ids"][0]) > 0:
|
251 |
-
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
252 |
-
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
|
253 |
-
|
254 |
-
original["embed_mask"] = embed_mask
|
255 |
-
return original
|
256 |
-
|
257 |
-
def _skip_instruction(self, sentence_feature):
|
258 |
-
assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape
|
259 |
-
sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
|
260 |
-
|
261 |
-
def _encode(self, sentences_batch, device, convert_to_numpy, multiprocessing=False):
|
262 |
-
|
263 |
-
if multiprocessing:
|
264 |
-
rank = mp.current_process()._identity[0]
|
265 |
-
if device is None and torch.cuda.is_available():
|
266 |
-
device = f"cuda:{rank % torch.cuda.device_count()}"
|
267 |
-
|
268 |
-
self.to(device)
|
269 |
-
features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])
|
270 |
-
features = batch_to_device(features, device)
|
271 |
-
|
272 |
-
with torch.no_grad():
|
273 |
-
embeddings = self.forward(features)
|
274 |
-
embeddings = embeddings.detach()
|
275 |
-
embeddings = embeddings.cpu()
|
276 |
-
|
277 |
-
return embeddings
|
|
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class ModifiedMistralAttention(MistralAttention):
|
17 |
|
18 |
def __init__(self, *args, **kwargs):
|
|
|
209 |
hidden_states=all_hidden_states,
|
210 |
attentions=all_self_attns,
|
211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|