bofenghuang
commited on
Commit
·
f1dd1bd
1
Parent(s):
16000ae
updt example
Browse files
README.md
CHANGED
@@ -130,23 +130,26 @@ import torchaudio
|
|
130 |
|
131 |
from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
|
132 |
|
133 |
-
|
|
|
|
|
134 |
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained("bhuang/asr-wav2vec2-french")
|
|
|
135 |
|
136 |
wav_path = "example.wav" # path to your audio file
|
137 |
waveform, sample_rate = torchaudio.load(wav_path)
|
138 |
waveform = waveform.squeeze(axis=0) # mono
|
139 |
|
140 |
# resample
|
141 |
-
if sample_rate !=
|
142 |
-
resampler = torchaudio.transforms.Resample(sample_rate,
|
143 |
waveform = resampler(waveform)
|
144 |
|
145 |
# normalize
|
146 |
-
input_dict = processor_with_lm(waveform, sampling_rate=
|
147 |
|
148 |
with torch.inference_mode():
|
149 |
-
logits = model(input_dict.input_values.to(
|
150 |
|
151 |
predicted_sentence = processor_with_lm.batch_decode(logits.cpu().numpy()).text[0]
|
152 |
```
|
@@ -159,23 +162,26 @@ import torchaudio
|
|
159 |
|
160 |
from transformers import AutoModelForCTC, Wav2Vec2Processor
|
161 |
|
162 |
-
|
|
|
|
|
163 |
processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french")
|
|
|
164 |
|
165 |
wav_path = "example.wav" # path to your audio file
|
166 |
waveform, sample_rate = torchaudio.load(wav_path)
|
167 |
waveform = waveform.squeeze(axis=0) # mono
|
168 |
|
169 |
# resample
|
170 |
-
if sample_rate !=
|
171 |
-
resampler = torchaudio.transforms.Resample(sample_rate,
|
172 |
waveform = resampler(waveform)
|
173 |
|
174 |
# normalize
|
175 |
-
input_dict = processor(waveform, sampling_rate=
|
176 |
|
177 |
with torch.inference_mode():
|
178 |
-
logits = model(input_dict.input_values.to(
|
179 |
|
180 |
# decode
|
181 |
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
130 |
|
131 |
from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
|
132 |
|
133 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
134 |
+
|
135 |
+
model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device)
|
136 |
processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained("bhuang/asr-wav2vec2-french")
|
137 |
+
model_sample_rate = processor_with_lm.feature_extractor.sampling_rate
|
138 |
|
139 |
wav_path = "example.wav" # path to your audio file
|
140 |
waveform, sample_rate = torchaudio.load(wav_path)
|
141 |
waveform = waveform.squeeze(axis=0) # mono
|
142 |
|
143 |
# resample
|
144 |
+
if sample_rate != model_sample_rate:
|
145 |
+
resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate)
|
146 |
waveform = resampler(waveform)
|
147 |
|
148 |
# normalize
|
149 |
+
input_dict = processor_with_lm(waveform, sampling_rate=model_sample_rate, return_tensors="pt")
|
150 |
|
151 |
with torch.inference_mode():
|
152 |
+
logits = model(input_dict.input_values.to(device)).logits
|
153 |
|
154 |
predicted_sentence = processor_with_lm.batch_decode(logits.cpu().numpy()).text[0]
|
155 |
```
|
|
|
162 |
|
163 |
from transformers import AutoModelForCTC, Wav2Vec2Processor
|
164 |
|
165 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
166 |
+
|
167 |
+
model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device)
|
168 |
processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french")
|
169 |
+
model_sample_rate = processor.feature_extractor.sampling_rate
|
170 |
|
171 |
wav_path = "example.wav" # path to your audio file
|
172 |
waveform, sample_rate = torchaudio.load(wav_path)
|
173 |
waveform = waveform.squeeze(axis=0) # mono
|
174 |
|
175 |
# resample
|
176 |
+
if sample_rate != model_sample_rate:
|
177 |
+
resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate)
|
178 |
waveform = resampler(waveform)
|
179 |
|
180 |
# normalize
|
181 |
+
input_dict = processor(waveform, sampling_rate=model_sample_rate, return_tensors="pt")
|
182 |
|
183 |
with torch.inference_mode():
|
184 |
+
logits = model(input_dict.input_values.to(device)).logits
|
185 |
|
186 |
# decode
|
187 |
predicted_ids = torch.argmax(logits, dim=-1)
|