How to get the discrete codes correctly

#6
by dathudeptrai - opened

I am trying to get the discrete codes in the right way but seems the faiss index is wrong somehow ?.

import torch
from transformers import HubertModel
from datasets import load_dataset
import faiss
import numpy as np


def load_index(index_path):
    index: faiss.IndexPreTransform = faiss.read_index(index_path)

    #Make sure we have access to the ivf subindex. We'll need it to get the centroids (clusters)
    index_ivf = faiss.extract_index_ivf(index)

    return index, index_ivf

def get_centroids_index(xq, index, index_ivf):
    ''' Get centroids '''
    #Get OPQ matix
    opq_mt = faiss.downcast_VectorTransform(index.chain.at(0))
    #Apply pre-transform to query
    xq_t = opq_mt.apply_py(xq)
    #Get centroids C and distances DC on a pre-transformed index
    DC,C = index_ivf.quantizer.search(xq_t, 1)
    return DC, C


class Hubert2Unit(torch.nn.Module):
    def __init__(
        self,
        model_name="",
        kmean_path="",
        dtype=torch.float32,
        device="cuda:0",
    ):
        super(Hubert2Unit, self).__init__()
        self.model = HubertModel.from_pretrained("utter-project/mHuBERT-147").eval()
        self.model.to(dtype=torch.float32, device=device)  # trained with float32
        self.index, self.index_ivf = load_index("mhubert147_faiss.index")

    def zero_mean_unit_var_norm(
        self, input_values, wav_lengths, padding_value: float = 0.0
    ):
        """
        Every array in the list is normalized to have zero mean and unit variance
        """
        if wav_lengths is not None:
            normed_input_values = []

            for vector, length in zip(input_values, wav_lengths):
                normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7)
                if length < normed_slice.shape[0]:
                    normed_slice[length:] = padding_value

                normed_input_values.append(normed_slice)
        else:
            normed_input_values = [(x - x.mean()) / torch.sqrt(x.var() + 1e-7) for x in input_values]

        return torch.stack(normed_input_values, dim=0)

    def forward(self, wav, wav_lengths, do_normalize=True):
        with torch.no_grad():
            if do_normalize:
                input_values = self.zero_mean_unit_var_norm(wav, wav_lengths)
            else:
                input_values = wav.clone()
            # calcualte the attention_mask based on the wav_lengths_16k
            attention_mask = torch.arange(
                input_values.size(1), 
                device=input_values.device)[None, :] < wav_lengths[:, None]
            attention_mask = attention_mask.long()
            hidden_states = self.model(
                input_values, 
                attention_mask=attention_mask,
                output_hidden_states=True
            ).hidden_states[9]  # 9th layer of encoder block.
            hidden_states = hidden_states.reshape(hidden_states.size(0) * hidden_states.size(1), -1)
            hidden_states_cpu = hidden_states.float().detach().cpu().numpy()
            _, C = get_centroids_index(hidden_states_cpu, self.index, self.index_ivf)
            C = C.reshape(wav.shape[0], -1)
            n_unique_codes = len(np.unique(C))

        return C, n_unique_codes


ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
hubert = Hubert2Unit()
wav = ds[0]["audio"]["array"]
wav = torch.tensor(wav).to("cuda:0").unsqueeze(0).float()
lengths = torch.tensor([wav.shape[1]]).to("cuda:0")
C, n = hubert(wav, lengths)

@mzboito Thanks you in advance.

UTTER - Unified Transcription and Translation for Extended Reality org
edited Jul 3, 2024

Hi! Thanks for moving this thread to a dedicated issue.
The source of your issue is very likely the mismatch between the trained faiss index and the mhubert-147 model you are using.

This index here (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index) was trained on the output of the 9th layer of the 2nd iteration mHuBERT-147 (https://huggingface.co/utter-project/mHuBERT-147-base-2nd-iter), in order to generate targets for the mHuBERT-147 3rd iteration training.
If you input mHuBERT-147 (3rd iteration) features into it, it will not know how to cluster it very well, as it was trained on the output of a different model.

Basically, there are two settings in which you might be interested on faiss:

  1. If you want to continuous pretrain the mHuBERT-147 (3rd iteration), you should extract features for your speech using the 2nd iteration 9th layer, and then generate the indices using the faiss index you are using (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index). This should work.

  2. If you want to generate faiss discretization using as input the features from the 3rd iteration (mHuBERT-147), then you need to train a new index on your target data. You can check our training recommendations here: https://github.com/utter-project/mHuBERT-147-scripts

I hope it was understandable!

@mzboito thanks you so much, it seems to be corrected now. Just want to make sure everything is matching, the do_normalize = True and the hidden_states[9] are correct (instead of False or hidden_states[10]) ?. The reason is because it seems true that len(hidden_states) = 13 not 12.

UTTER - Unified Transcription and Translation for Extended Reality org

Yes, do_normalize=True for everything.
Regarding the layer: I did feature extraction on fairseq, not HF, so I'm not 100% sure, but it should be [9] if your length is of 13.

That is because the forward for feature extraction takes output_layer - 1: https://github.com/utter-project/fairseq/blob/3fb951a8658b81f09011fc2e9e5fe4c2e818a304/fairseq/models/hubert/hubert.py#L470

mzboito changed discussion status to closed

Sign up or log in to comment