Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
63015d5
1
Parent(s):
1a5973b
looks like it's working?
Browse files- conf/lora/ella-baila-sola.yaml +10 -0
- conf/lora/lora-is-this-charlie-parker.yaml +8 -2
- conf/lora/lora.yaml +16 -0
- conf/vampnet.yml +3 -3
- env/setup.py +0 -29
- scripts/exp/train.py +4 -4
- scripts/utils/split.py +51 -0
- vampnet/modules/__init__.py +1 -1
- vampnet/modules/transformer.py +1 -1
conf/lora/ella-baila-sola.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Eslabon Armado - Ella Baila Sola.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Eslabon Armado - Ella Baila Sola.mp3
|
conf/lora/lora-is-this-charlie-parker.yaml
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
$include:
|
2 |
-
- conf/
|
3 |
|
4 |
-
fine_tune: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
$include:
|
2 |
+
- conf/lora.yml
|
3 |
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Yardbird Suite.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Yardbird Suite.mp3
|
conf/lora/lora.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioDataset.n_examples: 10000000
|
7 |
+
|
8 |
+
val/AudioDataset.n_examples: 10
|
9 |
+
|
10 |
+
|
11 |
+
NoamScheduler.warmup: 250
|
12 |
+
|
13 |
+
epoch_length: 100
|
14 |
+
save_audio_epochs: 2
|
15 |
+
|
16 |
+
AdamW.lr: 0.0001
|
conf/vampnet.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
|
2 |
-
codec_ckpt: /
|
3 |
save_path: ckpt
|
4 |
max_epochs: 1000
|
5 |
epoch_length: 1000
|
@@ -11,8 +11,8 @@ suffix_amt: 0.0
|
|
11 |
prefix_dropout: 0.1
|
12 |
suffix_dropout: 0.1
|
13 |
|
14 |
-
batch_size:
|
15 |
-
num_workers:
|
16 |
|
17 |
# Optimization
|
18 |
detect_anomaly: false
|
|
|
1 |
|
2 |
+
codec_ckpt: /home/hugo/descript/vampnet/models/spotdl/codec.pth
|
3 |
save_path: ckpt
|
4 |
max_epochs: 1000
|
5 |
epoch_length: 1000
|
|
|
11 |
prefix_dropout: 0.1
|
12 |
suffix_dropout: 0.1
|
13 |
|
14 |
+
batch_size: 8
|
15 |
+
num_workers: 10
|
16 |
|
17 |
# Optimization
|
18 |
detect_anomaly: false
|
env/setup.py
CHANGED
@@ -11,36 +11,7 @@ def run(cmd):
|
|
11 |
return subprocess.check_output(shlex.split(cmd)).decode("utf-8")
|
12 |
|
13 |
|
14 |
-
print("1. Setting up Google Cloud access")
|
15 |
-
print("---------------------------------")
|
16 |
-
gcloud_authorized = "gs://research-data-raw" in run("gsutil ls")
|
17 |
-
if not gcloud_authorized:
|
18 |
-
run("gcloud auth login")
|
19 |
|
20 |
-
run("gcloud config set project lyrebird-research")
|
21 |
-
run("gcloud auth configure-docker")
|
22 |
-
|
23 |
-
print()
|
24 |
-
print("2. Setting up Github access")
|
25 |
-
print("---------------------------")
|
26 |
-
|
27 |
-
lines = textwrap.wrap(
|
28 |
-
"First, let's get your Github token, so all "
|
29 |
-
"packages can be installed. Create one by going to your "
|
30 |
-
"Github profile -> Developer settings -> Personal access tokens -> "
|
31 |
-
"Generate new token. Copy the token below."
|
32 |
-
)
|
33 |
-
[print(l) for l in lines]
|
34 |
-
|
35 |
-
GITHUB_TOKEN = input("\nGithub token: ") or "undefined"
|
36 |
-
|
37 |
-
print()
|
38 |
-
print("3. Setting up Jupyter and Tensorboard")
|
39 |
-
print("-------------------------------------")
|
40 |
-
|
41 |
-
JUPYTER_TOKEN = input("Password for Jupyter server (default:password): ") or "password"
|
42 |
-
JUPYTER_PORT = input("Jupyter port to run on (default:8888): ") or "8888"
|
43 |
-
TENSORBOARD_PORT = input("Tensorboard port to run on (default:6006): ") or "6006"
|
44 |
|
45 |
print()
|
46 |
print("4. Setting up paths.")
|
|
|
11 |
return subprocess.check_output(shlex.split(cmd)).decode("utf-8")
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
print()
|
17 |
print("4. Setting up paths.")
|
scripts/exp/train.py
CHANGED
@@ -248,12 +248,12 @@ def train(
|
|
248 |
save_path: str = "ckpt",
|
249 |
max_epochs: int = int(100e3),
|
250 |
epoch_length: int = 1000,
|
251 |
-
save_audio_epochs: int =
|
252 |
save_epochs: list = [10, 50, 100, 200, 300, 400,],
|
253 |
batch_size: int = 48,
|
254 |
grad_acc_steps: int = 1,
|
255 |
-
val_idx: list = [0, 1, 2, 3, 4],
|
256 |
-
num_workers: int =
|
257 |
detect_anomaly: bool = False,
|
258 |
grad_clip_val: float = 5.0,
|
259 |
prefix_amt: float = 0.0,
|
@@ -530,7 +530,7 @@ def train(
|
|
530 |
|
531 |
accel.unwrap(model).metadata = metadata
|
532 |
accel.unwrap(model).save_to_folder(
|
533 |
-
f"{save_path}/{tag}", model_extra
|
534 |
)
|
535 |
|
536 |
def save_sampled(self, z):
|
|
|
248 |
save_path: str = "ckpt",
|
249 |
max_epochs: int = int(100e3),
|
250 |
epoch_length: int = 1000,
|
251 |
+
save_audio_epochs: int = 2,
|
252 |
save_epochs: list = [10, 50, 100, 200, 300, 400,],
|
253 |
batch_size: int = 48,
|
254 |
grad_acc_steps: int = 1,
|
255 |
+
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
256 |
+
num_workers: int = 10,
|
257 |
detect_anomaly: bool = False,
|
258 |
grad_clip_val: float = 5.0,
|
259 |
prefix_amt: float = 0.0,
|
|
|
530 |
|
531 |
accel.unwrap(model).metadata = metadata
|
532 |
accel.unwrap(model).save_to_folder(
|
533 |
+
f"{save_path}/{tag}", model_extra,
|
534 |
)
|
535 |
|
536 |
def save_sampled(self, z):
|
scripts/utils/split.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
|
7 |
+
from audiotools.core import util
|
8 |
+
|
9 |
+
|
10 |
+
@argbind.bind(without_prefix=True)
|
11 |
+
def train_test_split(
|
12 |
+
audio_folder: str = ".",
|
13 |
+
test_size: float = 0.2,
|
14 |
+
seed: int = 42,
|
15 |
+
):
|
16 |
+
audio_files = util.find_audio(audio_folder)
|
17 |
+
|
18 |
+
# split according to test_size
|
19 |
+
n_test = int(len(audio_files) * test_size)
|
20 |
+
n_train = len(audio_files) - n_test
|
21 |
+
|
22 |
+
# shuffle
|
23 |
+
random.seed(seed)
|
24 |
+
random.shuffle(audio_files)
|
25 |
+
|
26 |
+
train_files = audio_files[:n_train]
|
27 |
+
test_files = audio_files[n_train:]
|
28 |
+
|
29 |
+
|
30 |
+
print(f"Train files: {len(train_files)}")
|
31 |
+
print(f"Test files: {len(test_files)}")
|
32 |
+
continue_ = input("Continue [yn]? ") or "n"
|
33 |
+
|
34 |
+
if continue_ != "y":
|
35 |
+
return
|
36 |
+
|
37 |
+
for split, files in (
|
38 |
+
("train", train_files), ("test", test_files)
|
39 |
+
):
|
40 |
+
for file in files:
|
41 |
+
out_file = Path(file).parent / split / Path(file).name
|
42 |
+
out_file.parent.mkdir(exist_ok=True, parents=True)
|
43 |
+
shutil.copy(file, out_file)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
args = argbind.parse_args()
|
49 |
+
|
50 |
+
with argbind.scope(args):
|
51 |
+
train_test_split()
|
vampnet/modules/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
import audiotools
|
2 |
|
3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
-
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention"]
|
|
|
1 |
import audiotools
|
2 |
|
3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
+
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
vampnet/modules/transformer.py
CHANGED
@@ -14,7 +14,7 @@ from .layers import FiLM
|
|
14 |
from .layers import SequentialWithFiLM
|
15 |
from .layers import WNConv1d
|
16 |
|
17 |
-
LORA_R =
|
18 |
|
19 |
|
20 |
class RMSNorm(nn.Module):
|
|
|
14 |
from .layers import SequentialWithFiLM
|
15 |
from .layers import WNConv1d
|
16 |
|
17 |
+
LORA_R = 8
|
18 |
|
19 |
|
20 |
class RMSNorm(nn.Module):
|