breadlicker45 commited on
Commit
f9dccaf
·
verified ·
1 Parent(s): 407a995

Update midi_model.py

Browse files
Files changed (1) hide show
  1. midi_model.py +151 -50
midi_model.py CHANGED
@@ -7,22 +7,26 @@ import torch.nn as nn
7
  import torch.nn.functional as F
8
  import tqdm
9
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
10
- from transformers import LlamaModel, LlamaConfig, DynamicCache, PretrainedConfig, PreTrainedModel
 
 
11
 
12
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
13
 
14
  config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
15
 
16
-
17
  class MIDIModelConfig(PretrainedConfig):
18
  model_type = "midi_model"
19
 
20
  def __init__(self,
21
  tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
22
- net_config: Union[LlamaConfig, Dict]=None,
23
- net_token_config: Union[LlamaConfig, Dict]=None,
 
24
  **kwargs):
25
  super().__init__(**kwargs)
 
 
26
  if tokenizer:
27
  if isinstance(tokenizer, dict):
28
  self.tokenizer = MIDITokenizer(tokenizer["version"])
@@ -31,52 +35,72 @@ class MIDIModelConfig(PretrainedConfig):
31
  self.tokenizer = tokenizer
32
  else:
33
  self.tokenizer = MIDITokenizer()
 
34
  if net_config:
35
  if isinstance(net_config, dict):
36
- self.net_config = LlamaConfig(**net_config)
37
  else:
38
  self.net_config = net_config
39
  else:
40
- self.net_config = LlamaConfig()
 
41
  if net_token_config:
42
  if isinstance(net_token_config, dict):
43
- self.net_token_config = LlamaConfig(**net_token_config)
44
  else:
45
  self.net_token_config = net_token_config
46
  else:
47
- self.net_token_config = LlamaConfig()
 
48
  self.n_embd = self.net_token_config.hidden_size
49
 
50
  def to_dict(self) -> Dict[str, Any]:
51
  d = super().to_dict()
52
  d["tokenizer"] = self.tokenizer.to_dict()
 
53
  return d
54
 
55
  def __str__(self):
56
  d = {
 
57
  "net": self.net_config.to_json_string(use_diff=False),
58
  "net_token": self.net_token_config.to_json_string(use_diff=False)
59
  }
60
  return json.dumps(d, indent=4)
61
 
62
  @staticmethod
63
- def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096):
64
  tokenizer = MIDITokenizer(tokenizer_ver)
65
  tokenizer.set_optimise_midi(optimise_midi)
66
- net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
67
- hidden_size=n_embd, num_attention_heads=n_head,
68
- num_hidden_layers=n_layer, intermediate_size=n_inner,
69
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
70
- use_cache=False)
71
- net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
72
- hidden_size=n_embd, num_attention_heads=n_head // 4,
73
- num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
74
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
75
- use_cache=False)
76
- return MIDIModelConfig(tokenizer, net_config, net_token_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @staticmethod
79
- def from_name(name="tv2o-medium"):
80
  tv, size = name.split("-")
81
  tv = tv[1:]
82
  if tv[-1] == "o":
@@ -84,26 +108,45 @@ class MIDIModelConfig(PretrainedConfig):
84
  tv = tv[:-1]
85
  else:
86
  o = False
 
87
  if tv not in ["v1", "v2"]:
88
  raise ValueError(f"Unknown tokenizer version {tv}")
 
89
  if size == "medium":
90
- return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
91
- n_layer=12, n_head=16, n_embd=1024, n_inner=4096)
 
 
 
 
 
 
 
92
  elif size == "large":
93
- return MIDIModelConfig.get_config(tokenizer_ver=tv, optimise_midi=o,
94
- n_layer=24, n_head=16, n_embd=1024, n_inner=4096)
 
 
 
 
 
 
 
95
  else:
96
  raise ValueError(f"Unknown model size {size}")
97
 
98
-
99
  class MIDIModel(PreTrainedModel):
100
  config_class = MIDIModelConfig
101
 
102
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
103
  super(MIDIModel, self).__init__(config, *args, **kwargs)
104
  self.tokenizer = config.tokenizer
105
- self.net = LlamaModel(config.net_config)
106
- self.net_token = LlamaModel(config.net_token_config)
 
 
 
 
107
  self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
108
 
109
  def load_merge_lora(self, model_id):
@@ -115,62 +158,97 @@ class MIDIModel(PreTrainedModel):
115
 
116
  def forward_token(self, hidden_state=None, x=None, cache=None):
117
  """
118
-
119
  :param hidden_state: (batch_size, n_embd)
120
  :param x: (batch_size, token_sequence_length)
121
  :param cache: Cache
122
  :return: (batch_size, 1 + token_sequence_length, vocab_size)
123
  """
124
  if hidden_state is not None:
125
- #if you use cache, you don't need to pass in hidden_state
126
  hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
127
  if x is not None:
128
  x = self.net_token.embed_tokens(x)
129
  if hidden_state is not None:
130
  x = torch.cat([hidden_state, x], dim=1)
131
  hidden_state = x
132
- hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
133
- past_key_values=cache,
134
- use_cache=cache is not None).last_hidden_state
 
 
135
  return self.lm_head(hidden_state)
136
 
137
- def forward(self, x, cache = None):
138
  """
139
  :param x: (batch_size, midi_sequence_length, token_sequence_length)
140
  :param cache: Cache
141
  :return: hidden (batch_size, midi_sequence_length, n_embd)
142
  """
143
-
144
- # merge token sequence
145
  x = self.net.embed_tokens(x)
146
  x = x.sum(dim=-2)
147
- x = self.net.forward(inputs_embeds=x,
148
- past_key_values=cache,
149
- use_cache=cache is not None)
 
 
150
  return x.last_hidden_state
151
 
152
  def sample_top_p_k(self, probs, p, k, generator=None):
 
 
 
 
 
 
 
 
 
153
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
154
  probs_sum = torch.cumsum(probs_sort, dim=-1)
155
  mask = probs_sum - probs_sort > p
156
  probs_sort[mask] = 0.0
 
157
  mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
158
  mask[:k] = 1
159
  probs_sort = probs_sort * mask
 
160
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
161
  shape = probs_sort.shape
162
- next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]),
163
- num_samples=1, generator=generator).reshape(*shape[:-1], 1)
 
 
 
 
 
164
  next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
165
  return next_token
166
 
167
  @torch.inference_mode()
168
  def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
 
 
 
 
 
 
 
 
 
 
 
 
169
  tokenizer = self.tokenizer
170
  max_token_seq = tokenizer.max_token_seq
 
 
171
  if prompt is None:
172
- input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
173
- input_tensor[0, 0] = tokenizer.bos_id # bos
 
 
 
 
 
174
  input_tensor = input_tensor.unsqueeze(0)
175
  input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
176
  else:
@@ -181,16 +259,22 @@ class MIDIModel(PreTrainedModel):
181
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
182
  elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
183
  raise ValueError(f"invalid shape for prompt, {prompt.shape}")
 
184
  prompt = prompt[..., :max_token_seq]
185
  if prompt.shape[-1] < max_token_seq:
186
- prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
187
- mode="constant", constant_values=tokenizer.pad_id)
 
 
 
 
188
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
189
 
190
  cur_len = input_tensor.shape[1]
191
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
192
  cache1 = DynamicCache()
193
  past_len = 0
 
194
  with bar:
195
  while cur_len < max_len:
196
  end = [False] * batch_size
@@ -198,12 +282,19 @@ class MIDIModel(PreTrainedModel):
198
  next_token_seq = None
199
  event_names = [""] * batch_size
200
  cache2 = DynamicCache()
 
201
  for i in range(max_token_seq):
202
- mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
 
 
 
 
 
203
  for b in range(batch_size):
204
  if end[b]:
205
  mask[b, tokenizer.pad_id] = 1
206
  continue
 
207
  if i == 0:
208
  mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
209
  else:
@@ -212,15 +303,19 @@ class MIDIModel(PreTrainedModel):
212
  mask[b, tokenizer.pad_id] = 1
213
  continue
214
  mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
 
215
  mask = mask.unsqueeze(1)
216
  x = next_token_seq
 
217
  if i != 0:
218
- # cached
219
  hidden = None
220
  x = x[:, -1:]
 
221
  logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
222
  scores = torch.softmax(logits / temp, dim=-1) * mask
223
  samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
 
224
  if i == 0:
225
  next_token_seq = samples
226
  for b in range(batch_size):
@@ -237,8 +332,13 @@ class MIDIModel(PreTrainedModel):
237
  break
238
 
239
  if next_token_seq.shape[1] < max_token_seq:
240
- next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
241
- "constant", value=tokenizer.pad_id)
 
 
 
 
 
242
  next_token_seq = next_token_seq.unsqueeze(1)
243
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
244
  past_len = cur_len
@@ -247,4 +347,5 @@ class MIDIModel(PreTrainedModel):
247
 
248
  if all(end):
249
  break
250
- return input_tensor.cpu().numpy()
 
 
7
  import torch.nn.functional as F
8
  import tqdm
9
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
10
+ from transformers import LlamaModel, Phi3Model
11
+ from transformers import LlamaConfig, Phi3Config
12
+ from transformers import DynamicCache, PretrainedConfig, PreTrainedModel
13
 
14
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
15
 
16
  config_name_list = ["tv1-medium", "tv2-medium", "tv2o-medium", "tv2-large", "tv2o-large"]
17
 
 
18
  class MIDIModelConfig(PretrainedConfig):
19
  model_type = "midi_model"
20
 
21
  def __init__(self,
22
  tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2, Dict]=None,
23
+ net_config: Union[LlamaConfig, Phi3Config, Dict]=None,
24
+ net_token_config: Union[LlamaConfig, Phi3Config, Dict]=None,
25
+ model_type: str = "llama",
26
  **kwargs):
27
  super().__init__(**kwargs)
28
+ self.model_type = model_type
29
+
30
  if tokenizer:
31
  if isinstance(tokenizer, dict):
32
  self.tokenizer = MIDITokenizer(tokenizer["version"])
 
35
  self.tokenizer = tokenizer
36
  else:
37
  self.tokenizer = MIDITokenizer()
38
+
39
  if net_config:
40
  if isinstance(net_config, dict):
41
+ self.net_config = LlamaConfig(**net_config) if model_type == "llama" else Phi3Config(**net_config)
42
  else:
43
  self.net_config = net_config
44
  else:
45
+ self.net_config = LlamaConfig() if model_type == "llama" else Phi3Config()
46
+
47
  if net_token_config:
48
  if isinstance(net_token_config, dict):
49
+ self.net_token_config = LlamaConfig(**net_token_config) if model_type == "llama" else Phi3Config(**net_token_config)
50
  else:
51
  self.net_token_config = net_token_config
52
  else:
53
+ self.net_token_config = LlamaConfig() if model_type == "llama" else Phi3Config()
54
+
55
  self.n_embd = self.net_token_config.hidden_size
56
 
57
  def to_dict(self) -> Dict[str, Any]:
58
  d = super().to_dict()
59
  d["tokenizer"] = self.tokenizer.to_dict()
60
+ d["model_type"] = self.model_type
61
  return d
62
 
63
  def __str__(self):
64
  d = {
65
+ "model_type": self.model_type,
66
  "net": self.net_config.to_json_string(use_diff=False),
67
  "net_token": self.net_token_config.to_json_string(use_diff=False)
68
  }
69
  return json.dumps(d, indent=4)
70
 
71
  @staticmethod
72
+ def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, model_type="llama"):
73
  tokenizer = MIDITokenizer(tokenizer_ver)
74
  tokenizer.set_optimise_midi(optimise_midi)
75
+
76
+ config_class = LlamaConfig if model_type == "llama" else Phi3Config
77
+
78
+ net_config = config_class(
79
+ vocab_size=tokenizer.vocab_size,
80
+ hidden_size=n_embd,
81
+ num_attention_heads=n_head,
82
+ num_hidden_layers=n_layer,
83
+ intermediate_size=n_inner,
84
+ pad_token_id=tokenizer.pad_id,
85
+ max_position_embeddings=4096,
86
+ use_cache=False
87
+ )
88
+
89
+ net_token_config = config_class(
90
+ vocab_size=tokenizer.vocab_size,
91
+ hidden_size=n_embd,
92
+ num_attention_heads=n_head // 4,
93
+ num_hidden_layers=n_layer // 4,
94
+ intermediate_size=n_inner // 4,
95
+ pad_token_id=tokenizer.pad_id,
96
+ max_position_embeddings=4096,
97
+ use_cache=False
98
+ )
99
+
100
+ return MIDIModelConfig(tokenizer, net_config, net_token_config, model_type=model_type)
101
 
102
  @staticmethod
103
+ def from_name(name="tv2o-medium", model_type="llama"):
104
  tv, size = name.split("-")
105
  tv = tv[1:]
106
  if tv[-1] == "o":
 
108
  tv = tv[:-1]
109
  else:
110
  o = False
111
+
112
  if tv not in ["v1", "v2"]:
113
  raise ValueError(f"Unknown tokenizer version {tv}")
114
+
115
  if size == "medium":
116
+ return MIDIModelConfig.get_config(
117
+ tokenizer_ver=tv,
118
+ optimise_midi=o,
119
+ n_layer=12,
120
+ n_head=16,
121
+ n_embd=1024,
122
+ n_inner=4096,
123
+ model_type=model_type
124
+ )
125
  elif size == "large":
126
+ return MIDIModelConfig.get_config(
127
+ tokenizer_ver=tv,
128
+ optimise_midi=o,
129
+ n_layer=24,
130
+ n_head=16,
131
+ n_embd=1024,
132
+ n_inner=4096,
133
+ model_type=model_type
134
+ )
135
  else:
136
  raise ValueError(f"Unknown model size {size}")
137
 
 
138
  class MIDIModel(PreTrainedModel):
139
  config_class = MIDIModelConfig
140
 
141
  def __init__(self, config: MIDIModelConfig, *args, **kwargs):
142
  super(MIDIModel, self).__init__(config, *args, **kwargs)
143
  self.tokenizer = config.tokenizer
144
+
145
+ # Initialize the appropriate model type
146
+ model_class = LlamaModel if config.model_type == "llama" else Phi3Model
147
+ self.net = model_class(config.net_config)
148
+ self.net_token = model_class(config.net_token_config)
149
+
150
  self.lm_head = nn.Linear(config.n_embd, self.tokenizer.vocab_size, bias=False)
151
 
152
  def load_merge_lora(self, model_id):
 
158
 
159
  def forward_token(self, hidden_state=None, x=None, cache=None):
160
  """
 
161
  :param hidden_state: (batch_size, n_embd)
162
  :param x: (batch_size, token_sequence_length)
163
  :param cache: Cache
164
  :return: (batch_size, 1 + token_sequence_length, vocab_size)
165
  """
166
  if hidden_state is not None:
 
167
  hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
168
  if x is not None:
169
  x = self.net_token.embed_tokens(x)
170
  if hidden_state is not None:
171
  x = torch.cat([hidden_state, x], dim=1)
172
  hidden_state = x
173
+ hidden_state = self.net_token.forward(
174
+ inputs_embeds=hidden_state,
175
+ past_key_values=cache,
176
+ use_cache=cache is not None
177
+ ).last_hidden_state
178
  return self.lm_head(hidden_state)
179
 
180
+ def forward(self, x, cache=None):
181
  """
182
  :param x: (batch_size, midi_sequence_length, token_sequence_length)
183
  :param cache: Cache
184
  :return: hidden (batch_size, midi_sequence_length, n_embd)
185
  """
 
 
186
  x = self.net.embed_tokens(x)
187
  x = x.sum(dim=-2)
188
+ x = self.net.forward(
189
+ inputs_embeds=x,
190
+ past_key_values=cache,
191
+ use_cache=cache is not None
192
+ )
193
  return x.last_hidden_state
194
 
195
  def sample_top_p_k(self, probs, p, k, generator=None):
196
+ """
197
+ Sample from top-p and top-k filtered probability distribution
198
+
199
+ :param probs: probability distribution
200
+ :param p: top-p threshold
201
+ :param k: top-k threshold
202
+ :param generator: random number generator
203
+ :return: sampled token indices
204
+ """
205
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
206
  probs_sum = torch.cumsum(probs_sort, dim=-1)
207
  mask = probs_sum - probs_sort > p
208
  probs_sort[mask] = 0.0
209
+
210
  mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
211
  mask[:k] = 1
212
  probs_sort = probs_sort * mask
213
+
214
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
215
  shape = probs_sort.shape
216
+
217
+ next_token = torch.multinomial(
218
+ probs_sort.reshape(-1, shape[-1]),
219
+ num_samples=1,
220
+ generator=generator
221
+ ).reshape(*shape[:-1], 1)
222
+
223
  next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
224
  return next_token
225
 
226
  @torch.inference_mode()
227
  def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, generator=None):
228
+ """
229
+ Generate MIDI sequences
230
+
231
+ :param prompt: optional input prompt
232
+ :param batch_size: number of sequences to generate
233
+ :param max_len: maximum sequence length
234
+ :param temp: temperature for sampling
235
+ :param top_p: top-p threshold for sampling
236
+ :param top_k: top-k threshold for sampling
237
+ :param generator: random number generator
238
+ :return: generated sequences
239
+ """
240
  tokenizer = self.tokenizer
241
  max_token_seq = tokenizer.max_token_seq
242
+
243
+ # Initialize input tensor
244
  if prompt is None:
245
+ input_tensor = torch.full(
246
+ (1, max_token_seq),
247
+ tokenizer.pad_id,
248
+ dtype=torch.long,
249
+ device=self.device
250
+ )
251
+ input_tensor[0, 0] = tokenizer.bos_id
252
  input_tensor = input_tensor.unsqueeze(0)
253
  input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
254
  else:
 
259
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
260
  elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
261
  raise ValueError(f"invalid shape for prompt, {prompt.shape}")
262
+
263
  prompt = prompt[..., :max_token_seq]
264
  if prompt.shape[-1] < max_token_seq:
265
+ prompt = np.pad(
266
+ prompt,
267
+ ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
268
+ mode="constant",
269
+ constant_values=tokenizer.pad_id
270
+ )
271
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
272
 
273
  cur_len = input_tensor.shape[1]
274
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
275
  cache1 = DynamicCache()
276
  past_len = 0
277
+
278
  with bar:
279
  while cur_len < max_len:
280
  end = [False] * batch_size
 
282
  next_token_seq = None
283
  event_names = [""] * batch_size
284
  cache2 = DynamicCache()
285
+
286
  for i in range(max_token_seq):
287
+ mask = torch.zeros(
288
+ (batch_size, tokenizer.vocab_size),
289
+ dtype=torch.int64,
290
+ device=self.device
291
+ )
292
+
293
  for b in range(batch_size):
294
  if end[b]:
295
  mask[b, tokenizer.pad_id] = 1
296
  continue
297
+
298
  if i == 0:
299
  mask[b, list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
300
  else:
 
303
  mask[b, tokenizer.pad_id] = 1
304
  continue
305
  mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
306
+
307
  mask = mask.unsqueeze(1)
308
  x = next_token_seq
309
+
310
  if i != 0:
311
+ # Use cache for non-first tokens
312
  hidden = None
313
  x = x[:, -1:]
314
+
315
  logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
316
  scores = torch.softmax(logits / temp, dim=-1) * mask
317
  samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
318
+
319
  if i == 0:
320
  next_token_seq = samples
321
  for b in range(batch_size):
 
332
  break
333
 
334
  if next_token_seq.shape[1] < max_token_seq:
335
+ next_token_seq = F.pad(
336
+ next_token_seq,
337
+ (0, max_token_seq - next_token_seq.shape[1]),
338
+ "constant",
339
+ value=tokenizer.pad_id
340
+ )
341
+
342
  next_token_seq = next_token_seq.unsqueeze(1)
343
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
344
  past_len = cur_len
 
347
 
348
  if all(end):
349
  break
350
+
351
+ return input_tensor.cpu().numpy()