Hugo Flores Garcia commited on
Commit
9fbfaa6
·
1 Parent(s): 57047e5
conf/interface-c2f-exp.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: /runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth
2
+ Interface.coarse2fine_ckpt: runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth
3
+ Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
4
+ Interface.coarse_chunk_size_s: 5
5
+ Interface.coarse2fine_chunk_size_s: 3
scripts/exp/{c2f_eval.py → eval.py} RENAMED
@@ -13,7 +13,7 @@ from audiotools import AudioSignal
13
  @argbind.bind(without_prefix=True)
14
  def eval(
15
  exp_dir: str = None,
16
- baseline_key: str = "reconstructed",
17
  audio_ext: str = ".wav",
18
  ):
19
  assert exp_dir is not None
@@ -27,7 +27,7 @@ def eval(
27
  frechet = FrechetAudioDistance(
28
  use_pca=False,
29
  use_activation=False,
30
- verbose=False
31
  )
32
  visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
33
 
@@ -48,7 +48,7 @@ def eval(
48
  cond_dir = exp_dir / condition
49
  cond_files = list(cond_dir.glob(f"*{audio_ext}"))
50
 
51
- print(f"computing fad")
52
  frechet_score = frechet.score(baseline_dir, cond_dir)
53
 
54
  # make sure we have the same number of files
 
13
  @argbind.bind(without_prefix=True)
14
  def eval(
15
  exp_dir: str = None,
16
+ baseline_key: str = "baseline",
17
  audio_ext: str = ".wav",
18
  ):
19
  assert exp_dir is not None
 
27
  frechet = FrechetAudioDistance(
28
  use_pca=False,
29
  use_activation=False,
30
+ verbose=True
31
  )
32
  visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
33
 
 
48
  cond_dir = exp_dir / condition
49
  cond_files = list(cond_dir.glob(f"*{audio_ext}"))
50
 
51
+ print(f"computing fad for {baseline_dir} and {cond_dir}")
52
  frechet_score = frechet.score(baseline_dir, cond_dir)
53
 
54
  # make sure we have the same number of files
scripts/utils/vamp_folder.py CHANGED
@@ -5,17 +5,30 @@ from tqdm import tqdm
5
  import torch
6
 
7
  from vampnet.interface import Interface
 
8
 
9
  Interface = argbind.bind(Interface)
10
 
 
 
 
 
 
 
 
 
 
 
11
  def baseline(sig, interface):
12
  return sig
13
 
 
14
  def reconstructed(sig, interface):
15
  return interface.to_signal(
16
  interface.encode(sig)
17
  )
18
 
 
19
  def coarse2fine(sig, interface):
20
  z = interface.encode(sig)
21
  z = z[:, :interface.c2f.n_conditioning_codebooks, :]
@@ -23,6 +36,18 @@ def coarse2fine(sig, interface):
23
  z = interface.coarse_to_fine(z)
24
  return interface.to_signal(z)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def one_codebook(sig, interface):
27
  z = interface.encode(sig)
28
 
@@ -38,6 +63,7 @@ def one_codebook(sig, interface):
38
 
39
  return interface.to_signal(zv)
40
 
 
41
  def four_codebooks_downsampled_4x(sig, interface):
42
  zv = interface.coarse_vamp_v2(
43
  sig, downsample_factor=4
@@ -45,6 +71,7 @@ def four_codebooks_downsampled_4x(sig, interface):
45
  zv = interface.coarse_to_fine(zv)
46
  return interface.to_signal(zv)
47
 
 
48
  def two_codebooks_downsampled_4x(sig, interface):
49
  z = interface.encode(sig)
50
 
@@ -60,6 +87,7 @@ def two_codebooks_downsampled_4x(sig, interface):
60
 
61
  return interface.to_signal(zv)
62
 
 
63
  def four_codebooks_downsampled_8x(sig, interface):
64
  zv = interface.coarse_vamp_v2(
65
  sig, downsample_factor=8
@@ -68,7 +96,7 @@ def four_codebooks_downsampled_8x(sig, interface):
68
  return interface.to_signal(zv)
69
 
70
 
71
- SAMPLE_CONDS ={
72
  "baseline": baseline,
73
  "reconstructed": reconstructed,
74
  "coarse2fine": coarse2fine,
@@ -78,6 +106,12 @@ SAMPLE_CONDS ={
78
  "four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
79
  }
80
 
 
 
 
 
 
 
81
 
82
  @argbind.bind(without_prefix=True)
83
  def main(
@@ -86,7 +120,10 @@ def main(
86
  ],
87
  output_dir: str = "./samples",
88
  max_excerpts: int = 5000,
 
 
89
  ):
 
90
  interface = Interface()
91
 
92
  output_dir = Path(output_dir)
@@ -102,6 +139,8 @@ def main(
102
  without_replacement=True,
103
  )
104
 
 
 
105
  for i in tqdm(range(max_excerpts)):
106
  sig = dataset[i]["signal"]
107
 
 
5
  import torch
6
 
7
  from vampnet.interface import Interface
8
+ import audiotools as at
9
 
10
  Interface = argbind.bind(Interface)
11
 
12
+ # condition wrapper for printing
13
+ def condition(cond):
14
+ def wrapper(sig, interface):
15
+ print(f"Condition: {cond.__name__}")
16
+ sig = cond(sig, interface)
17
+ print(f"Condition: {cond.__name__} (done)\n")
18
+ return sig
19
+ return wrapper
20
+
21
+ @condition
22
  def baseline(sig, interface):
23
  return sig
24
 
25
+ @condition
26
  def reconstructed(sig, interface):
27
  return interface.to_signal(
28
  interface.encode(sig)
29
  )
30
 
31
+ @condition
32
  def coarse2fine(sig, interface):
33
  z = interface.encode(sig)
34
  z = z[:, :interface.c2f.n_conditioning_codebooks, :]
 
36
  z = interface.coarse_to_fine(z)
37
  return interface.to_signal(z)
38
 
39
+ @condition
40
+ def coarse2fine_argmax(sig, interface):
41
+ z = interface.encode(sig)
42
+ z = z[:, :interface.c2f.n_conditioning_codebooks, :]
43
+
44
+ z = interface.coarse_to_fine(z,
45
+ sample="argmax", sampling_steps=1,
46
+ temperature=1.0
47
+ )
48
+ return interface.to_signal(z)
49
+
50
+ @condition
51
  def one_codebook(sig, interface):
52
  z = interface.encode(sig)
53
 
 
63
 
64
  return interface.to_signal(zv)
65
 
66
+ @condition
67
  def four_codebooks_downsampled_4x(sig, interface):
68
  zv = interface.coarse_vamp_v2(
69
  sig, downsample_factor=4
 
71
  zv = interface.coarse_to_fine(zv)
72
  return interface.to_signal(zv)
73
 
74
+ @condition
75
  def two_codebooks_downsampled_4x(sig, interface):
76
  z = interface.encode(sig)
77
 
 
87
 
88
  return interface.to_signal(zv)
89
 
90
+ @condition
91
  def four_codebooks_downsampled_8x(sig, interface):
92
  zv = interface.coarse_vamp_v2(
93
  sig, downsample_factor=8
 
96
  return interface.to_signal(zv)
97
 
98
 
99
+ COARSE_SAMPLE_CONDS ={
100
  "baseline": baseline,
101
  "reconstructed": reconstructed,
102
  "coarse2fine": coarse2fine,
 
106
  "four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
107
  }
108
 
109
+ C2F_SAMPLE_CONDS = {
110
+ "baseline": baseline,
111
+ "reconstructed": reconstructed,
112
+ "coarse2fine": coarse2fine,
113
+ "coarse2fine_argmax": coarse2fine_argmax,
114
+ }
115
 
116
  @argbind.bind(without_prefix=True)
117
  def main(
 
120
  ],
121
  output_dir: str = "./samples",
122
  max_excerpts: int = 5000,
123
+ exp_type: str = "coarse",
124
+ seed: int = 0,
125
  ):
126
+ at.util.seed(seed)
127
  interface = Interface()
128
 
129
  output_dir = Path(output_dir)
 
139
  without_replacement=True,
140
  )
141
 
142
+ SAMPLE_CONDS = COARSE_SAMPLE_CONDS if exp_type == "coarse" else C2F_SAMPLE_CONDS
143
+
144
  for i in tqdm(range(max_excerpts)):
145
  sig = dataset[i]["signal"]
146