zzk1st commited on
Commit
03adfb9
·
1 Parent(s): 78d395e

Fixed multi-user

Browse files
Files changed (11) hide show
  1. APIs.py +4 -4
  2. README.md +16 -6
  3. code_generator.py +1 -3
  4. config.yaml +0 -4
  5. pipeline.py +7 -6
  6. scripts/kill_services.py +1 -6
  7. services.py +2 -2
  8. ui_client.py +16 -8
  9. utils.py +10 -1
  10. voice_presets.py +1 -1
  11. wavjourney_cli.py +1 -1
APIs.py CHANGED
@@ -6,6 +6,7 @@ import pyloudnorm as pyln
6
  from scipy.io.wavfile import write
7
  import torchaudio
8
  from retrying import retry
 
9
 
10
 
11
  os.environ['OPENBLAS_NUM_THREADS'] = '1'
@@ -14,10 +15,9 @@ SAMPLE_RATE = 32000
14
 
15
  with open('config.yaml', 'r') as file:
16
  config = yaml.safe_load(file)
17
- service_port = config['Service-Port']
 
18
  enable_sr = config['Speech-Restoration']['Enable']
19
- localhost_addr = '0.0.0.0'
20
-
21
 
22
  def LOUDNESS_NORM(audio, sr=32000, volumn=-25):
23
  # peak normalize audio to -1 dB
@@ -148,7 +148,7 @@ def TTA(text, length=5, volume=-35, out_wav='out.wav'):
148
 
149
 
150
  @retry(stop_max_attempt_number=5, wait_fixed=2000)
151
- def TTS(text, speaker='news_anchor', volume=-20, out_wav='out.wav', enhanced=enable_sr, speaker_id='', speaker_npz=''):
152
  url = f'http://{localhost_addr}:{service_port}/generate_speech'
153
  data = {
154
  'text': f'{text}',
 
6
  from scipy.io.wavfile import write
7
  import torchaudio
8
  from retrying import retry
9
+ from utils import get_service_port, get_service_url
10
 
11
 
12
  os.environ['OPENBLAS_NUM_THREADS'] = '1'
 
15
 
16
  with open('config.yaml', 'r') as file:
17
  config = yaml.safe_load(file)
18
+ service_port = get_service_port()
19
+ localhost_addr = get_service_url()
20
  enable_sr = config['Speech-Restoration']['Enable']
 
 
21
 
22
  def LOUDNESS_NORM(audio, sr=32000, volumn=-25):
23
  # peak normalize audio to -1 dB
 
148
 
149
 
150
  @retry(stop_max_attempt_number=5, wait_fixed=2000)
151
+ def TTS(text, volume=-20, out_wav='out.wav', enhanced=enable_sr, speaker_id='', speaker_npz=''):
152
  url = f'http://{localhost_addr}:{service_port}/generate_speech'
153
  data = {
154
  'text': f'{text}',
README.md CHANGED
@@ -8,7 +8,7 @@ pinned: false
8
  license: cc-by-nc-nd-4.0
9
  ---
10
  # <span style="color: blue;">🎵</span> WavJourney: Compositional Audio Creation with LLMs
11
- [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2307.14335) [![GitHub Stars](https://img.shields.io/github/stars/Audio-AGI/WavJourney?style=social)](https://github.com/Audio-AGI/WavJourney/) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Page-blue?logo=Github&style=flat-square)](https://audio-agi.github.io/WavJourney_demopage/)
12
 
13
 
14
  This repository contains the official implementation of ["WavJourney: Compositional Audio Creation with Large Language Models"](https://audio-agi.github.io/WavJourney_demopage/WavJourney_arXiv.pdf).
@@ -32,14 +32,24 @@ bash ./scripts/EnvsSetup.sh
32
  conda activate WavJourney
33
  ```
34
 
35
- 3. Set your `OpenAI-Key` in `config.yaml` for accessing [GPT-4 API](https://platform.openai.com/account/api-keys) [[Guidance](https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4)]. Please make sure the 'Service-Port' is not occupied. You can also modify the configuration, check the details described in the configuration file.
36
-
37
- 3. Pre-download the models (might take some time):
38
  ```bash
39
  python scripts/download_models.py
40
  ```
41
 
42
- 5. Start Python API services (e.g., Text-to-Speech, Text-to-Audio)
 
 
 
 
 
 
 
 
 
 
 
43
  ```bash
44
  bash scripts/start_services.sh
45
  ```
@@ -51,7 +61,7 @@ bash scripts/start_ui.sh
51
 
52
  ## Commandline Usage
53
  ```bash
54
- python wavjourney_cli.py -f --input-text "Generate a one-minute introduction to quantum mechanics"
55
  ```
56
 
57
 
 
8
  license: cc-by-nc-nd-4.0
9
  ---
10
  # <span style="color: blue;">🎵</span> WavJourney: Compositional Audio Creation with LLMs
11
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2307.14335) [![GitHub Stars](https://img.shields.io/github/stars/Audio-AGI/WavJourney?style=social)](https://github.com/Audio-AGI/WavJourney/) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Page-blue?logo=Github&style=flat-square)](https://audio-agi.github.io/WavJourney_demopage/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Audio-AGI/WavJourney)
12
 
13
 
14
  This repository contains the official implementation of ["WavJourney: Compositional Audio Creation with Large Language Models"](https://audio-agi.github.io/WavJourney_demopage/WavJourney_arXiv.pdf).
 
32
  conda activate WavJourney
33
  ```
34
 
35
+ 3. (Optional) You can modify the default configuration in `config.yaml`, check the details described in the configuration file.
36
+ 4. Pre-download the models (might take some time):
 
37
  ```bash
38
  python scripts/download_models.py
39
  ```
40
 
41
+ 5. Set the WAVJOURNEY_OPENAI_KEY in the environment variable for accessing [GPT-4 API](https://platform.openai.com/account/api-keys) [[Guidance](https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4)]
42
+ ```bash
43
+ export WAVJOURNEY_OPENAI_KEY=your_openai_key_here
44
+ ```
45
+
46
+ 6. Set environment variables for using API services
47
+ ```bash
48
+ export WAVJOURNEY_SERVICE_PORT=8021 WAVJOURNEY_SERVICE_URL=127.0.0.1
49
+ ```
50
+
51
+
52
+ 7. Start Python API services (e.g., Text-to-Speech, Text-to-Audio)
53
  ```bash
54
  bash scripts/start_services.sh
55
  ```
 
61
 
62
  ## Commandline Usage
63
  ```bash
64
+ python wavjourney_cli.py -f --input-text "Generate a one-minute introduction to quantum mechanics"
65
  ```
66
 
67
 
code_generator.py CHANGED
@@ -113,10 +113,8 @@ class AudioCodeGenerator:
113
  return wav_filename
114
 
115
  header = f'''
116
- import sys
117
- sys.path.append('../AudioJourney')
118
-
119
  import os
 
120
  import datetime
121
 
122
  from APIs import TTM, TTS, TTA, MIX, CAT, COMPUTE_LEN
 
113
  return wav_filename
114
 
115
  header = f'''
 
 
 
116
  import os
117
+ import sys
118
  import datetime
119
 
120
  from APIs import TTM, TTS, TTA, MIX, CAT, COMPUTE_LEN
config.yaml CHANGED
@@ -15,7 +15,3 @@ Speech-Restoration:
15
  Voice-Parser:
16
  # HuBERT
17
  device: 'cpu'
18
-
19
- Service-Port: 8021
20
-
21
- OpenAI-Key: ''
 
15
  Voice-Parser:
16
  # HuBERT
17
  device: 'cpu'
 
 
 
 
pipeline.py CHANGED
@@ -120,6 +120,7 @@ def init_session(session_id=''):
120
  # create the paths
121
  os.makedirs(utils.get_session_voice_preset_path(session_id))
122
  os.makedirs(utils.get_session_audio_path(session_id))
 
123
  return session_id
124
 
125
  @retry(stop_max_attempt_number=3)
@@ -142,7 +143,6 @@ def input_text_to_json_script_with_retry(complete_prompt_path, api_key):
142
 
143
  # Step 1: input_text to json
144
  def input_text_to_json_script(input_text, output_path, api_key):
145
- print('Step 1: Writing audio script with LLM ...')
146
  input_text = maybe_get_content_from_file(input_text)
147
  text_to_audio_script_prompt = get_file_content('prompts/text_to_json.prompt')
148
  prompt = f'{text_to_audio_script_prompt}\n\nInput text: {input_text}\n\nScript:\n'
@@ -155,7 +155,6 @@ def input_text_to_json_script(input_text, output_path, api_key):
155
 
156
  # Step 2: json to char-voice map
157
  def json_script_to_char_voice_map(json_script, voices, output_path, api_key):
158
- print('Step 2: Parsing character voice with LLM...')
159
  json_script_content = maybe_get_content_from_file(json_script)
160
  prompt = get_file_content('prompts/audio_script_to_character_voice_map.prompt')
161
  presets_str = '\n'.join(f"{preset['id']}: {preset['desc']}" for preset in voices.values())
@@ -172,7 +171,6 @@ def json_script_to_char_voice_map(json_script, voices, output_path, api_key):
172
 
173
  # Step 3: json to py code
174
  def json_script_and_char_voice_map_to_audio_gen_code(json_script_filename, char_voice_map_filename, output_path, result_filename):
175
- print('Step 3: Compiling audio script to Python program ...')
176
  audio_code_generator = AudioCodeGenerator()
177
  code = audio_code_generator.parse_and_generate(
178
  json_script_filename,
@@ -184,14 +182,14 @@ def json_script_and_char_voice_map_to_audio_gen_code(json_script_filename, char_
184
 
185
  # Step 4: py code to final wav
186
  def audio_code_gen_to_result(audio_gen_code_path):
187
- print('Step 4: Start running Python program ...')
188
  audio_gen_code_filename = audio_gen_code_path / 'audio_generation.py'
189
- os.system(f'python {audio_gen_code_filename}')
190
 
191
  # Function call used by Gradio: input_text to json
192
  def generate_json_file(session_id, input_text, api_key):
193
  output_path = utils.get_session_path(session_id)
194
  # Step 1
 
195
  return input_text_to_json_script(input_text, output_path, api_key)
196
 
197
  # Function call used by Gradio: json to result wav
@@ -201,13 +199,16 @@ def generate_audio(session_id, json_script, api_key):
201
  voices = voice_presets.get_merged_voice_presets(session_id)
202
 
203
  # Step 2
 
204
  char_voice_map = json_script_to_char_voice_map(json_script, voices, output_path, api_key)
205
  # Step 3
206
  json_script_filename = output_path / 'audio_script.json'
207
  char_voice_map_filename = output_path / 'character_voice_map.json'
208
  result_wav_basename = f'res_{session_id}'
 
209
  json_script_and_char_voice_map_to_audio_gen_code(json_script_filename, char_voice_map_filename, output_path, result_wav_basename)
210
  # Step 4
 
211
  audio_code_gen_to_result(output_path)
212
 
213
  result_wav_filename = output_audio_path / f'{result_wav_basename}.wav'
@@ -217,4 +218,4 @@ def generate_audio(session_id, json_script, api_key):
217
  # Convenient function call used by wavjourney_cli
218
  def full_steps(session_id, input_text, api_key):
219
  json_script = generate_json_file(session_id, input_text, api_key)
220
- return generate_audio(session_id, json_script, api_key)
 
120
  # create the paths
121
  os.makedirs(utils.get_session_voice_preset_path(session_id))
122
  os.makedirs(utils.get_session_audio_path(session_id))
123
+ print(f'New session created, session_id={session_id}')
124
  return session_id
125
 
126
  @retry(stop_max_attempt_number=3)
 
143
 
144
  # Step 1: input_text to json
145
  def input_text_to_json_script(input_text, output_path, api_key):
 
146
  input_text = maybe_get_content_from_file(input_text)
147
  text_to_audio_script_prompt = get_file_content('prompts/text_to_json.prompt')
148
  prompt = f'{text_to_audio_script_prompt}\n\nInput text: {input_text}\n\nScript:\n'
 
155
 
156
  # Step 2: json to char-voice map
157
  def json_script_to_char_voice_map(json_script, voices, output_path, api_key):
 
158
  json_script_content = maybe_get_content_from_file(json_script)
159
  prompt = get_file_content('prompts/audio_script_to_character_voice_map.prompt')
160
  presets_str = '\n'.join(f"{preset['id']}: {preset['desc']}" for preset in voices.values())
 
171
 
172
  # Step 3: json to py code
173
  def json_script_and_char_voice_map_to_audio_gen_code(json_script_filename, char_voice_map_filename, output_path, result_filename):
 
174
  audio_code_generator = AudioCodeGenerator()
175
  code = audio_code_generator.parse_and_generate(
176
  json_script_filename,
 
182
 
183
  # Step 4: py code to final wav
184
  def audio_code_gen_to_result(audio_gen_code_path):
 
185
  audio_gen_code_filename = audio_gen_code_path / 'audio_generation.py'
186
+ os.system(f'PYTHONPATH=. python {audio_gen_code_filename}')
187
 
188
  # Function call used by Gradio: input_text to json
189
  def generate_json_file(session_id, input_text, api_key):
190
  output_path = utils.get_session_path(session_id)
191
  # Step 1
192
+ print(f'session_id={session_id}, Step 1: Writing audio script with LLM ...')
193
  return input_text_to_json_script(input_text, output_path, api_key)
194
 
195
  # Function call used by Gradio: json to result wav
 
199
  voices = voice_presets.get_merged_voice_presets(session_id)
200
 
201
  # Step 2
202
+ print(f'session_id={session_id}, Step 2: Parsing character voice with LLM...')
203
  char_voice_map = json_script_to_char_voice_map(json_script, voices, output_path, api_key)
204
  # Step 3
205
  json_script_filename = output_path / 'audio_script.json'
206
  char_voice_map_filename = output_path / 'character_voice_map.json'
207
  result_wav_basename = f'res_{session_id}'
208
+ print(f'session_id={session_id}, Step 3: Compiling audio script to Python program ...')
209
  json_script_and_char_voice_map_to_audio_gen_code(json_script_filename, char_voice_map_filename, output_path, result_wav_basename)
210
  # Step 4
211
+ print(f'session_id={session_id}, Step 4: Start running Python program ...')
212
  audio_code_gen_to_result(output_path)
213
 
214
  result_wav_filename = output_audio_path / f'{result_wav_basename}.wav'
 
218
  # Convenient function call used by wavjourney_cli
219
  def full_steps(session_id, input_text, api_key):
220
  json_script = generate_json_file(session_id, input_text, api_key)
221
+ return generate_audio(session_id, json_script, api_key)
scripts/kill_services.py CHANGED
@@ -1,12 +1,7 @@
1
- import yaml
2
  import os
3
 
4
- # Read the YAML file
5
- with open('config.yaml', 'r') as file:
6
- config = yaml.safe_load(file)
7
-
8
  # Extract values for each application
9
- service_port = config['Service-Port']
10
 
11
  # Execute the commands
12
  os.system(f'kill $(lsof -t -i :{service_port})')
 
 
1
  import os
2
 
 
 
 
 
3
  # Extract values for each application
4
+ service_port = os.environ.get('WAVJOURNEY_SERVICE_PORT')
5
 
6
  # Execute the commands
7
  os.system(f'kill $(lsof -t -i :{service_port})')
services.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import torchaudio
7
  from torchaudio.transforms import SpeedPerturbation
8
  from APIs import WRITE_AUDIO, LOUDNESS_NORM
9
- from utils import fade
10
  from flask import Flask, request, jsonify
11
 
12
  with open('config.yaml', 'r') as file:
@@ -226,5 +226,5 @@ def parse_voice():
226
 
227
 
228
  if __name__ == '__main__':
229
- service_port = config['Service-Port']
230
  app.run(debug=False, port=service_port)
 
6
  import torchaudio
7
  from torchaudio.transforms import SpeedPerturbation
8
  from APIs import WRITE_AUDIO, LOUDNESS_NORM
9
+ from utils import fade, get_service_port
10
  from flask import Flask, request, jsonify
11
 
12
  with open('config.yaml', 'r') as file:
 
226
 
227
 
228
  if __name__ == '__main__':
229
+ service_port = get_service_port()
230
  app.run(debug=False, port=service_port)
ui_client.py CHANGED
@@ -41,9 +41,15 @@ def convert_char_voice_map_to_md(char_voice_map):
41
  return table_txt
42
 
43
 
 
 
 
 
 
 
44
  def generate_script_fn(instruction, _state: gr.State):
45
  try:
46
- session_id = _state['session_id']
47
  api_key = utils.get_api_key()
48
  json_script = generate_json_file(session_id, instruction, api_key)
49
  table_text = convert_json_to_md(json_script)
@@ -130,12 +136,14 @@ def textbox_listener(textbox_input):
130
 
131
 
132
  def get_voice_preset_to_list(state: gr.State):
133
- if state.__class__ == dict:
134
- session_id = state['session_id']
 
 
135
  else:
136
- session_id = state.value['session_id']
137
  voice_presets = load_voice_presets_metadata(
138
- utils.get_session_voice_preset_path(session_id),
139
  safe_if_metadata_not_exist=True
140
  )
141
  dataframe = []
@@ -192,7 +200,7 @@ def add_voice_preset(vp_id, vp_desc, file, ui_state, added_voice_preset):
192
  else:
193
  count: int = added_voice_preset['count']
194
  # check if greater than 3
195
- session_id = ui_state['session_id']
196
  file_path = file.name
197
  print(f'session {session_id}, id {id}, desc {vp_desc}, file {file_path}')
198
  # Do adding ...
@@ -398,7 +406,7 @@ with gr.Blocks(css=css) as interface:
398
 
399
  system_voice_presets = get_system_voice_presets()
400
  # State
401
- ui_state = gr.State(value={'session_id': pipeline.init_session()})
402
  selected_voice_presets = gr.State(value={'selected_voice_preset': None})
403
  added_voice_preset_state = gr.State(value={'added_file': None, 'count': 0})
404
  # UI Component
@@ -557,4 +565,4 @@ with gr.Blocks(css=css) as interface:
557
  # print_state_btn = gr.Button(value='Print State')
558
  # print_state_btn.click(fn=lambda state, state2: print(state, state2), inputs=[ui_state, selected_voice_presets])
559
  interface.queue(concurrency_count=5, max_size=20)
560
- interface.launch()
 
41
  return table_txt
42
 
43
 
44
+ def get_or_create_session_from_state(ui_state):
45
+ if 'session_id' not in ui_state:
46
+ ui_state['session_id'] = pipeline.init_session()
47
+ return ui_state['session_id']
48
+
49
+
50
  def generate_script_fn(instruction, _state: gr.State):
51
  try:
52
+ session_id = get_or_create_session_from_state(_state)
53
  api_key = utils.get_api_key()
54
  json_script = generate_json_file(session_id, instruction, api_key)
55
  table_text = convert_json_to_md(json_script)
 
136
 
137
 
138
  def get_voice_preset_to_list(state: gr.State):
139
+ if state.__class__ == gr.State:
140
+ state = state.value
141
+ if 'session_id' in state:
142
+ path = utils.get_session_voice_preset_path(state['session_id'])
143
  else:
144
+ path = ''
145
  voice_presets = load_voice_presets_metadata(
146
+ path,
147
  safe_if_metadata_not_exist=True
148
  )
149
  dataframe = []
 
200
  else:
201
  count: int = added_voice_preset['count']
202
  # check if greater than 3
203
+ session_id = get_or_create_session_from_state(ui_state)
204
  file_path = file.name
205
  print(f'session {session_id}, id {id}, desc {vp_desc}, file {file_path}')
206
  # Do adding ...
 
406
 
407
  system_voice_presets = get_system_voice_presets()
408
  # State
409
+ ui_state = gr.State({})
410
  selected_voice_presets = gr.State(value={'selected_voice_preset': None})
411
  added_voice_preset_state = gr.State(value={'added_file': None, 'count': 0})
412
  # UI Component
 
565
  # print_state_btn = gr.Button(value='Print State')
566
  # print_state_btn.click(fn=lambda state, state2: print(state, state2), inputs=[ui_state, selected_voice_presets])
567
  interface.queue(concurrency_count=5, max_size=20)
568
+ interface.launch()
utils.py CHANGED
@@ -65,6 +65,15 @@ def fade(audio_data, fade_duration=2, sr=32000):
65
  # config = yaml.safe_load(file)
66
  # return config['OpenAI-Key'] if 'OpenAI-Key' in config else None
67
 
 
 
 
 
 
 
 
 
68
  def get_api_key():
69
- api_key = os.environ.get('OPENAI_KEY')
70
  return api_key
 
 
65
  # config = yaml.safe_load(file)
66
  # return config['OpenAI-Key'] if 'OpenAI-Key' in config else None
67
 
68
+ def get_service_port():
69
+ service_port = os.environ.get('WAVJOURNEY_SERVICE_PORT')
70
+ return service_port
71
+
72
+ def get_service_url():
73
+ service_url = os.environ.get('WAVJOURNEY_SERVICE_URL')
74
+ return service_url
75
+
76
  def get_api_key():
77
+ api_key = os.environ.get('WAVJOURNEY_OPENAI_KEY')
78
  return api_key
79
+
voice_presets.py CHANGED
@@ -11,7 +11,7 @@ def save_voice_presets_metadata(voice_presets_path, metadata):
11
  json.dump(metadata, f, indent=4)
12
 
13
  def load_voice_presets_metadata(voice_presets_path, safe_if_metadata_not_exist=False):
14
- metadata_full_path = voice_presets_path / 'metadata.json'
15
 
16
  if safe_if_metadata_not_exist:
17
  if not os.path.exists(metadata_full_path):
 
11
  json.dump(metadata, f, indent=4)
12
 
13
  def load_voice_presets_metadata(voice_presets_path, safe_if_metadata_not_exist=False):
14
+ metadata_full_path = Path(voice_presets_path) / 'metadata.json'
15
 
16
  if safe_if_metadata_not_exist:
17
  if not os.path.exists(metadata_full_path):
wavjourney_cli.py CHANGED
@@ -24,4 +24,4 @@ if args.full:
24
  pipeline.full_steps(session_id, input_text, api_key)
25
  end_time = time.time()
26
 
27
- print(f"WavJourney took {end_time - start_time:.2f} seconds to complete.")
 
24
  pipeline.full_steps(session_id, input_text, api_key)
25
  end_time = time.time()
26
 
27
+ print(f"WavJourney took {end_time - start_time:.2f} seconds to complete.")