gabrielchua commited on
Commit
506f934
Β·
1 Parent(s): 14ff1d7

update app

Browse files
Files changed (3) hide show
  1. constants.py +1 -3
  2. requirements.txt +9 -3
  3. utils.py +15 -52
constants.py CHANGED
@@ -23,11 +23,9 @@ ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combi
23
 
24
  # Fireworks API-related constants
25
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
26
- FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
27
  FIREWORKS_MAX_TOKENS = 16_384
28
  FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
29
  FIREWORKS_TEMPERATURE = 0.1
30
- FIREWORKS_JSON_RETRY_ATTEMPTS = 3
31
 
32
  # MeloTTS
33
  MELO_API_NAME = "/synthesize"
@@ -80,7 +78,7 @@ UI_DESCRIPTION = """
80
  Generate Podcasts from PDFs using open-source AI.
81
 
82
  Built with:
83
- - [Llama 3.1 405B πŸ¦™](https://huggingface.co/meta-llama/Llama-3.1-405B) via [Fireworks AI πŸŽ†](https://fireworks.ai/)
84
  - [MeloTTS 🐚](https://huggingface.co/myshell-ai/MeloTTS-English)
85
  - [Bark 🐢](https://huggingface.co/suno/bark)
86
  - [Jina Reader πŸ”](https://jina.ai/reader/)
 
23
 
24
  # Fireworks API-related constants
25
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
 
26
  FIREWORKS_MAX_TOKENS = 16_384
27
  FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
28
  FIREWORKS_TEMPERATURE = 0.1
 
29
 
30
  # MeloTTS
31
  MELO_API_NAME = "/synthesize"
 
78
  Generate Podcasts from PDFs using open-source AI.
79
 
80
  Built with:
81
+ - [Llama 3.1 405B πŸ¦™](https://huggingface.co/meta-llama/Llama-3.1-405B) via [Fireworks AI πŸŽ†](https://fireworks.ai/) and [Instructor πŸ“](https://github.com/instructor-ai/instructor)
82
  - [MeloTTS 🐚](https://huggingface.co/myshell-ai/MeloTTS-English)
83
  - [Bark 🐢](https://huggingface.co/suno/bark)
84
  - [Jina Reader πŸ”](https://jina.ai/reader/)
requirements.txt CHANGED
@@ -13,11 +13,13 @@ click==8.1.7
13
  contourpy==1.3.0
14
  cycler==0.12.1
15
  distro==1.9.0
 
16
  einops==0.8.0
17
  encodec==0.1.1
18
  fastapi==0.115.0
19
  ffmpy==0.4.0
20
  filelock==3.16.1
 
21
  fonttools==4.54.1
22
  frozenlist==1.4.1
23
  fsspec==2024.9.0
@@ -28,10 +30,13 @@ granian==1.4.0
28
  h11==0.14.0
29
  httpcore==1.0.5
30
  httpx==0.27.2
 
 
31
  huggingface-hub==0.25.1
32
  idna==3.10
33
  importlib_metadata==8.5.0
34
  importlib_resources==6.4.5
 
35
  Jinja2==3.1.4
36
  jiter==0.5.0
37
  jmespath==1.0.1
@@ -55,8 +60,8 @@ pandas==2.2.3
55
  pillow==10.4.0
56
  promptic==0.7.5
57
  psutil==5.9.8
58
- pydantic==2.7.0
59
- pydantic_core==2.18.1
60
  pydub==0.25.1
61
  Pygments==2.18.0
62
  pyparsing==3.1.4
@@ -85,7 +90,7 @@ spaces==0.30.2
85
  starlette==0.38.6
86
  suno-bark @ git+https://github.com/suno-ai/bark.git@f4f32d4cd480dfec1c245d258174bc9bde3c2148
87
  sympy==1.13.3
88
- tenacity==8.3.0
89
  tiktoken==0.7.0
90
  tokenizers==0.20.0
91
  tomlkit==0.12.0
@@ -100,5 +105,6 @@ urllib3==2.2.3
100
  uvicorn==0.31.0
101
  uvloop==0.18.0
102
  websockets==12.0
 
103
  yarl==1.13.1
104
  zipp==3.20.2
 
13
  contourpy==1.3.0
14
  cycler==0.12.1
15
  distro==1.9.0
16
+ docstring_parser==0.16
17
  einops==0.8.0
18
  encodec==0.1.1
19
  fastapi==0.115.0
20
  ffmpy==0.4.0
21
  filelock==3.16.1
22
+ fireworks-ai==0.15.6
23
  fonttools==4.54.1
24
  frozenlist==1.4.1
25
  fsspec==2024.9.0
 
30
  h11==0.14.0
31
  httpcore==1.0.5
32
  httpx==0.27.2
33
+ httpx-sse==0.4.0
34
+ httpx-ws==0.6.2
35
  huggingface-hub==0.25.1
36
  idna==3.10
37
  importlib_metadata==8.5.0
38
  importlib_resources==6.4.5
39
+ instructor==1.6.2
40
  Jinja2==3.1.4
41
  jiter==0.5.0
42
  jmespath==1.0.1
 
60
  pillow==10.4.0
61
  promptic==0.7.5
62
  psutil==5.9.8
63
+ pydantic==2.9.2
64
+ pydantic_core==2.23.4
65
  pydub==0.25.1
66
  Pygments==2.18.0
67
  pyparsing==3.1.4
 
90
  starlette==0.38.6
91
  suno-bark @ git+https://github.com/suno-ai/bark.git@f4f32d4cd480dfec1c245d258174bc9bde3c2148
92
  sympy==1.13.3
93
+ tenacity==9.0.0
94
  tiktoken==0.7.0
95
  tokenizers==0.20.0
96
  tomlkit==0.12.0
 
105
  uvicorn==0.31.0
106
  uvloop==0.18.0
107
  websockets==12.0
108
+ wsproto==1.2.0
109
  yarl==1.13.1
110
  zipp==3.20.2
utils.py CHANGED
@@ -6,6 +6,9 @@ Functions:
6
  - call_llm: Call the LLM with the given prompt and dialogue format.
7
  - parse_url: Parse the given URL and return the text content.
8
  - generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
 
 
 
9
  """
10
 
11
  # Standard library imports
@@ -13,21 +16,19 @@ import time
13
  from typing import Any, Union
14
 
15
  # Third-party imports
 
16
  import requests
17
  from bark import SAMPLE_RATE, generate_audio, preload_models
 
18
  from gradio_client import Client
19
- from openai import OpenAI
20
- from pydantic import ValidationError
21
  from scipy.io.wavfile import write as write_wav
22
 
23
  # Local imports
24
  from constants import (
25
  FIREWORKS_API_KEY,
26
- FIREWORKS_BASE_URL,
27
  FIREWORKS_MODEL_ID,
28
  FIREWORKS_MAX_TOKENS,
29
  FIREWORKS_TEMPERATURE,
30
- FIREWORKS_JSON_RETRY_ATTEMPTS,
31
  MELO_API_NAME,
32
  MELO_TTS_SPACES_ID,
33
  MELO_RETRY_ATTEMPTS,
@@ -38,8 +39,11 @@ from constants import (
38
  )
39
  from schema import ShortDialogue, MediumDialogue
40
 
41
- # Initialize clients
42
- fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
 
 
 
43
  hf_client = Client(MELO_TTS_SPACES_ID)
44
 
45
  # Download and load all models for Bark
@@ -53,51 +57,13 @@ def generate_script(
53
  ) -> Union[ShortDialogue, MediumDialogue]:
54
  """Get the dialogue from the LLM."""
55
 
56
- # Call the LLM
57
- response = call_llm(system_prompt, input_text, output_model)
58
- response_json = response.choices[0].message.content
59
-
60
- # Validate the response
61
- for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
62
- try:
63
- first_draft_dialogue = output_model.model_validate_json(response_json)
64
- break
65
- except ValidationError as e:
66
- if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
67
- raise ValueError(
68
- f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
69
- ) from e
70
- error_message = (
71
- f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
72
- )
73
- # Re-call the LLM with the error message
74
- system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
75
- response = call_llm(system_prompt_with_error, input_text, output_model)
76
- response_json = response.choices[0].message.content
77
- first_draft_dialogue = output_model.model_validate_json(response_json)
78
 
79
  # Call the LLM a second time to improve the dialogue
80
- system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
 
81
 
82
- # Validate the response
83
- for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
84
- try:
85
- response = call_llm(
86
- system_prompt_with_dialogue,
87
- "Please improve the dialogue. Make it more natural and engaging.",
88
- output_model,
89
- )
90
- final_dialogue = output_model.model_validate_json(
91
- response.choices[0].message.content
92
- )
93
- break
94
- except ValidationError as e:
95
- if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
96
- raise ValueError(
97
- f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
98
- ) from e
99
- error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
100
- system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
101
  return final_dialogue
102
 
103
 
@@ -111,10 +77,7 @@ def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
111
  model=FIREWORKS_MODEL_ID,
112
  max_tokens=FIREWORKS_MAX_TOKENS,
113
  temperature=FIREWORKS_TEMPERATURE,
114
- response_format={
115
- "type": "json_object",
116
- "schema": dialogue_format.model_json_schema(),
117
- },
118
  )
119
  return response
120
 
 
6
  - call_llm: Call the LLM with the given prompt and dialogue format.
7
  - parse_url: Parse the given URL and return the text content.
8
  - generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
9
+ - _use_suno_model: Generate advanced audio using Bark.
10
+ - _use_melotts_api: Generate audio using TTS model.
11
+ - _get_melo_tts_params: Get TTS parameters based on speaker and language.
12
  """
13
 
14
  # Standard library imports
 
16
  from typing import Any, Union
17
 
18
  # Third-party imports
19
+ import instructor
20
  import requests
21
  from bark import SAMPLE_RATE, generate_audio, preload_models
22
+ from fireworks.client import Fireworks
23
  from gradio_client import Client
 
 
24
  from scipy.io.wavfile import write as write_wav
25
 
26
  # Local imports
27
  from constants import (
28
  FIREWORKS_API_KEY,
 
29
  FIREWORKS_MODEL_ID,
30
  FIREWORKS_MAX_TOKENS,
31
  FIREWORKS_TEMPERATURE,
 
32
  MELO_API_NAME,
33
  MELO_TTS_SPACES_ID,
34
  MELO_RETRY_ATTEMPTS,
 
39
  )
40
  from schema import ShortDialogue, MediumDialogue
41
 
42
+ # Initialize Fireworks client, with Instructor patch
43
+ fw_client = Fireworks(api_key=FIREWORKS_API_KEY)
44
+ fw_client = instructor.from_fireworks(fw_client)
45
+
46
+ # Initialize Hugging Face client
47
  hf_client = Client(MELO_TTS_SPACES_ID)
48
 
49
  # Download and load all models for Bark
 
57
  ) -> Union[ShortDialogue, MediumDialogue]:
58
  """Get the dialogue from the LLM."""
59
 
60
+ # Call the LLM for the first time
61
+ first_draft_dialogue = call_llm(system_prompt, input_text, output_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Call the LLM a second time to improve the dialogue
64
+ system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue.model_dump_json()}."
65
+ final_dialogue = call_llm(system_prompt_with_dialogue, "Please improve the dialogue. Make it more natural and engaging.", output_model)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return final_dialogue
68
 
69
 
 
77
  model=FIREWORKS_MODEL_ID,
78
  max_tokens=FIREWORKS_MAX_TOKENS,
79
  temperature=FIREWORKS_TEMPERATURE,
80
+ response_model=dialogue_format,
 
 
 
81
  )
82
  return response
83