Dionyssos commited on
Commit
60fbcf9
·
1 Parent(s): 731cb10

fx k,v history after project

Browse files
Files changed (2) hide show
  1. audiocraft/lm.py +1 -1
  2. audiocraft/transformer.py +20 -16
audiocraft/lm.py CHANGED
@@ -147,7 +147,7 @@ class LMModel(nn.Module):
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
- self.n_draw = 8
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
 
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
+ self.n_draw = 2
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
audiocraft/transformer.py CHANGED
@@ -194,21 +194,12 @@ class StreamingMultiheadAttention(nn.Module):
194
 
195
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
196
  else:
197
-
 
 
 
198
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
199
- if self.k_history is not None:
200
- #
201
- # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
202
- # has to be 4D with batch 1 due to single condition 3=seqlen
203
- # 24 heads 64 dimofh
204
- self.k_history = torch.cat([self.k_history, query], 2)
205
- self.v_history = torch.cat([self.v_history, query], 2)
206
- else:
207
- # init on 1st token (for all 47 transf layers)
208
- self.k_history = query
209
- self.v_history = query
210
-
211
-
212
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
213
  if self.kv_repeat == 1:
214
  # if time_dim == 2:
@@ -217,7 +208,21 @@ class StreamingMultiheadAttention(nn.Module):
217
  # bound_layout = "b t p h d"
218
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
219
  q, k, v = ops.unbind(packed, dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
220
 
 
 
221
 
222
 
223
 
@@ -235,8 +240,7 @@ class StreamingMultiheadAttention(nn.Module):
235
  # k, v = self._complete_kv(k, v)
236
  # print(k.sum(), v.sum(), k.shape, v.shape,'ATTNext')
237
 
238
- if self.attention_as_float32:
239
- q, k, v = [x.float() for x in [q, k, v]]
240
  if self.memory_efficient:
241
  # print('EVER IN MEMORY EFFICIENT A')
242
 
 
194
 
195
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
196
  else:
197
+ # 1st projected makes k,v (instantaneous)
198
+ # 2nd cat
199
+
200
+
201
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
202
+
 
 
 
 
 
 
 
 
 
 
 
 
203
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
204
  if self.kv_repeat == 1:
205
  # if time_dim == 2:
 
208
  # bound_layout = "b t p h d"
209
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
210
  q, k, v = ops.unbind(packed, dim=2)
211
+
212
+ if self.k_history is not None:
213
+ #
214
+ # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
215
+ # has to be 4D with batch 1 due to single condition 3=seqlen
216
+ # 24 heads 64 dimofh
217
+ self.k_history = torch.cat([self.k_history, k], 2)
218
+ self.v_history = torch.cat([self.v_history, v], 2)
219
+ else:
220
+ # init on 1st token (for all 47 transf layers)
221
+ self.k_history = k
222
+ self.v_history = v
223
 
224
+ k = self.k_history
225
+ v = self.v_history
226
 
227
 
228
 
 
240
  # k, v = self._complete_kv(k, v)
241
  # print(k.sum(), v.sum(), k.shape, v.shape,'ATTNext')
242
 
243
+ print(f'{self.attention_as_float32=}')
 
244
  if self.memory_efficient:
245
  # print('EVER IN MEMORY EFFICIENT A')
246