Porjaz commited on
Commit
171bc27
·
verified ·
1 Parent(s): 23ab61f

Upload custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +105 -0
custom_interface.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.inference.interfaces import Pretrained
3
+
4
+
5
+ class ASR(Pretrained):
6
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
+ language-id, emotion recognition, keyword spotting, etc).
8
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
9
+ are defined in the yaml file. If you want to
10
+ convert the predicted index into a corresponding text label, please
11
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
12
+ within the yaml.
13
+ The class can be used either to run only the encoder (encode_batch()) to
14
+ extract embeddings or to run a classification step (classify_batch()).
15
+ ```
16
+ Example
17
+ -------
18
+ >>> import torchaudio
19
+ >>> from speechbrain.pretrained import EncoderClassifier
20
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
21
+ >>> tmpdir = getfixture("tmpdir")
22
+ >>> classifier = EncoderClassifier.from_hparams(
23
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
24
+ ... savedir=tmpdir,
25
+ ... )
26
+ >>> # Compute embeddings
27
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
28
+ >>> embeddings = classifier.encode_batch(signal)
29
+ >>> # Classification
30
+ >>> prediction = classifier .classify_batch(signal)
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+
36
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
37
+ """Encodes the input audio into a single vector embedding.
38
+ The waveforms should already be in the model's desired format.
39
+ You can call:
40
+ ``normalized = <this>.normalizer(signal, sample_rate)``
41
+ to get a correctly converted signal in most cases.
42
+ Arguments
43
+ ---------
44
+ wavs : torch.tensor
45
+ Batch of waveforms [batch, time, channels] or [batch, time]
46
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
47
+ wav_lens : torch.tensor
48
+ Lengths of the waveforms relative to the longest one in the
49
+ batch, tensor of shape [batch]. The longest one should have
50
+ relative length 1.0 and others len(waveform) / max_length.
51
+ Used for ignoring padding.
52
+ normalize : bool
53
+ If True, it normalizes the embeddings with the statistics
54
+ contained in mean_var_norm_emb.
55
+ Returns
56
+ -------
57
+ torch.tensor
58
+ The encoded batch
59
+ """
60
+ batch = batch.to(self.device)
61
+ sig, self.sig_lens = batch.sig
62
+ tokens_bos, _ = batch.tokens_bos
63
+ sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device)
64
+
65
+ # Forward pass
66
+ encoded_outputs = self.modules.encoder_w2v2(sig.detach())
67
+ embedded_tokens = self.modules.embedding(tokens_bos)
68
+ decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
69
+
70
+ # Output layer for seq2seq log-probabilities
71
+ logits = self.modules.seq_lin(decoder_outputs)
72
+ predictions = {"seq_logprobs": self.hparams.log_softmax(logits)}
73
+ predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens)
74
+
75
+ return predictions
76
+
77
+
78
+ def classify_file(self, path):
79
+ """Classifies the given audiofile into the given set of labels.
80
+ Arguments
81
+ ---------
82
+ path : str
83
+ Path to audio file to classify.
84
+ Returns
85
+ -------
86
+ out_prob
87
+ The log posterior probabilities of each class ([batch, N_class])
88
+ score:
89
+ It is the value of the log-posterior for the best class ([batch,])
90
+ index
91
+ The indexes of the best class ([batch,])
92
+ text_lab:
93
+ List with the text labels corresponding to the indexes.
94
+ (label encoder should be provided).
95
+ """
96
+ waveform = self.load_audio(path)
97
+ # Fake a batch:
98
+ batch = waveform.unsqueeze(0)
99
+ rel_length = torch.tensor([1.0])
100
+ outputs = self.encode_batch(batch, rel_length)["tokens"]
101
+
102
+ return outputs
103
+
104
+ def forward(self, wavs, wav_lens=None):
105
+ return self.encode_batch(wavs=wavs, wav_lens=wav_lens)