jupyterjazz commited on
Commit
851aaca
·
1 Parent(s): 5418705

refactor: set task in lora class rather than xlm roberta

Browse files
Files changed (2) hide show
  1. modeling_lora.py +43 -7
  2. modeling_xlm_roberta.py +1 -24
modeling_lora.py CHANGED
@@ -1,18 +1,17 @@
1
  import math
2
  import os
 
3
  from functools import partial
4
- from typing import Iterator, Optional, Tuple, Union
5
 
 
6
  import torch
7
  import torch.nn.utils.parametrize as parametrize
8
  from torch import nn
9
  from torch.nn import Parameter
10
  from transformers import PretrainedConfig
11
 
12
- from .modeling_xlm_roberta import (
13
- XLMRobertaFlashConfig,
14
- XLMRobertaModel,
15
- )
16
 
17
 
18
  def initialized_weights(
@@ -231,7 +230,6 @@ class XLMRobertaLoRA(XLMRobertaModel):
231
  # By default, disable LoRA until it's specified which adapter/task to use
232
  self.current_task = None
233
 
234
-
235
  @property
236
  def main_params_trainable(self):
237
  return self._main_params_trainable
@@ -273,7 +271,8 @@ class XLMRobertaLoRA(XLMRobertaModel):
273
  pretrained_model_name_or_path, *model_args, **kwargs
274
  )
275
  else:
276
- torch.set_default_dtype(torch.float16)
 
277
  return cls(config)
278
 
279
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
@@ -327,3 +326,40 @@ class XLMRobertaLoRA(XLMRobertaModel):
327
  ):
328
  if "lora" in name or self.main_params_trainable:
329
  yield name, param
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import os
3
+ import warnings
4
  from functools import partial
5
+ from typing import Iterator, List, Optional, Tuple, Union
6
 
7
+ import numpy as np
8
  import torch
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
 
 
 
15
 
16
 
17
  def initialized_weights(
 
230
  # By default, disable LoRA until it's specified which adapter/task to use
231
  self.current_task = None
232
 
 
233
  @property
234
  def main_params_trainable(self):
235
  return self._main_params_trainable
 
271
  pretrained_model_name_or_path, *model_args, **kwargs
272
  )
273
  else:
274
+ dtype = config.torch_dtype if config.torch_dtype else torch.bfloat16
275
+ torch.set_default_dtype(dtype)
276
  return cls(config)
277
 
278
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
 
326
  ):
327
  if "lora" in name or self.main_params_trainable:
328
  yield name, param
329
+
330
+ @torch.inference_mode()
331
+ def encode(
332
+ self,
333
+ *args,
334
+ task: Optional[str] = None,
335
+ **kwargs,
336
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
337
+ """
338
+ Computes sentence embeddings
339
+
340
+ task(`str`, *optional*, defaults to None):
341
+ Specifies the task for which the encoding is intended. This
342
+ controls the use of specialized LoRA adapters that are tuned for specific tasks.
343
+ If provided, the corresponding LoRA adapter is enabled, enhancing the model's
344
+ performance for that task. If `None` or not provided, LoRA is disabled, and the
345
+ model uses its original, general-purpose weights.
346
+ """
347
+ lora_adapter_num = None
348
+ if self.config.lora_adaptations:
349
+ if task:
350
+ if task in self.config.lora_adaptations:
351
+ lora_adapter_num = self.config.lora_adaptations.index(task)
352
+ else:
353
+ raise ValueError(
354
+ f"Unsupported task '{task}'. "
355
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
+ )
357
+ else:
358
+ warnings.warn(
359
+ f"Task-specific embeddings are disabled. To enable, specify the `task` "
360
+ f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
361
+ category=UserWarning,
362
+ )
363
+ self.current_task = lora_adapter_num
364
+
365
+ return super().encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -452,7 +452,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
455
- task: Optional[str] = None,
456
  **tokenizer_kwargs,
457
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
458
  """
@@ -482,12 +481,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
482
  If set to true, returned vectors will have length 1. In that case, the
483
  faster dot-product (util.dot_score) instead of cosine similarity can
484
  be used.
485
- task(`str`, *optional*, defaults to None):
486
- Specifies the task for which the encoding is intended. This
487
- controls the use of specialized LoRA adapters that are tuned for specific tasks.
488
- If provided, the corresponding LoRA adapter is enabled, enhancing the model's
489
- performance for that task. If `None` or not provided, LoRA is disabled, and the
490
- model uses its original, general-purpose weights.
491
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
492
  Keyword arguments for the tokenizer
493
  Returns:
@@ -525,22 +518,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
525
  if device is not None:
526
  self.to(device)
527
 
528
- lora_adapter_num = None
529
- if self.config.lora_adaptations:
530
- if task:
531
- if task in self.config.lora_adaptations:
532
- lora_adapter_num = self.config.lora_adaptations.index(task)
533
- else:
534
- raise ValueError(
535
- f"Unsupported task '{task}'. "
536
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}.")
537
- else:
538
- logger.warning(
539
- f"Task-specific embeddings are disabled. To enable, specify the `task` "
540
- f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}"
541
- )
542
-
543
-
544
  permutation = np.argsort([-len(i) for i in sentences])
545
  inverse_permutation = np.argsort(permutation)
546
  sentences = [sentences[idx] for idx in permutation]
@@ -570,7 +547,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
570
  return_tensors='pt',
571
  **tokenizer_kwargs,
572
  ).to(self.device)
573
- token_embs = self.forward(**encoded_input, lora_adaptation=lora_adapter_num)[0]
574
 
575
  # Accumulate in fp32 to avoid overflow
576
  token_embs = token_embs.float()
 
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
 
455
  **tokenizer_kwargs,
456
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
  """
 
481
  If set to true, returned vectors will have length 1. In that case, the
482
  faster dot-product (util.dot_score) instead of cosine similarity can
483
  be used.
 
 
 
 
 
 
484
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
  Keyword arguments for the tokenizer
486
  Returns:
 
518
  if device is not None:
519
  self.to(device)
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  permutation = np.argsort([-len(i) for i in sentences])
522
  inverse_permutation = np.argsort(permutation)
523
  sentences = [sentences[idx] for idx in permutation]
 
547
  return_tensors='pt',
548
  **tokenizer_kwargs,
549
  ).to(self.device)
550
+ token_embs = self.forward(**encoded_input)[0]
551
 
552
  # Accumulate in fp32 to avoid overflow
553
  token_embs = token_embs.float()