vaibhavad commited on
Commit
445e7cb
·
verified ·
1 Parent(s): f9a0043

Update modeling_mistral_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_mistral_encoder.py +0 -66
modeling_mistral_encoder.py CHANGED
@@ -13,15 +13,6 @@ from .attn_mask_utils import _prepare_4d_causal_attention_mask
13
 
14
  logger = logging.get_logger(__name__)
15
 
16
- def batch_to_device(batch, target_device: device):
17
- """
18
- send a pytorch batch to a device (CPU/GPU)
19
- """
20
- for key in batch:
21
- if isinstance(batch[key], Tensor):
22
- batch[key] = batch[key].to(target_device)
23
- return batch
24
-
25
  class ModifiedMistralAttention(MistralAttention):
26
 
27
  def __init__(self, *args, **kwargs):
@@ -218,60 +209,3 @@ class MistralEncoderModel(MistralModel):
218
  hidden_states=all_hidden_states,
219
  attentions=all_self_attns,
220
  )
221
-
222
- def prepare_for_tokenization(self, text):
223
-
224
- text = '[INST] ' + text.strip() + ' [/INST]'
225
- # if self.pooling_mode == "eos_token":
226
- # text = text.strip() + ' </s>'
227
- return text
228
-
229
- def tokenize(self, texts):
230
- # return self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
231
-
232
- texts_2 = []
233
- original_texts = []
234
- for text in texts:
235
- t = text.split("!@#$%^&*()")
236
- texts_2.append(t[1])
237
- original_texts.append("".join(t))
238
-
239
- original = self.tokenizer(original_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
240
- embed_mask = None
241
- for t_i, t in enumerate(texts_2):
242
- ids = self.tokenizer([t], return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False)
243
- if embed_mask is None:
244
- e_m = torch.zeros_like(original["attention_mask"][t_i])
245
- if len(ids["input_ids"][0]) > 0:
246
- e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
247
- embed_mask = e_m.unsqueeze(0)
248
- else:
249
- e_m = torch.zeros_like(original["attention_mask"][t_i])
250
- if len(ids["input_ids"][0]) > 0:
251
- e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
252
- embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
253
-
254
- original["embed_mask"] = embed_mask
255
- return original
256
-
257
- def _skip_instruction(self, sentence_feature):
258
- assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape
259
- sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
260
-
261
- def _encode(self, sentences_batch, device, convert_to_numpy, multiprocessing=False):
262
-
263
- if multiprocessing:
264
- rank = mp.current_process()._identity[0]
265
- if device is None and torch.cuda.is_available():
266
- device = f"cuda:{rank % torch.cuda.device_count()}"
267
-
268
- self.to(device)
269
- features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])
270
- features = batch_to_device(features, device)
271
-
272
- with torch.no_grad():
273
- embeddings = self.forward(features)
274
- embeddings = embeddings.detach()
275
- embeddings = embeddings.cpu()
276
-
277
- return embeddings
 
13
 
14
  logger = logging.get_logger(__name__)
15
 
 
 
 
 
 
 
 
 
 
16
  class ModifiedMistralAttention(MistralAttention):
17
 
18
  def __init__(self, *args, **kwargs):
 
209
  hidden_states=all_hidden_states,
210
  attentions=all_self_attns,
211
  )