|
--- |
|
datasets: |
|
- lewtun/music_genres_small |
|
base_model: |
|
- facebook/wav2vec2-large |
|
metrics: |
|
- accuracy |
|
- f1 |
|
tags: |
|
- audio |
|
- music |
|
- classification |
|
- Wav2Vec2 |
|
pipeline_tag: audio-classification |
|
--- |
|
|
|
# Music Genre Classification Model 🎶 |
|
This model classifies music genres based on audio signals (.wav). |
|
It was fine-tuned on the model **[Wav2Vec2](https://huggingface.co/facebook/wav2vec2-large)** and using the datasets **[music_genres_small](https://huggingface.co/datasets/lewtun/music_genres_small)**. |
|
|
|
You can find a **GitHub** repository with an interface hosted by a Flask API to test the model: **[music-classifier repository](https://github.com/gastonduault/Music-Classifier)** |
|
|
|
## Metrics |
|
- **Validation Accuracy**: 75% |
|
- **F1 Score**: 74% |
|
- **Validation Loss**: 0.77 |
|
|
|
## Example Usage |
|
```python |
|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor |
|
import librosa |
|
import torch |
|
|
|
# Genre mapping corrected to a dictionary |
|
genre_mapping = { |
|
0: "Electronic", |
|
1: "Rock", |
|
2: "Punk", |
|
3: "Experimental", |
|
4: "Hip-Hop", |
|
5: "Folk", |
|
6: "Chiptune / Glitch", |
|
7: "Instrumental", |
|
8: "Pop", |
|
9: "International", |
|
} |
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier") |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large") |
|
|
|
# Function for preprocessing audio for prediction |
|
def preprocess_audio(audio_path): |
|
audio_array, sampling_rate = librosa.load(audio_path, sr=16000) |
|
return feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) |
|
|
|
# Path to your audio file |
|
audio_path = "./Nirvana - Come As You Are.wav" |
|
|
|
# Preprocess audio |
|
inputs = preprocess_audio(audio_path) |
|
|
|
# Predict |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
|
|
# Output the result |
|
print(f"song analized:{audio_path}") |
|
print(f"Predicted genre: {genre_mapping[predicted_class]}") |