freddyaboulton HF staff commited on
Commit
72c65b6
·
1 Parent(s): 93acf16
Files changed (2) hide show
  1. app.py +12 -0
  2. requirements.txt +2 -1
app.py CHANGED
@@ -14,12 +14,21 @@ from pydub import AudioSegment
14
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
15
  from transformers.generation.streamers import BaseStreamer
16
  from huggingface_hub import InferenceClient
 
 
 
17
 
18
  device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
19
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
20
 
21
  repo_id = "parler-tts/parler_tts_mini_v0.1"
22
 
 
 
 
 
 
 
23
  model = ParlerTTSForConditionalGeneration.from_pretrained(
24
  repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
25
  ).to(device)
@@ -205,6 +214,9 @@ def generate_base(subject, setting, ):
205
  gr.Info("Story Generated", duration=3)
206
  story = response.choices[0].message.content
207
 
 
 
 
208
  play_steps_in_s = 4.0
209
  play_steps = int(frame_rate * play_steps_in_s)
210
  streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
 
14
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
15
  from transformers.generation.streamers import BaseStreamer
16
  from huggingface_hub import InferenceClient
17
+ import nltk
18
+ nltk.download('punkt')
19
+
20
 
21
  device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
  torch_dtype = torch.float16 if device != "cpu" else torch.float32
23
 
24
  repo_id = "parler-tts/parler_tts_mini_v0.1"
25
 
26
+ jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"
27
+
28
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
29
+ jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
30
+ ).to(device)
31
+
32
  model = ParlerTTSForConditionalGeneration.from_pretrained(
33
  repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
34
  ).to(device)
 
214
  gr.Info("Story Generated", duration=3)
215
  story = response.choices[0].message.content
216
 
217
+ model_input = story.replace("\n", " ").strip()
218
+ model_input = nltk.sent_tokenize(model_input)
219
+
220
  play_steps_in_s = 4.0
221
  play_steps = int(frame_rate * play_steps_in_s)
222
  streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  https://gradio-builds.s3.amazonaws.com/bed454c3d22cfacedc047eb3b0ba987b485ac3fd/gradio-4.40.0-py3-none-any.whl
2
  git+https://github.com/huggingface/parler-tts.git
3
- accelerate
 
 
1
  https://gradio-builds.s3.amazonaws.com/bed454c3d22cfacedc047eb3b0ba987b485ac3fd/gradio-4.40.0-py3-none-any.whl
2
  git+https://github.com/huggingface/parler-tts.git
3
+ accelerate
4
+ nltk