File size: 3,880 Bytes
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5442f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f53272
5442f52
 
 
 
 
 
 
 
 
 
 
 
4090e0d
7bcf8d7
 
 
 
 
 
 
 
 
 
408e3fc
 
 
 
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7f5da
7bcf8d7
 
 
6f53272
 
 
7bcf8d7
 
 
 
 
 
 
 
 
 
 
a45002a
 
7bcf8d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json

from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder

ASR_SAMPLING_RATE = 16_000

ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        ASR_LANGUAGES[iso] = name

MODEL_ID = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)


# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename="decoding_config.json",
#     subfolder="mms-1b-all",
# )

# with open(lm_decoding_configfile) as f:
#     lm_decoding_config = json.loads(f.read())

# # allow language model decoding for "eng"

# decoding_config = lm_decoding_config["eng"]

# lm_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["lmfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
#     lexicon_file = hf_hub_download(
#         repo_id="facebook/mms-cclms",
#         filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
#         subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
#     )
    
# beam_search_decoder = ctc_decoder(
#     lexicon=lexicon_file,
#     tokens=token_file,
#     lm=lm_file,
#     nbest=1,
#     beam_size=500,
#     beam_size_token=50,
#     lm_weight=float(decoding_config["lmweight"]),
#     word_score=float(decoding_config["wordscore"]),
#     sil_score=float(decoding_config["silweight"]),
#     blank_token="<s>",
# )


def transcribe(
    audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
    if type(microphone) is dict:
        # HACK: microphone variable is a dict when running on examples
        microphone = microphone["name"]
    audio_fp = (
        file_upload if "upload" in str(audio_source or "").lower() else microphone
    )
    
    if audio_fp is None:
        return "ERROR: You have to either use the microphone or upload an audio file"
    
    audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]

    lang_code = lang.split()[0]
    processor.tokenizer.set_target_lang(lang_code)
    model.load_adapter(lang_code)

    inputs = processor(
        audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    if lang_code != "eng" or True:
        ids = torch.argmax(outputs, dim=-1)[0]
        transcription = processor.decode(ids)
    else:
        assert False 
        # beam_search_result = beam_search_decoder(outputs.to("cpu"))
        # transcription = " ".join(beam_search_result[0][0].words).strip()

    return transcription


ASR_EXAMPLES = [
    [None, "assets/english.mp3", None, "eng (English)"],
    # [None, "assets/tamil.mp3", None, "tam (Tamil)"],
    # [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]

ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model. 
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""