Lakoc commited on
Commit
d7361d8
·
verified ·
1 Parent(s): 0558854

Upload JointCTCAttentionEncoderDecoder

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. generation.py +61 -0
  3. modeling_decred.py +10 -6
config.json CHANGED
@@ -5,7 +5,7 @@
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_decred.JointCTCAttentionEncoderDecoderConfig",
8
- "AutoModelForSpeechSeq2Seq": "modeling_decred.JointCTCAttentionEncoderDecoder"
9
  },
10
  "ctc_weight": 0.3,
11
  "decoder": {
 
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_decred.JointCTCAttentionEncoderDecoderConfig",
8
+ "AutoModel": "modeling_decred.JointCTCAttentionEncoderDecoder"
9
  },
10
  "ctc_weight": 0.3,
11
  "decoder": {
generation.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GenerationConfig
2
+
3
+
4
+ class GenerationConfigCustom(GenerationConfig):
5
+ def __init__(
6
+ self,
7
+ ctc_weight=0.0,
8
+ ctc_margin=0,
9
+ lm_weight=0,
10
+ lm_model=None,
11
+ space_token_id=-1,
12
+ eos_space_trick_weight=0,
13
+ apply_eos_space_trick=False,
14
+ **kwargs,
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.ctc_weight = ctc_weight
18
+ self.ctc_margin = ctc_margin
19
+ self.lm_weight = lm_weight
20
+ self.lm_model = lm_model
21
+ self.space_token_id = space_token_id
22
+ self.eos_space_trick_weight = eos_space_trick_weight
23
+ self.apply_eos_space_trick = apply_eos_space_trick
24
+
25
+ def update_from_string(self, update_str: str):
26
+ """
27
+ Updates attributes of this class with attributes from `update_str`.
28
+
29
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
30
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
31
+
32
+ The keys to change have to already exist in the config object.
33
+
34
+ Args:
35
+ update_str (`str`): String with attributes that should be updated for this class.
36
+
37
+ """
38
+
39
+ d = dict(x.split("=") for x in update_str.split(";"))
40
+ for k, v in d.items():
41
+ if not hasattr(self, k):
42
+ raise ValueError(f"key {k} isn't in the original config dict")
43
+
44
+ old_v = getattr(self, k)
45
+ if isinstance(old_v, bool):
46
+ if v.lower() in ["true", "1", "y", "yes"]:
47
+ v = True
48
+ elif v.lower() in ["false", "0", "n", "no"]:
49
+ v = False
50
+ else:
51
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
52
+ elif isinstance(old_v, int):
53
+ v = int(v)
54
+ elif isinstance(old_v, float):
55
+ v = float(v)
56
+ elif not isinstance(old_v, str):
57
+ raise ValueError(
58
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
59
+ )
60
+
61
+ setattr(self, k, v)
modeling_decred.py CHANGED
@@ -8,7 +8,6 @@ from transformers import (
8
  AutoConfig,
9
  AutoModelForCausalLM,
10
  AutoModelForSpeechSeq2Seq,
11
- GenerationConfig,
12
  LogitsProcessor,
13
  PretrainedConfig,
14
  PreTrainedModel,
@@ -28,6 +27,7 @@ from .auto_wrappers import CustomAutoModelForCTC
28
  from .configuration_decred import JointCTCAttentionEncoderDecoderConfig
29
  from .ctc_scorer import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
30
  from .embeddings import AdaptiveEmbedding, PositionalEmbedding
 
31
  from .multi_head_gpt2 import GPT2LMMultiHeadModel
32
 
33
  logger = logging.get_logger("transformers")
@@ -433,7 +433,7 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
433
 
434
  def _get_logits_processor(
435
  self,
436
- generation_config: GenerationConfig,
437
  input_ids_seq_length: int,
438
  encoder_input_ids: torch.LongTensor,
439
  prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
@@ -464,9 +464,13 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
464
  self.generation_config.ctc_margin,
465
  self.generation_config.ctc_weight,
466
  self.generation_config.num_beams,
467
- self.generation_config.space_token_id,
468
- self.generation_config.apply_eos_space_trick,
469
- self.generation_config.eos_space_trick_weight,
 
 
 
 
470
  )
471
  processors.append(self.ctc_rescorer)
472
  if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
@@ -524,7 +528,7 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
524
  def generate(
525
  self,
526
  inputs: Optional[torch.Tensor] = None,
527
- generation_config: Optional[GenerationConfig] = None,
528
  logits_processor: Optional[LogitsProcessorList] = None,
529
  stopping_criteria: Optional[StoppingCriteriaList] = None,
530
  prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 
8
  AutoConfig,
9
  AutoModelForCausalLM,
10
  AutoModelForSpeechSeq2Seq,
 
11
  LogitsProcessor,
12
  PretrainedConfig,
13
  PreTrainedModel,
 
27
  from .configuration_decred import JointCTCAttentionEncoderDecoderConfig
28
  from .ctc_scorer import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
29
  from .embeddings import AdaptiveEmbedding, PositionalEmbedding
30
+ from .generation import GenerationConfigCustom
31
  from .multi_head_gpt2 import GPT2LMMultiHeadModel
32
 
33
  logger = logging.get_logger("transformers")
 
433
 
434
  def _get_logits_processor(
435
  self,
436
+ generation_config: GenerationConfigCustom,
437
  input_ids_seq_length: int,
438
  encoder_input_ids: torch.LongTensor,
439
  prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
 
464
  self.generation_config.ctc_margin,
465
  self.generation_config.ctc_weight,
466
  self.generation_config.num_beams,
467
+ self.generation_config.space_token_id if hasattr(self.generation_config, "space_token_id") else None,
468
+ self.generation_config.apply_eos_space_trick
469
+ if hasattr(self.generation_config, "apply_eos_space_trick")
470
+ else False,
471
+ self.generation_config.eos_space_trick_weight
472
+ if hasattr(self.generation_config, "eos_space_trick_weight")
473
+ else 0.0,
474
  )
475
  processors.append(self.ctc_rescorer)
476
  if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
 
528
  def generate(
529
  self,
530
  inputs: Optional[torch.Tensor] = None,
531
+ generation_config: Optional[GenerationConfigCustom] = None,
532
  logits_processor: Optional[LogitsProcessorList] = None,
533
  stopping_criteria: Optional[StoppingCriteriaList] = None,
534
  prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,