skytnt commited on
Commit
5bef524
·
1 Parent(s): 4789653
Files changed (2) hide show
  1. app.py +3 -1
  2. midi_model.py +3 -1
app.py CHANGED
@@ -53,10 +53,11 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
53
  cur_len = input_tensor.shape[1]
54
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
55
  cache1 = DynamicCache()
 
56
  with bar:
57
  while cur_len < max_len:
58
  end = [False] * batch_size
59
- hidden = model.forward(input_tensor[:, -1:], cache=cache1)[:, -1]
60
  next_token_seq = None
61
  event_names = [""] * batch_size
62
  cache2 = DynamicCache()
@@ -110,6 +111,7 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
110
  "constant", value=tokenizer.pad_id)
111
  next_token_seq = next_token_seq.unsqueeze(1)
112
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
 
113
  cur_len += 1
114
  bar.update(1)
115
  yield next_token_seq[:, 0].cpu().numpy()
 
53
  cur_len = input_tensor.shape[1]
54
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
55
  cache1 = DynamicCache()
56
+ past_len = 0
57
  with bar:
58
  while cur_len < max_len:
59
  end = [False] * batch_size
60
+ hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
61
  next_token_seq = None
62
  event_names = [""] * batch_size
63
  cache2 = DynamicCache()
 
111
  "constant", value=tokenizer.pad_id)
112
  next_token_seq = next_token_seq.unsqueeze(1)
113
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
114
+ past_len = cur_len
115
  cur_len += 1
116
  bar.update(1)
117
  yield next_token_seq[:, 0].cpu().numpy()
midi_model.py CHANGED
@@ -160,10 +160,11 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
160
  cur_len = input_tensor.shape[1]
161
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
162
  cache1 = DynamicCache()
 
163
  with bar:
164
  while cur_len < max_len:
165
  end = [False] * batch_size
166
- hidden = self.forward(input_tensor[:,-1:], cache=cache1)[:, -1]
167
  next_token_seq = None
168
  event_names = [""] * batch_size
169
  cache2 = DynamicCache()
@@ -210,6 +211,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
210
  "constant", value=tokenizer.pad_id)
211
  next_token_seq = next_token_seq.unsqueeze(1)
212
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
 
213
  cur_len += 1
214
  bar.update(1)
215
 
 
160
  cur_len = input_tensor.shape[1]
161
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
162
  cache1 = DynamicCache()
163
+ past_len = 0
164
  with bar:
165
  while cur_len < max_len:
166
  end = [False] * batch_size
167
+ hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
168
  next_token_seq = None
169
  event_names = [""] * batch_size
170
  cache2 = DynamicCache()
 
211
  "constant", value=tokenizer.pad_id)
212
  next_token_seq = next_token_seq.unsqueeze(1)
213
  input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
214
+ past_len = cur_len
215
  cur_len += 1
216
  bar.update(1)
217