sanchit-gandhi commited on
Commit
6f5cea7
·
1 Parent(s): 5039fa6

generation logic

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -149,10 +149,22 @@ class ParlerTTSStreamer(BaseStreamer):
149
  # send the input_ids to the correct device
150
  input_ids = input_ids.to(self.audio_encoder.device)
151
 
152
- output_values = self.audio_encoder.decode(
153
- input_ids,
154
- audio_scales=[None],
 
155
  )
 
 
 
 
 
 
 
 
 
 
 
156
  audio_values = output_values.audio_values[0, 0]
157
  return audio_values.cpu().float().numpy()
158
 
 
149
  # send the input_ids to the correct device
150
  input_ids = input_ids.to(self.audio_encoder.device)
151
 
152
+ decode_sequentially = (
153
+ self.generation_config.bos_token_id in input_ids
154
+ or self.generation_config.pad_token_id in input_ids
155
+ or self.generation_config.eos_token_id in input_ids
156
  )
157
+ if not decode_sequentially:
158
+ output_values = self.audio_encoder.decode(
159
+ input_ids,
160
+ audio_scales=[None],
161
+ )
162
+ else:
163
+ sample = input_ids[:, 0]
164
+ sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
165
+ sample = sample[:, :, sample_mask]
166
+ output_values = self.audio_encoder.decode(sample[None, ...], [None])
167
+
168
  audio_values = output_values.audio_values[0, 0]
169
  return audio_values.cpu().float().numpy()
170