Spaces:
Sleeping
Sleeping
Hugo Flores
commited on
Commit
·
260b46d
1
Parent(s):
3d08285
add a coarse2fine eval script
Browse files- requirements.txt +1 -0
- scripts/exp/c2f_eval.py +100 -0
- setup.py +1 -0
requirements.txt
CHANGED
@@ -27,3 +27,4 @@ tensorboardX
|
|
27 |
gradio
|
28 |
einops
|
29 |
flash-attn
|
|
|
|
27 |
gradio
|
28 |
einops
|
29 |
flash-attn
|
30 |
+
frechet_audio_distance
|
scripts/exp/c2f_eval.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from frechet_audio_distance import FrechetAudioDistance
|
6 |
+
import pandas
|
7 |
+
import argbind
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import audiotools
|
11 |
+
from audiotools import AudioSignal
|
12 |
+
|
13 |
+
@argbind.bind(without_prefix=True)
|
14 |
+
def eval(
|
15 |
+
exp_dir: str = None,
|
16 |
+
baseline_key: str = "reconstructed",
|
17 |
+
audio_ext: str = ".wav",
|
18 |
+
):
|
19 |
+
assert exp_dir is not None
|
20 |
+
exp_dir = Path(exp_dir)
|
21 |
+
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
22 |
+
|
23 |
+
# set up our metrics
|
24 |
+
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
25 |
+
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
26 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
27 |
+
frechet = FrechetAudioDistance(
|
28 |
+
use_pca=False,
|
29 |
+
use_activation=False,
|
30 |
+
verbose=False
|
31 |
+
)
|
32 |
+
visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
|
33 |
+
|
34 |
+
# figure out what conditions we have
|
35 |
+
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
36 |
+
|
37 |
+
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
|
38 |
+
conditions.remove(baseline_key)
|
39 |
+
|
40 |
+
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
41 |
+
print(f"conditions: {conditions}")
|
42 |
+
|
43 |
+
baseline_dir = exp_dir / baseline_key
|
44 |
+
baseline_files = list(baseline_dir.glob(f"*{audio_ext}"))
|
45 |
+
|
46 |
+
metrics = []
|
47 |
+
for condition in conditions:
|
48 |
+
cond_dir = exp_dir / condition
|
49 |
+
cond_files = list(cond_dir.glob(f"*{audio_ext}"))
|
50 |
+
|
51 |
+
print(f"computing fad")
|
52 |
+
frechet_score = frechet.score(baseline_dir, cond_dir)
|
53 |
+
|
54 |
+
# make sure we have the same number of files
|
55 |
+
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
56 |
+
|
57 |
+
pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
|
58 |
+
for baseline_file, cond_file in pbar:
|
59 |
+
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
60 |
+
pbar.set_description(baseline_file.stem)
|
61 |
+
|
62 |
+
# load the files
|
63 |
+
baseline_sig = AudioSignal(baseline_file)
|
64 |
+
cond_sig = AudioSignal(cond_file)
|
65 |
+
|
66 |
+
# compute the metrics
|
67 |
+
try:
|
68 |
+
vsq = visqol(baseline_sig, cond_sig)
|
69 |
+
except:
|
70 |
+
vsq = 0.0
|
71 |
+
metrics.append({
|
72 |
+
"sisdr": sisdr_loss(baseline_sig, cond_sig),
|
73 |
+
"stft": stft_loss(baseline_sig, cond_sig),
|
74 |
+
"mel": mel_loss(baseline_sig, cond_sig),
|
75 |
+
"frechet": frechet_score,
|
76 |
+
"visqol": vsq,
|
77 |
+
"condition": condition,
|
78 |
+
"file": baseline_file.stem,
|
79 |
+
})
|
80 |
+
|
81 |
+
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
82 |
+
|
83 |
+
stats = []
|
84 |
+
for mk in metric_keys:
|
85 |
+
stat = pandas.DataFrame(metrics)
|
86 |
+
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
|
87 |
+
stats.append(stat)
|
88 |
+
|
89 |
+
stats = pandas.concat(stats, axis=1)
|
90 |
+
stats.to_csv(exp_dir / "metrics-stats.csv")
|
91 |
+
|
92 |
+
df = pandas.DataFrame(metrics)
|
93 |
+
df.to_csv(exp_dir / "metrics-all.csv", index=False)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
args = argbind.parse_args()
|
98 |
+
|
99 |
+
with argbind.scope(args):
|
100 |
+
eval()
|
setup.py
CHANGED
@@ -38,5 +38,6 @@ setup(
|
|
38 |
"torchmetrics>=0.7.3",
|
39 |
"einops",
|
40 |
"flash-attn",
|
|
|
41 |
],
|
42 |
)
|
|
|
38 |
"torchmetrics>=0.7.3",
|
39 |
"einops",
|
40 |
"flash-attn",
|
41 |
+
"frechet_audio_distance"
|
42 |
],
|
43 |
)
|