azeus commited on
Commit
1e2e376
·
1 Parent(s): c00ec95

adding fb model

Browse files
Files changed (2) hide show
  1. app.py +116 -26
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,11 +1,78 @@
1
  import streamlit as st
2
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Page setup
5
  st.title("🎵 Music Genre Classifier")
6
- st.write("Upload an audio file to analyze its genre")
7
 
8
- # Create two columns for better layout
 
 
 
 
 
 
 
 
 
 
9
  col1, col2 = st.columns(2)
10
 
11
  with col1:
@@ -17,38 +84,61 @@ with col1:
17
  st.audio(audio_file)
18
  st.success("File uploaded successfully!")
19
 
20
- # Add a classify button
21
  if st.button("Classify Genre"):
22
- with st.spinner("Analyzing..."):
23
- # Simulate genre classification (we'll replace this with real model later)
24
- genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
25
- confidences = np.random.dirichlet(np.ones(5)) # Random probabilities that sum to 1
 
 
 
26
 
27
- # Show results
28
- st.write("### Genre Analysis Results:")
29
- for genre, confidence in zip(genres, confidences):
30
- st.write(f"{genre}: {confidence:.2%}")
 
 
 
31
 
32
- # Show top prediction
33
- top_genre = genres[np.argmax(confidences)]
34
- st.write(f"**Predicted Genre:** {top_genre}")
 
 
35
 
36
  with col2:
37
- # Display some tips and information
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  st.write("### Tips for best results:")
39
- st.write("- Upload files in MP3 or WAV format")
40
- st.write("- Ensure good audio quality")
41
- st.write("- Try to upload songs without too much background noise")
42
  st.write("- Ideal length: 10-30 seconds")
 
 
43
 
44
- # Add a sample counter
45
- if 'analyzed_count' not in st.session_state:
46
- st.session_state.analyzed_count = 0
47
-
48
- if audio_file is not None:
49
- st.session_state.analyzed_count += 1
50
- st.write(f"Songs analyzed this session: {st.session_state.analyzed_count}")
 
 
51
 
52
  # Footer
53
  st.markdown("---")
54
- st.write("Made with ❤️ using Streamlit")
 
1
  import streamlit as st
2
  import numpy as np
3
+ import torch
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
5
+ import torchaudio
6
+ import io
7
+
8
+
9
+ # Initialize model and processor
10
+ @st.cache_resource
11
+ def load_model():
12
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
13
+ model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
14
+ return processor, model
15
+
16
+
17
+ # Audio processing function
18
+ def process_audio(audio_file, processor, model):
19
+ # Read audio file
20
+ audio_bytes = audio_file.read()
21
+ waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
22
+
23
+ # Resample if needed
24
+ if sample_rate != 16000:
25
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
26
+ waveform = resampler(waveform)
27
+
28
+ # Convert to mono if stereo
29
+ if waveform.shape[0] > 1:
30
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
31
+
32
+ # Process through Wav2Vec2
33
+ inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+
37
+ # Get features from last hidden states
38
+ features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
39
+ return features
40
+
41
+
42
+ # Simple genre classifier (we'll use a basic classifier for demonstration)
43
+ class SimpleGenreClassifier:
44
+ def __init__(self):
45
+ self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
46
+ # Simulated learned weights (in real application, these would be trained)
47
+ self.weights = np.random.randn(768, len(self.genres))
48
+
49
+ def predict(self, features):
50
+ # Simple linear classification
51
+ logits = np.dot(features, self.weights)
52
+ probabilities = self.softmax(logits)
53
+ return probabilities
54
+
55
+ @staticmethod
56
+ def softmax(x):
57
+ exp_x = np.exp(x - np.max(x))
58
+ return exp_x / exp_x.sum()
59
+
60
 
61
  # Page setup
62
  st.title("🎵 Music Genre Classifier")
63
+ st.write("Upload an audio file to analyze its genre using Wav2Vec2")
64
 
65
+ # Load models
66
+ try:
67
+ with st.spinner("Loading models..."):
68
+ processor, wav2vec_model = load_model()
69
+ classifier = SimpleGenreClassifier()
70
+ st.success("Models loaded successfully!")
71
+ except Exception as e:
72
+ st.error(f"Error loading models: {str(e)}")
73
+ st.stop()
74
+
75
+ # Create two columns for layout
76
  col1, col2 = st.columns(2)
77
 
78
  with col1:
 
84
  st.audio(audio_file)
85
  st.success("File uploaded successfully!")
86
 
87
+ # Add classify button
88
  if st.button("Classify Genre"):
89
+ try:
90
+ with st.spinner("Analyzing audio..."):
91
+ # Extract features using Wav2Vec2
92
+ features = process_audio(audio_file, processor, wav2vec_model)
93
+
94
+ # Get genre predictions
95
+ probabilities = classifier.predict(features)
96
 
97
+ # Show results
98
+ st.write("### Genre Analysis Results:")
99
+ for genre, prob in zip(classifier.genres, probabilities):
100
+ # Create a progress bar for each genre
101
+ st.write(f"{genre}:")
102
+ st.progress(float(prob))
103
+ st.write(f"{prob:.2%}")
104
 
105
+ # Show top prediction
106
+ top_genre = classifier.genres[np.argmax(probabilities)]
107
+ st.write(f"**Predicted Genre:** {top_genre}")
108
+ except Exception as e:
109
+ st.error(f"Error during analysis: {str(e)}")
110
 
111
  with col2:
112
+ # Display information about the model
113
+ st.write("### About the Model:")
114
+ st.write("""
115
+ This classifier uses:
116
+ - Facebook's Wav2Vec2 for audio feature extraction
117
+ - Custom genre classification layer
118
+ - Pre-trained on speech recognition
119
+ """)
120
+
121
+ st.write("### Supported Genres:")
122
+ for genre in classifier.genres:
123
+ st.write(f"- {genre}")
124
+
125
+ # Add usage tips
126
  st.write("### Tips for best results:")
127
+ st.write("- Upload clear, high-quality audio")
 
 
128
  st.write("- Ideal length: 10-30 seconds")
129
+ st.write("- Avoid audio with multiple overlapping genres")
130
+ st.write("- Ensure minimal background noise")
131
 
132
+ # Update requirements.txt
133
+ if st.sidebar.checkbox("Show requirements.txt contents"):
134
+ st.sidebar.code("""
135
+ streamlit==1.31.0
136
+ torch==2.0.1
137
+ torchaudio==2.0.1
138
+ transformers==4.30.2
139
+ numpy==1.24.3
140
+ """)
141
 
142
  # Footer
143
  st.markdown("---")
144
+ st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  streamlit==1.31.0
 
 
 
2
  numpy==1.24.3
 
1
  streamlit==1.31.0
2
+ torch==2.0.1
3
+ torchaudio==2.0.1
4
+ transformers==4.30.2
5
  numpy==1.24.3