sarinam's picture
Initial commit
dcd7182
# This code is based on the descriptions in https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/processing/PLDA_LDA.py
from pathlib import Path
from speechbrain.processing.PLDA_LDA import PLDA, StatObject_SB, Ndx, fast_PLDA_scoring
import numpy as np
import torch
class PLDAModel:
def __init__(self, train_embeddings, results_path: Path=None):
self.mean, self.F, self.Sigma = None, None, None
files_exist = False
if results_path and results_path.exists():
files_exist = self.load_parameters(results_path)
if not files_exist:
self._train_plda(train_embeddings)
self.save_parameters(results_path)
def compute_distance(self, enrollment_vectors, trial_vectors):
enrol_vecs = enrollment_vectors.cpu().numpy()
en_sets, en_s, en_stat0 = self._get_vector_stats(enrol_vecs, sg_tag='en')
en_stat = StatObject_SB(modelset=en_sets, segset=en_sets, start=en_s, stop=en_s, stat0=en_stat0,
stat1=enrol_vecs)
trial_vecs = trial_vectors.cpu().numpy()
te_sets, te_s, te_stat0 = self._get_vector_stats(trial_vecs, sg_tag='te')
te_stat = StatObject_SB(modelset=te_sets, segset=te_sets, start=te_s, stop=te_s, stat0=te_stat0,
stat1=trial_vecs)
ndx = Ndx(models=en_sets, testsegs=te_sets)
scores_plda = fast_PLDA_scoring(en_stat, te_stat, ndx, self.mean, self.F, self.Sigma)
return scores_plda.scoremat
def save_parameters(self, filename):
filename.parent.mkdir(parents=True, exist_ok=True)
np.save(filename / 'plda_mean.npy', self.mean)
np.save(filename / 'plda_F.npy', self.F)
np.save(filename / 'plda_Sigma.npy', self.Sigma)
def load_parameters(self, dir_path):
existing_files = [x.name for x in dir_path.glob('*')]
files_exist = True
if 'plda_mean.npy' in existing_files:
self.mean = np.load(dir_path / 'plda_mean.npy')
else:
files_exist = False
if 'plda_F.npy' in existing_files:
self.F = np.load(dir_path / 'plda_F.npy')
else:
files_exist = False
if 'plda_Sigma.npy' in existing_files:
self.Sigma = np.load(dir_path / 'plda_Sigma.npy')
else:
files_exist = False
return files_exist
def _train_plda(self, train_embeddings):
vectors = train_embeddings.speaker_vectors.to(torch.float64)
speakers = train_embeddings.speakers
modelset = np.array([f'md{speaker}' for speaker in speakers], dtype="|O")
segset, s, stat0 = self._get_vector_stats(vectors, sg_tag='sg')
xvectors_stat = StatObject_SB(modelset=modelset, segset=segset, start=s, stop=s, stat0=stat0,
stat1=vectors.cpu().numpy())
plda = PLDA(rank_f=100)
plda.plda(xvectors_stat)
self.mean = plda.mean
self.F = plda.F
self.Sigma = plda.Sigma
def _get_vector_stats(self, vectors, sg_tag='sg'):
N, dim = vectors.shape
segset = np.array([f'{sg_tag}{i}' for i in range(N)], dtype="|O")
s = np.array([None] * N)
stat0 = np.array([[1.0]] * N)
return segset, s, stat0