Porjaz commited on
Commit
747d7b4
·
verified ·
1 Parent(s): b8a6501

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +6 -72
custom_interface.py CHANGED
@@ -3,67 +3,18 @@ from speechbrain.inference.interfaces import Pretrained
3
 
4
 
5
  class ASR(Pretrained):
6
- """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
- language-id, emotion recognition, keyword spotting, etc).
8
- The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
9
- are defined in the yaml file. If you want to
10
- convert the predicted index into a corresponding text label, please
11
- provide the path of the label_encoder in a variable called 'lab_encoder_file'
12
- within the yaml.
13
- The class can be used either to run only the encoder (encode_batch()) to
14
- extract embeddings or to run a classification step (classify_batch()).
15
- ```
16
- Example
17
- -------
18
- >>> import torchaudio
19
- >>> from speechbrain.pretrained import EncoderClassifier
20
- >>> # Model is downloaded from the speechbrain HuggingFace repo
21
- >>> tmpdir = getfixture("tmpdir")
22
- >>> classifier = EncoderClassifier.from_hparams(
23
- ... source="speechbrain/spkrec-ecapa-voxceleb",
24
- ... savedir=tmpdir,
25
- ... )
26
- >>> # Compute embeddings
27
- >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
28
- >>> embeddings = classifier.encode_batch(signal)
29
- >>> # Classification
30
- >>> prediction = classifier .classify_batch(signal)
31
- """
32
-
33
  def __init__(self, *args, **kwargs):
34
  super().__init__(*args, **kwargs)
35
 
36
  def encode_batch(self, wavs, wav_lens=None, normalize=False):
37
- """Encodes the input audio into a single vector embedding.
38
- The waveforms should already be in the model's desired format.
39
- You can call:
40
- ``normalized = <this>.normalizer(signal, sample_rate)``
41
- to get a correctly converted signal in most cases.
42
- Arguments
43
- ---------
44
- wavs : torch.tensor
45
- Batch of waveforms [batch, time, channels] or [batch, time]
46
- depending on the model. Make sure the sample rate is fs=16000 Hz.
47
- wav_lens : torch.tensor
48
- Lengths of the waveforms relative to the longest one in the
49
- batch, tensor of shape [batch]. The longest one should have
50
- relative length 1.0 and others len(waveform) / max_length.
51
- Used for ignoring padding.
52
- normalize : bool
53
- If True, it normalizes the embeddings with the statistics
54
- contained in mean_var_norm_emb.
55
- Returns
56
- -------
57
- torch.tensor
58
- The encoded batch
59
- """
60
- batch = batch.to(self.device)
61
- sig, self.sig_lens = batch.sig
62
- tokens_bos, _ = batch.tokens_bos
63
- sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device)
64
 
65
  # Forward pass
66
- encoded_outputs = self.modules.encoder_w2v2(sig.detach())
 
 
 
67
  embedded_tokens = self.modules.embedding(tokens_bos)
68
  decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
69
 
@@ -76,23 +27,6 @@ class ASR(Pretrained):
76
 
77
 
78
  def classify_file(self, path):
79
- """Classifies the given audiofile into the given set of labels.
80
- Arguments
81
- ---------
82
- path : str
83
- Path to audio file to classify.
84
- Returns
85
- -------
86
- out_prob
87
- The log posterior probabilities of each class ([batch, N_class])
88
- score:
89
- It is the value of the log-posterior for the best class ([batch,])
90
- index
91
- The indexes of the best class ([batch,])
92
- text_lab:
93
- List with the text labels corresponding to the indexes.
94
- (label encoder should be provided).
95
- """
96
  waveform = self.load_audio(path)
97
  # Fake a batch:
98
  batch = waveform.unsqueeze(0)
 
3
 
4
 
5
  class ASR(Pretrained):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def __init__(self, *args, **kwargs):
7
  super().__init__(*args, **kwargs)
8
 
9
  def encode_batch(self, wavs, wav_lens=None, normalize=False):
10
+ wavs = wavs.to(self.device)
11
+ wav_lens = wav_lens.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Forward pass
14
+ encoded_outputs = self.modules.encoder_w2v2(wavs.detach())
15
+ # append
16
+ tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
17
+ print(tokens_bos.size())
18
  embedded_tokens = self.modules.embedding(tokens_bos)
19
  decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
20
 
 
27
 
28
 
29
  def classify_file(self, path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  waveform = self.load_audio(path)
31
  # Fake a batch:
32
  batch = waveform.unsqueeze(0)