File size: 3,880 Bytes
7bcf8d7 5442f52 6f53272 5442f52 4090e0d 7bcf8d7 408e3fc 7bcf8d7 1b7f5da 7bcf8d7 6f53272 7bcf8d7 a45002a 7bcf8d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso] = name
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename="decoding_config.json",
# subfolder="mms-1b-all",
# )
# with open(lm_decoding_configfile) as f:
# lm_decoding_config = json.loads(f.read())
# # allow language model decoding for "eng"
# decoding_config = lm_decoding_config["eng"]
# lm_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lmfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
# lexicon_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
# )
# beam_search_decoder = ctc_decoder(
# lexicon=lexicon_file,
# tokens=token_file,
# lm=lm_file,
# nbest=1,
# beam_size=500,
# beam_size_token=50,
# lm_weight=float(decoding_config["lmweight"]),
# word_score=float(decoding_config["wordscore"]),
# sil_score=float(decoding_config["silweight"]),
# blank_token="<s>",
# )
def transcribe(
audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
if type(microphone) is dict:
# HACK: microphone variable is a dict when running on examples
microphone = microphone["name"]
audio_fp = (
file_upload if "upload" in str(audio_source or "").lower() else microphone
)
if audio_fp is None:
return "ERROR: You have to either use the microphone or upload an audio file"
audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
if lang_code != "eng" or True:
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
else:
assert False
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
# transcription = " ".join(beam_search_result[0][0].words).strip()
return transcription
ASR_EXAMPLES = [
[None, "assets/english.mp3", None, "eng (English)"],
# [None, "assets/tamil.mp3", None, "tam (Tamil)"],
# [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]
ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""
|