Hugo Flores commited on
Commit
260b46d
·
1 Parent(s): 3d08285

add a coarse2fine eval script

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -0
  2. scripts/exp/c2f_eval.py +100 -0
  3. setup.py +1 -0
requirements.txt CHANGED
@@ -27,3 +27,4 @@ tensorboardX
27
  gradio
28
  einops
29
  flash-attn
 
 
27
  gradio
28
  einops
29
  flash-attn
30
+ frechet_audio_distance
scripts/exp/c2f_eval.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from functools import partial
4
+
5
+ from frechet_audio_distance import FrechetAudioDistance
6
+ import pandas
7
+ import argbind
8
+ from tqdm import tqdm
9
+
10
+ import audiotools
11
+ from audiotools import AudioSignal
12
+
13
+ @argbind.bind(without_prefix=True)
14
+ def eval(
15
+ exp_dir: str = None,
16
+ baseline_key: str = "reconstructed",
17
+ audio_ext: str = ".wav",
18
+ ):
19
+ assert exp_dir is not None
20
+ exp_dir = Path(exp_dir)
21
+ assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
22
+
23
+ # set up our metrics
24
+ sisdr_loss = audiotools.metrics.distance.SISDRLoss()
25
+ stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
26
+ mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
27
+ frechet = FrechetAudioDistance(
28
+ use_pca=False,
29
+ use_activation=False,
30
+ verbose=False
31
+ )
32
+ visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
33
+
34
+ # figure out what conditions we have
35
+ conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
36
+
37
+ assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
38
+ conditions.remove(baseline_key)
39
+
40
+ print(f"Found {len(conditions)} conditions in {exp_dir}")
41
+ print(f"conditions: {conditions}")
42
+
43
+ baseline_dir = exp_dir / baseline_key
44
+ baseline_files = list(baseline_dir.glob(f"*{audio_ext}"))
45
+
46
+ metrics = []
47
+ for condition in conditions:
48
+ cond_dir = exp_dir / condition
49
+ cond_files = list(cond_dir.glob(f"*{audio_ext}"))
50
+
51
+ print(f"computing fad")
52
+ frechet_score = frechet.score(baseline_dir, cond_dir)
53
+
54
+ # make sure we have the same number of files
55
+ assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
56
+
57
+ pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
58
+ for baseline_file, cond_file in pbar:
59
+ assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
60
+ pbar.set_description(baseline_file.stem)
61
+
62
+ # load the files
63
+ baseline_sig = AudioSignal(baseline_file)
64
+ cond_sig = AudioSignal(cond_file)
65
+
66
+ # compute the metrics
67
+ try:
68
+ vsq = visqol(baseline_sig, cond_sig)
69
+ except:
70
+ vsq = 0.0
71
+ metrics.append({
72
+ "sisdr": sisdr_loss(baseline_sig, cond_sig),
73
+ "stft": stft_loss(baseline_sig, cond_sig),
74
+ "mel": mel_loss(baseline_sig, cond_sig),
75
+ "frechet": frechet_score,
76
+ "visqol": vsq,
77
+ "condition": condition,
78
+ "file": baseline_file.stem,
79
+ })
80
+
81
+ metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
82
+
83
+ stats = []
84
+ for mk in metric_keys:
85
+ stat = pandas.DataFrame(metrics)
86
+ stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
87
+ stats.append(stat)
88
+
89
+ stats = pandas.concat(stats, axis=1)
90
+ stats.to_csv(exp_dir / "metrics-stats.csv")
91
+
92
+ df = pandas.DataFrame(metrics)
93
+ df.to_csv(exp_dir / "metrics-all.csv", index=False)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ args = argbind.parse_args()
98
+
99
+ with argbind.scope(args):
100
+ eval()
setup.py CHANGED
@@ -38,5 +38,6 @@ setup(
38
  "torchmetrics>=0.7.3",
39
  "einops",
40
  "flash-attn",
 
41
  ],
42
  )
 
38
  "torchmetrics>=0.7.3",
39
  "einops",
40
  "flash-attn",
41
+ "frechet_audio_distance"
42
  ],
43
  )