Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
import logging | |
import traceback | |
import sys | |
from audio_processing import AudioProcessor | |
import spaces | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
def load_qa_model(): | |
"""Load question-answering model""" | |
try: | |
qa_pipeline = pipeline( | |
"text-generation", | |
model="meta-llama/Meta-Llama-3-8B-Instruct", | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
return qa_pipeline | |
except Exception as e: | |
logger.error(f"Failed to load Q&A model: {str(e)}") | |
return None | |
def load_summarization_model(): | |
"""Load summarization model""" | |
try: | |
summarizer = pipeline( | |
"summarization", | |
model="sshleifer/distilbart-cnn-12-6", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
return summarizer | |
except Exception as e: | |
logger.error(f"Failed to load summarization model: {str(e)}") | |
return None | |
def process_audio(audio_file, translate=False): | |
"""Process audio file""" | |
try: | |
processor = AudioProcessor() | |
language_segments, final_segments = processor.process_audio(audio_file, translate) | |
# Format output | |
transcription = "" | |
full_text = "" | |
# Add language detection information | |
for segment in language_segments: | |
transcription += f"Language: {segment['language']}\n" | |
transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n" | |
# Add transcription/translation information | |
transcription += "Transcription with language detection:\n\n" | |
for segment in final_segments: | |
transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n" | |
transcription += f"Original: {segment['text']}\n" | |
if translate and 'translated' in segment: | |
transcription += f"Translated: {segment['translated']}\n" | |
full_text += segment['translated'] + " " | |
else: | |
full_text += segment['text'] + " " | |
transcription += "\n" | |
return transcription, full_text | |
except Exception as e: | |
logger.error(f"Audio processing failed: {str(e)}") | |
raise gr.Error(f"Processing failed: {str(e)}") | |
def summarize_text(text): | |
"""Summarize text""" | |
try: | |
summarizer = load_summarization_model() | |
if summarizer is None: | |
return "Summarization model could not be loaded." | |
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text'] | |
return summary | |
except Exception as e: | |
logger.error(f"Summarization failed: {str(e)}") | |
return "Error occurred during summarization." | |
def answer_question(context, question): | |
"""Answer questions about the text""" | |
try: | |
qa_pipeline = load_qa_model() | |
if qa_pipeline is None: | |
return "Q&A model could not be loaded." | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."}, | |
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"} | |
] | |
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text'] | |
return response | |
except Exception as e: | |
logger.error(f"Q&A failed: {str(e)}") | |
return f"Error occurred during Q&A process: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# Automatic Speech Recognition for Indic Languages") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="filepath") | |
translate_checkbox = gr.Checkbox(label="Enable Translation") | |
process_button = gr.Button("Process Audio") | |
with gr.Column(): | |
transcription_output = gr.Textbox(label="Transcription/Translation", lines=10) | |
full_text_output = gr.Textbox(label="Full Text", lines=5) | |
with gr.Row(): | |
with gr.Column(): | |
summarize_button = gr.Button("Summarize") | |
summary_output = gr.Textbox(label="Summary", lines=3) | |
with gr.Column(): | |
question_input = gr.Textbox(label="Ask a question about the transcription") | |
answer_button = gr.Button("Get Answer") | |
answer_output = gr.Textbox(label="Answer", lines=3) | |
# Set up event handlers | |
process_button.click( | |
process_audio, | |
inputs=[audio_input, translate_checkbox], | |
outputs=[transcription_output, full_text_output] | |
) | |
summarize_button.click( | |
summarize_text, | |
inputs=[full_text_output], | |
outputs=[summary_output] | |
) | |
answer_button.click( | |
answer_question, | |
inputs=[full_text_output, question_input], | |
outputs=[answer_output] | |
) | |
# Add system information | |
gr.Markdown(f""" | |
## System Information | |
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"} | |
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} | |
## Features | |
- Automatic language detection | |
- High-quality transcription using MMS | |
- Optional translation to English | |
- Text summarization | |
- Question answering | |
""") | |
if __name__ == "__main__": | |
iface.launch(server_port=None) |