Spaces:
Running
Running
import os | |
import subprocess | |
import random | |
import numpy as np | |
import json | |
from datetime import timedelta | |
import tempfile | |
import re | |
import gradio as gr | |
import groq | |
from groq import Groq | |
import io | |
import soundfile as sf | |
# setup groq | |
client = Groq(api_key=os.environ.get("Groq_Api_Key")) | |
def transcribe_audio(audio): | |
if audio is None: | |
return "" | |
client = groq.Client(api_key=os.environ.get("Groq_Api_Key")) | |
# Convert audio to the format expected by the model | |
# The model supports mp3, mp4, mpeg, mpga, m4a, wav, and webm file types | |
audio_data = audio[1] # Get the numpy array from the tuple | |
buffer = io.BytesIO() | |
sf.write(buffer, audio_data, audio[0], format='wav') | |
buffer.seek(0) | |
bytes_audio = io.BytesIO() | |
np.save(bytes_audio, audio_data) | |
bytes_audio.seek(0) | |
try: | |
# Use Distil-Whisper English powered by Groq for transcription | |
completion = client.audio.transcriptions.create( | |
model="distil-whisper-large-v3-en", | |
file=("audio.wav", buffer), | |
response_format="text" | |
) | |
return completion | |
except Exception as e: | |
return f"Error in transcription: {str(e)}" | |
def generate_response(transcription, api_key): | |
if not transcription: | |
return "No transcription available. Please try speaking again." | |
client = groq.Client(api_key=api_key) | |
try: | |
# Use Llama 3 70B powered by Groq for text generation | |
completion = client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": transcription} | |
], | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
return f"Error in response generation: {str(e)}" | |
def process_audio(audio, api_key): | |
if not api_key: | |
return "Please enter your Groq API key.", "API key is required." | |
transcription = transcribe_audio(audio, api_key) | |
response = generate_response(transcription, api_key) | |
return transcription, response | |
def handle_groq_error(e, model_name): | |
error_data = e.args[0] | |
if isinstance(error_data, str): | |
# Use regex to extract the JSON part of the string | |
json_match = re.search(r'(\{.*\})', error_data) | |
if json_match: | |
json_str = json_match.group(1) | |
# Ensure the JSON string is well-formed | |
json_str = json_str.replace("'", '"') # Replace single quotes with double quotes | |
error_data = json.loads(json_str) | |
if isinstance(e, groq.RateLimitError): | |
if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']: | |
error_message = error_data['error']['message'] | |
raise gr.Error(error_message) | |
else: | |
raise gr.Error(f"Error during Groq API call: {e}") | |
# llms | |
MAX_SEED = np.iinfo(np.int32).max | |
def update_max_tokens(model): | |
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]: | |
return gr.update(maximum=8192) | |
elif model == "mixtral-8x7b-32768": | |
return gr.update(maximum=32768) | |
def create_history_messages(history): | |
history_messages = [{"role": "user", "content": m[0]} for m in history] | |
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history]) | |
return history_messages | |
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed): | |
messages = create_history_messages(history) | |
messages.append({"role": "user", "content": prompt}) | |
print(messages) | |
if seed == 0: | |
seed = random.randint(1, MAX_SEED) | |
try: | |
stream = client.chat.completions.create( | |
messages=messages, | |
model=model, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
seed=seed, | |
stop=None, | |
stream=True, | |
) | |
response = "" | |
for chunk in stream: | |
delta_content = chunk.choices[0].delta.content | |
if delta_content is not None: | |
response += delta_content | |
yield response | |
return response | |
except Groq.GroqApiException as e: | |
handle_groq_error(e, model) | |
# speech to text | |
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] | |
MAX_FILE_SIZE_MB = 25 | |
CHUNK_SIZE_MB = 25 | |
LANGUAGE_CODES = { | |
"English": "en", | |
"Chinese": "zh", | |
"German": "de", | |
"Spanish": "es", | |
"Russian": "ru", | |
"Korean": "ko", | |
"French": "fr", | |
"Japanese": "ja", | |
"Portuguese": "pt", | |
"Turkish": "tr", | |
"Polish": "pl", | |
"Catalan": "ca", | |
"Dutch": "nl", | |
"Arabic": "ar", | |
"Swedish": "sv", | |
"Italian": "it", | |
"Indonesian": "id", | |
"Hindi": "hi", | |
"Finnish": "fi", | |
"Vietnamese": "vi", | |
"Hebrew": "he", | |
"Ukrainian": "uk", | |
"Greek": "el", | |
"Malay": "ms", | |
"Czech": "cs", | |
"Romanian": "ro", | |
"Danish": "da", | |
"Hungarian": "hu", | |
"Tamil": "ta", | |
"Norwegian": "no", | |
"Thai": "th", | |
"Urdu": "ur", | |
"Croatian": "hr", | |
"Bulgarian": "bg", | |
"Lithuanian": "lt", | |
"Latin": "la", | |
"Māori": "mi", | |
"Malayalam": "ml", | |
"Welsh": "cy", | |
"Slovak": "sk", | |
"Telugu": "te", | |
"Persian": "fa", | |
"Latvian": "lv", | |
"Bengali": "bn", | |
"Serbian": "sr", | |
"Azerbaijani": "az", | |
"Slovenian": "sl", | |
"Kannada": "kn", | |
"Estonian": "et", | |
"Macedonian": "mk", | |
"Breton": "br", | |
"Basque": "eu", | |
"Icelandic": "is", | |
"Armenian": "hy", | |
"Nepali": "ne", | |
"Mongolian": "mn", | |
"Bosnian": "bs", | |
"Kazakh": "kk", | |
"Albanian": "sq", | |
"Swahili": "sw", | |
"Galician": "gl", | |
"Marathi": "mr", | |
"Panjabi": "pa", | |
"Sinhala": "si", | |
"Khmer": "km", | |
"Shona": "sn", | |
"Yoruba": "yo", | |
"Somali": "so", | |
"Afrikaans": "af", | |
"Occitan": "oc", | |
"Georgian": "ka", | |
"Belarusian": "be", | |
"Tajik": "tg", | |
"Sindhi": "sd", | |
"Gujarati": "gu", | |
"Amharic": "am", | |
"Yiddish": "yi", | |
"Lao": "lo", | |
"Uzbek": "uz", | |
"Faroese": "fo", | |
"Haitian": "ht", | |
"Pashto": "ps", | |
"Turkmen": "tk", | |
"Norwegian Nynorsk": "nn", | |
"Maltese": "mt", | |
"Sanskrit": "sa", | |
"Luxembourgish": "lb", | |
"Burmese": "my", | |
"Tibetan": "bo", | |
"Tagalog": "tl", | |
"Malagasy": "mg", | |
"Assamese": "as", | |
"Tatar": "tt", | |
"Hawaiian": "haw", | |
"Lingala": "ln", | |
"Hausa": "ha", | |
"Bashkir": "ba", | |
"jw": "jw", | |
"Sundanese": "su", | |
} | |
def split_audio(audio_file_path, chunk_size_mb): | |
chunk_size = chunk_size_mb * 1024 * 1024 # Convert MB to bytes | |
file_number = 1 | |
chunks = [] | |
with open(audio_file_path, 'rb') as f: | |
chunk = f.read(chunk_size) | |
while chunk: | |
chunk_name = f"{os.path.splitext(audio_file_path)[0]}_part{file_number:03}.mp3" # Pad file number for correct ordering | |
with open(chunk_name, 'wb') as chunk_file: | |
chunk_file.write(chunk) | |
chunks.append(chunk_name) | |
file_number += 1 | |
chunk = f.read(chunk_size) | |
return chunks | |
def merge_audio(chunks, output_file_path): | |
with open("temp_list.txt", "w") as f: | |
for file in chunks: | |
f.write(f"file '{file}'\n") | |
try: | |
subprocess.run( | |
[ | |
"ffmpeg", | |
"-f", | |
"concat", | |
"-safe", "0", | |
"-i", | |
"temp_list.txt", | |
"-c", | |
"copy", | |
"-y", | |
output_file_path | |
], | |
check=True | |
) | |
os.remove("temp_list.txt") | |
for chunk in chunks: | |
os.remove(chunk) | |
except subprocess.CalledProcessError as e: | |
raise gr.Error(f"Error during audio merging: {e}") | |
# Checks file extension, size, and downsamples or splits if needed. | |
def check_file(audio_file_path): | |
if not audio_file_path: | |
raise gr.Error("Please upload an audio file.") | |
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024) | |
file_extension = audio_file_path.split(".")[-1].lower() | |
if file_extension not in ALLOWED_FILE_EXTENSIONS: | |
raise gr.Error(f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}") | |
if file_size_mb > MAX_FILE_SIZE_MB: | |
gr.Warning( | |
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz MP3 128kbps. Maximum size allowed: {MAX_FILE_SIZE_MB} MB" | |
) | |
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.mp3" | |
try: | |
subprocess.run( | |
[ | |
"ffmpeg", | |
"-i", | |
audio_file_path, | |
"-ar", | |
"16000", | |
"-ab", | |
"128k", | |
"-ac", | |
"1", | |
"-f", | |
"mp3", | |
"-y", | |
output_file_path, | |
], | |
check=True | |
) | |
# Check size after downsampling | |
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024) | |
if downsampled_size_mb > MAX_FILE_SIZE_MB: | |
gr.Warning(f"File still too large after downsampling ({downsampled_size_mb:.2f} MB). Splitting into {CHUNK_SIZE_MB} MB chunks.") | |
return split_audio(output_file_path, CHUNK_SIZE_MB), "split" | |
return output_file_path, None | |
except subprocess.CalledProcessError as e: | |
raise gr.Error(f"Error during downsampling: {e}") | |
return audio_file_path, None | |
def transcribe_audio(audio_file_path, model, prompt, language, auto_detect_language): | |
processed_path, split_status = check_file(audio_file_path) | |
full_transcription = "" | |
if split_status == "split": | |
processed_chunks = [] | |
for i, chunk_path in enumerate(processed_path): | |
try: | |
with open(chunk_path, "rb") as file: | |
transcription = client.audio.transcriptions.create( | |
file=(os.path.basename(chunk_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
language=None if auto_detect_language else language, | |
temperature=0.0, | |
) | |
full_transcription += transcription | |
processed_chunks.append(chunk_path) | |
except groq.RateLimitError as e: # Handle rate limit error | |
handle_groq_error(e, model) | |
gr.Warning(f"API limit reached during chunk {i+1}. Returning processed chunks only.") | |
if processed_chunks: | |
merge_audio(processed_chunks, 'merged_output.mp3') | |
return full_transcription, 'merged_output.mp3' | |
else: | |
return "Transcription failed due to API limits.", None | |
merge_audio(processed_path, 'merged_output.mp3') | |
return full_transcription, 'merged_output.mp3' | |
else: | |
try: | |
with open(processed_path, "rb") as file: | |
transcription = client.audio.transcriptions.create( | |
file=(os.path.basename(processed_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
language=None if auto_detect_language else language, | |
temperature=0.0, | |
) | |
return transcription, None | |
except groq.RateLimitError as e: # Handle rate limit error | |
handle_groq_error(e, model) | |
def translate_audio(audio_file_path, model, prompt): | |
processed_path, split_status = check_file(audio_file_path) | |
full_translation = "" | |
if split_status == "split": | |
for chunk_path in processed_path: | |
try: | |
with open(chunk_path, "rb") as file: | |
translation = client.audio.translations.create( | |
file=(os.path.basename(chunk_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
temperature=0.0, | |
) | |
full_translation += translation | |
except Groq.GroqApiException as e: | |
handle_groq_error(e, model) | |
return f"API limit reached. Partial translation: {full_translation}" | |
return full_translation | |
else: | |
try: | |
with open(processed_path, "rb") as file: | |
translation = client.audio.translations.create( | |
file=(os.path.basename(processed_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
temperature=0.0, | |
) | |
return translation | |
except Groq.GroqApiException as e: | |
handle_groq_error(e, model) | |
with gr.Blocks(theme="Hev832/niceandsimple") as interface: | |
gr.Markdown( | |
""" | |
# Groq API UI | |
Inference by Groq API | |
If you are having API Rate Limit issues, you can retry later based on the [rate limits](https://console.groq.com/docs/rate-limits) or <a href="https://huggingface.co/spaces/Nick088/Fast-Subtitle-Maker?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> with <a href=https://console.groq.com/keys>your own API Key</a> </p> | |
Hugging Face Space by [Nick088](https://linktr.ee/Nick088) | |
<br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a> | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Speech To Text"): | |
with gr.Tabs(): | |
with gr.TabItem("Transcription"): | |
gr.Markdown("Transcript audio from files to text!") | |
with gr.Row(): | |
audio_input = gr.File( | |
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] | |
) | |
model_choice_transcribe = gr.Dropdown( | |
choices=["whisper-large-v3"], | |
value="whisper-large-v3", | |
label="Model", | |
) | |
with gr.Row(): | |
transcribe_prompt = gr.Textbox( | |
label="Prompt (Optional)", | |
info="Specify any context or spelling corrections.", | |
) | |
with gr.Column(): | |
language = gr.Dropdown( | |
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()], | |
value="en", | |
label="Language", | |
) | |
auto_detect_language = gr.Checkbox(label="Auto Detect Language") | |
transcribe_button = gr.Button("Transcribe") | |
transcription_output = gr.Textbox(label="Transcription") | |
merged_audio_output = gr.File(label="Merged Audio (if chunked)") | |
transcribe_button.click( | |
transcribe_audio, | |
inputs=[audio_input, model_choice_transcribe, transcribe_prompt, language, auto_detect_language], | |
outputs=[transcription_output, merged_audio_output], | |
) | |
with gr.TabItem("Translation"): | |
gr.Markdown("Transcript audio from files and translate them to English text!") | |
with gr.Row(): | |
audio_input_translate = gr.File( | |
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] | |
) | |
model_choice_translate = gr.Dropdown( | |
choices=["whisper-large-v3"], | |
value="whisper-large-v3", | |
label="Audio Speech Recognition (ASR) Model", | |
) | |
with gr.Row(): | |
translate_prompt = gr.Textbox( | |
label="Prompt (Optional)", | |
info="Specify any context or spelling corrections.", | |
) | |
translate_button = gr.Button("Translate") | |
translation_output = gr.Textbox(label="Translation") | |
translate_button.click( | |
translate_audio, | |
inputs=[audio_input_translate, model_choice_translate, translate_prompt], | |
outputs=translation_output, | |
) | |
with gr.TabItem("LLMs"): | |
with gr.Tab("Chat"): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=250): | |
model = gr.Dropdown( | |
choices=[ | |
"llama3-70b-8192", | |
"llama3-8b-8192", | |
"mixtral-8x7b-32768", | |
"gemma-7b-it", | |
"gemma2-9b-it", | |
], | |
value="llama3-70b-8192", | |
label="Model", | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Temperature", | |
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.", | |
) | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=8192, | |
step=1, | |
value=4096, | |
label="Max Tokens", | |
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Top P", | |
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.", | |
) | |
seed = gr.Number( | |
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random" | |
) | |
model.change(update_max_tokens, inputs=[model], outputs=max_tokens) | |
with gr.Column(scale=1, min_width=400): | |
chatbot = gr.ChatInterface( | |
fn=generate_response, | |
chatbot=None, | |
additional_inputs=[ | |
model, | |
temperature, | |
max_tokens, | |
top_p, | |
seed, | |
], | |
) | |
model.change( | |
update_max_tokens, | |
inputs=[ | |
model, | |
], | |
outputs=max_tokens, | |
) | |
with gr.Tab("Voice-Powered AI Assistant"): | |
with gr.Row(): | |
audio_input = gr.Audio(label="Speak!", type="numpy") | |
with gr.Row(): | |
transcription_output = gr.Textbox(label="Transcription") | |
response_output = gr.Textbox(label="AI Assistant Response") | |
submit_button = gr.Button("Process", variant="primary") | |
submit_button.click( | |
process_audio, | |
inputs=[audio_input], | |
outputs=[transcription_output, response_output] | |
) | |
interface.launch(share=True) |