import streamlit as st import numpy as np import torch from transformers import Wav2Vec2Processor, Wav2Vec2Model import torchaudio import io # Initialize model and processor @st.cache_resource def load_model(): processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") return processor, model # Audio processing function def process_audio(audio_file, processor, model): # Read audio file audio_bytes = audio_file.read() waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) # Resample if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Process through Wav2Vec2 inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) # Get features from last hidden states features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() return features # Simple genre classifier (we'll use a basic classifier for demonstration) class SimpleGenreClassifier: def __init__(self): self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"] # Simulated learned weights (in real application, these would be trained) self.weights = np.random.randn(768, len(self.genres)) def predict(self, features): # Simple linear classification logits = np.dot(features, self.weights) probabilities = self.softmax(logits) return probabilities @staticmethod def softmax(x): exp_x = np.exp(x - np.max(x)) return exp_x / exp_x.sum() # Page setup st.title("🎵 Music Genre Classifier") st.write("Upload an audio file to analyze its genre using Wav2Vec2") # Load models try: with st.spinner("Loading models..."): processor, wav2vec_model = load_model() classifier = SimpleGenreClassifier() st.success("Models loaded successfully!") except Exception as e: st.error(f"Error loading models: {str(e)}") st.stop() # Create two columns for layout col1, col2 = st.columns(2) with col1: # File upload audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav']) if audio_file is not None: # Display audio player st.audio(audio_file) st.success("File uploaded successfully!") # Add classify button if st.button("Classify Genre"): try: with st.spinner("Analyzing audio..."): # Extract features using Wav2Vec2 features = process_audio(audio_file, processor, wav2vec_model) # Get genre predictions probabilities = classifier.predict(features) # Show results st.write("### Genre Analysis Results:") for genre, prob in zip(classifier.genres, probabilities): # Create a progress bar for each genre st.write(f"{genre}:") st.progress(float(prob)) st.write(f"{prob:.2%}") # Show top prediction top_genre = classifier.genres[np.argmax(probabilities)] st.write(f"**Predicted Genre:** {top_genre}") except Exception as e: st.error(f"Error during analysis: {str(e)}") with col2: # Display information about the model st.write("### About the Model:") st.write(""" This classifier uses: - Facebook's Wav2Vec2 for audio feature extraction - Custom genre classification layer - Pre-trained on speech recognition """) st.write("### Supported Genres:") for genre in classifier.genres: st.write(f"- {genre}") # Add usage tips st.write("### Tips for best results:") st.write("- Upload clear, high-quality audio") st.write("- Ideal length: 10-30 seconds") st.write("- Avoid audio with multiple overlapping genres") st.write("- Ensure minimal background noise") # Update requirements.txt if st.sidebar.checkbox("Show requirements.txt contents"): st.sidebar.code(""" streamlit==1.31.0 torch==2.0.1 torchaudio==2.0.1 transformers==4.30.2 numpy==1.24.3 """) # Footer st.markdown("---") st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")