gastonduault commited on
Commit
d77d9c5
·
1 Parent(s): 02e980a

add predict example

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
5
+ <inspection_tool class="PyInterpreterInspection" enabled="false" level="WARNING" enabled_by_default="false" />
6
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
7
+ <option name="ignoredPackages">
8
+ <value>
9
+ <list size="4">
10
+ <item index="0" class="java.lang.String" itemvalue="json" />
11
+ <item index="1" class="java.lang.String" itemvalue="customtkinter" />
12
+ <item index="2" class="java.lang.String" itemvalue="pytest" />
13
+ <item index="3" class="java.lang.String" itemvalue="csv" />
14
+ </list>
15
+ </value>
16
+ </option>
17
+ </inspection_tool>
18
+ <inspection_tool class="SqlDialectInspection" enabled="false" level="WARNING" enabled_by_default="false" />
19
+ <inspection_tool class="SqlNoDataSourceInspection" enabled="false" level="WARNING" enabled_by_default="false" />
20
+ </profile>
21
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/material_theme_project_new.xml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="MaterialThemeProjectNewConfig">
4
+ <option name="metadata">
5
+ <MTProjectMetadataState>
6
+ <option name="migrated" value="true" />
7
+ <option name="pristineConfig" value="false" />
8
+ <option name="userId" value="78c19164:19192ff6e6a:-7ffe" />
9
+ </MTProjectMetadataState>
10
+ </option>
11
+ </component>
12
+ </project>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12 (astronomy-research-my3dsky-ec1ed317f901)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (astronomy-research-my3dsky-ec1ed317f901)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/music-classifier.iml" filepath="$PROJECT_DIR$/.idea/music-classifier.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/music-classifier.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
predict-example.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ import librosa
5
+ import torch
6
+
7
+ # Paths
8
+ MODEL_DIR = "./wav2vec_trained_model"
9
+
10
+ # Load the dataset
11
+ dataset = load_dataset("lewtun/music_genres_small")
12
+
13
+ # Retrieve the label names
14
+ genre_mapping = {}
15
+ for example in dataset["train"]:
16
+ genre_id = example["genre_id"]
17
+ genre = example["genre"]
18
+ if genre_id not in genre_mapping:
19
+ genre_mapping[genre_id] = genre
20
+ if len(genre_mapping) == 9:
21
+ break
22
+
23
+ print(f"Loading model from {MODEL_DIR}...\n")
24
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("gastoooon/music-classifier")
25
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
26
+
27
+ # Function for preprocessing audio for prediction
28
+ def preprocess_audio(audio_path, target_length=16000 * 180): # 30 seconds at 16kHz
29
+ audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
30
+
31
+ if len(audio_array) > target_length:
32
+ audio_array = audio_array[:target_length]
33
+ else:
34
+ padding = target_length - len(audio_array)
35
+ audio_array = np.pad(audio_array, (0, padding), "constant")
36
+
37
+ inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
38
+ return inputs
39
+
40
+
41
+ # Path to your audio file
42
+ audio_path = "./Nirvana - Come As You Are.wav"
43
+
44
+
45
+ # Preprocess audio
46
+ inputs = preprocess_audio(audio_path)
47
+
48
+ # Predict
49
+ with torch.no_grad():
50
+ logits = model(**inputs).logits
51
+ predicted_class = torch.argmax(logits, dim=-1).item()
52
+
53
+ # Output the result
54
+ print(f"song analized:{audio_path}")
55
+ print(f"Predicted genre: {genre_mapping[predicted_class]}")
56
+