ASR / app.py
Kr08's picture
Update app.py
375457e verified
raw
history blame
5.79 kB
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import logging
import traceback
import sys
from audio_processing import AudioProcessor
import spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
def load_qa_model():
"""Load question-answering model"""
try:
qa_pipeline = pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3-8B-Instruct",
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
return qa_pipeline
except Exception as e:
logger.error(f"Failed to load Q&A model: {str(e)}")
return None
def load_summarization_model():
"""Load summarization model"""
try:
summarizer = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-12-6",
device=0 if torch.cuda.is_available() else -1
)
return summarizer
except Exception as e:
logger.error(f"Failed to load summarization model: {str(e)}")
return None
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False):
"""Process audio file"""
try:
processor = AudioProcessor()
language_segments, final_segments = processor.process_audio(audio_file, translate)
# Format output
transcription = ""
full_text = ""
# Add language detection information
for segment in language_segments:
transcription += f"Language: {segment['language']}\n"
transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
# Add transcription/translation information
transcription += "Transcription with language detection:\n\n"
for segment in final_segments:
transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n"
transcription += f"Original: {segment['text']}\n"
if translate and 'translated' in segment:
transcription += f"Translated: {segment['translated']}\n"
full_text += segment['translated'] + " "
else:
full_text += segment['text'] + " "
transcription += "\n"
return transcription, full_text
except Exception as e:
logger.error(f"Audio processing failed: {str(e)}")
raise gr.Error(f"Processing failed: {str(e)}")
@spaces.GPU(duration=60)
def summarize_text(text):
"""Summarize text"""
try:
summarizer = load_summarization_model()
if summarizer is None:
return "Summarization model could not be loaded."
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
return summary
except Exception as e:
logger.error(f"Summarization failed: {str(e)}")
return "Error occurred during summarization."
@spaces.GPU(duration=60)
def answer_question(context, question):
"""Answer questions about the text"""
try:
qa_pipeline = load_qa_model()
if qa_pipeline is None:
return "Q&A model could not be loaded."
messages = [
{"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."},
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
]
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text']
return response
except Exception as e:
logger.error(f"Q&A failed: {str(e)}")
return f"Error occurred during Q&A process: {str(e)}"
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Automatic Speech Recognition for Indic Languages")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath")
translate_checkbox = gr.Checkbox(label="Enable Translation")
process_button = gr.Button("Process Audio")
with gr.Column():
transcription_output = gr.Textbox(label="Transcription/Translation", lines=10)
full_text_output = gr.Textbox(label="Full Text", lines=5)
with gr.Row():
with gr.Column():
summarize_button = gr.Button("Summarize")
summary_output = gr.Textbox(label="Summary", lines=3)
with gr.Column():
question_input = gr.Textbox(label="Ask a question about the transcription")
answer_button = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", lines=3)
# Set up event handlers
process_button.click(
process_audio,
inputs=[audio_input, translate_checkbox],
outputs=[transcription_output, full_text_output]
)
summarize_button.click(
summarize_text,
inputs=[full_text_output],
outputs=[summary_output]
)
answer_button.click(
answer_question,
inputs=[full_text_output, question_input],
outputs=[answer_output]
)
# Add system information
gr.Markdown(f"""
## System Information
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"}
## Features
- Automatic language detection
- High-quality transcription using MMS
- Optional translation to English
- Text summarization
- Question answering
""")
if __name__ == "__main__":
iface.launch(server_port=None)