Hugo Flores Garcia commited on
Commit
1a5973b
Β·
1 Parent(s): fa490b8
conf/{interface-c2f-exp.yml β†’ interface/interface-c2f-exp.yml} RENAMED
File without changes
conf/{interface-jazzpop.yml β†’ interface/interface-jazzpop.yml} RENAMED
File without changes
conf/{interface-maestro.yml β†’ interface/interface-maestro.yml} RENAMED
File without changes
conf/{interface-spotdl.yml β†’ interface/interface-spotdl.yml} RENAMED
File without changes
conf/lora/lora-is-this-charlie-parker.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ fine_tune: True
scripts/exp/train.py CHANGED
@@ -260,6 +260,7 @@ def train(
260
  suffix_amt: float = 0.0,
261
  prefix_dropout: float = 0.1,
262
  suffix_dropout: float = 0.1,
 
263
  quiet: bool = False,
264
  ):
265
  assert codec_ckpt is not None, "codec_ckpt is required"
@@ -310,6 +311,11 @@ def train(
310
 
311
  criterion = CrossEntropyLoss()
312
 
 
 
 
 
 
313
  class Trainer(at.ml.BaseTrainer):
314
  _last_grad_norm = 0.0
315
 
 
260
  suffix_amt: float = 0.0,
261
  prefix_dropout: float = 0.1,
262
  suffix_dropout: float = 0.1,
263
+ fine_tune: bool = False,
264
  quiet: bool = False,
265
  ):
266
  assert codec_ckpt is not None, "codec_ckpt is required"
 
311
 
312
  criterion = CrossEntropyLoss()
313
 
314
+ if fine_tune:
315
+ import loralib as lora
316
+ lora.mark_only_lora_as_trainable(model)
317
+
318
+
319
  class Trainer(at.ml.BaseTrainer):
320
  _last_grad_norm = 0.0
321
 
vampnet/modules/transformer.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from einops import rearrange
 
8
 
9
  from .base import VampBase
10
  from .activations import get_activation
@@ -13,6 +14,8 @@ from .layers import FiLM
13
  from .layers import SequentialWithFiLM
14
  from .layers import WNConv1d
15
 
 
 
16
 
17
  class RMSNorm(nn.Module):
18
  def __init__(self, hidden_size: int, eps=1e-6):
@@ -86,9 +89,9 @@ class MultiHeadRelativeAttention(nn.Module):
86
  self.attention_max_distance = attention_max_distance
87
 
88
  # Create linear query, key, value projections
89
- self.w_qs = nn.Linear(d_model, d_model, bias=False)
90
  self.w_ks = nn.Linear(d_model, d_model, bias=False)
91
- self.w_vs = nn.Linear(d_model, d_model, bias=False)
92
 
93
  # Create linear final output projection
94
  self.fc = nn.Linear(d_model, d_model, bias=False)
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from einops import rearrange
8
+ import loralib as lora
9
 
10
  from .base import VampBase
11
  from .activations import get_activation
 
14
  from .layers import SequentialWithFiLM
15
  from .layers import WNConv1d
16
 
17
+ LORA_R = 4
18
+
19
 
20
  class RMSNorm(nn.Module):
21
  def __init__(self, hidden_size: int, eps=1e-6):
 
89
  self.attention_max_distance = attention_max_distance
90
 
91
  # Create linear query, key, value projections
92
+ self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
93
  self.w_ks = nn.Linear(d_model, d_model, bias=False)
94
+ self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
95
 
96
  # Create linear final output projection
97
  self.fc = nn.Linear(d_model, d_model, bias=False)