import matplotlib import numpy import soundfile as sf from matplotlib import pyplot as plt from matplotlib import cm matplotlib.use("tkAgg") from sklearn.manifold import TSNE from sklearn.decomposition import PCA from tqdm import tqdm from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor class Visualizer: def __init__(self, sr=48000, device="cpu"): """ Args: sr: The sampling rate of the audios you want to visualize. """ self.tsne = TSNE(n_jobs=-1) self.pca = PCA(n_components=2) self.pros_cond_ext = ProsodicConditionExtractor(sr=sr, device=device) self.sr = sr def visualize_speaker_embeddings(self, label_to_filepaths, title_of_plot, save_file_path=None, include_pca=True, legend=True): label_list = list() embedding_list = list() for label in tqdm(label_to_filepaths): for filepath in tqdm(label_to_filepaths[label]): wave, sr = sf.read(filepath) if len(wave) / sr < 1: continue if self.sr != sr: print("One of the Audios you included doesn't match the sampling rate of this visualizer object, " "creating a new condition extractor. Results will be correct, but if there are too many cases " "of changing samplingrate, this will run very slowly.") self.pros_cond_ext = ProsodicConditionExtractor(sr=sr) self.sr = sr embedding_list.append(self.pros_cond_ext.extract_condition_from_reference_wave(wave).squeeze().numpy()) label_list.append(label) embeddings_as_array = numpy.array(embedding_list) dimensionality_reduced_embeddings_tsne = self.tsne.fit_transform(embeddings_as_array) self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_tsne, labels=label_list, title=title_of_plot + " t-SNE" if include_pca else title_of_plot, save_file_path=save_file_path, legend=legend) if include_pca: dimensionality_reduced_embeddings_pca = self.pca.fit_transform(embeddings_as_array) self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_pca, labels=label_list, title=title_of_plot + " PCA", save_file_path=save_file_path, legend=legend) def _plot_embeddings(self, projected_data, labels, title, save_file_path, legend): colors = cm.gist_rainbow(numpy.linspace(0, 1, len(set(labels)))) label_to_color = dict() for index, label in enumerate(list(set(labels))): label_to_color[label] = colors[index] labels_to_points_x = dict() labels_to_points_y = dict() for label in labels: labels_to_points_x[label] = list() labels_to_points_y[label] = list() for index, label in enumerate(labels): labels_to_points_x[label].append(projected_data[index][0]) labels_to_points_y[label].append(projected_data[index][1]) fig, ax = plt.subplots() for label in set(labels): x = numpy.array(labels_to_points_x[label]) y = numpy.array(labels_to_points_y[label]) ax.scatter(x=x, y=y, c=label_to_color[label], label=label, alpha=0.9) if legend: ax.legend() fig.tight_layout() ax.axis('off') fig.subplots_adjust(top=0.9, bottom=0.0, right=1.0, left=0.0) ax.set_title(title) if save_file_path is not None: plt.savefig(save_file_path) else: plt.show() plt.close()