import json import gradio as gr import matplotlib.pyplot as plt import numpy as np import os import requests from config import Config from model import BirdAST import torch import librosa import noisereduce as nr import pandas as pd import torch.nn.functional as F import random from torchaudio.compliance import kaldi from torchaudio.functional import resample from transformers import ASTFeatureExtractor #TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" #MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() #LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" #AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) FEATURE_EXTRACTOR = ASTFeatureExtractor() def plot_mel(sr, x): mel_spec = librosa.feature.melspectrogram(y=x, sr=sr, n_mels=224, fmax=10000) mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) mel_spec_db = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) # normalize spectrogram to [0,1] mel_spec_db = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=-1) # Convert to 3-channel fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) librosa.display.specshow(mel_spec_db[:, :, 0], sr=sr, x_axis='time', y_axis='mel', fmin = 0, fmax=10000, ax = ax) return fig def plot_wave(sr, x): ry = nr.reduce_noise(y=x, sr=sr) fig, ax = plt.subplots(2, 1, figsize=(12, 8)) # Plot the original waveform librosa.display.waveshow(x, sr=sr, ax=ax[0]) ax[0].set(title='Original Waveform') ax[0].set_xlabel('Time (s)') ax[0].set_ylabel('Amplitude') # Plot the noise-reduced waveform librosa.display.waveshow(ry, sr=sr, ax=ax[1]) ax[1].set(title='Noise Reduced Waveform') ax[1].set_xlabel('Time (s)') ax[1].set_ylabel('Amplitude') plt.tight_layout() return fig def predict(audio, start, end): sr, x = audio x = np.array(x, dtype=np.float64) res = preprocess_for_inference(x, sr) if start >= end: raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") if x.shape[0] < start * sr: raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") if x.shape[0] > end * sr: end = x.shape[0]/(1.0*sr) fig1 = plot_mel(sr, x) fig2 = plot_wave(sr, x) return res, res, fig1, fig2 def download_model(url, model_path): if not os.path.exists(model_path): response = requests.get(url) response.raise_for_status() # Ensure the request was successful with open(model_path, 'wb') as f: f.write(response.content) # Model URL and path model_url = 'https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_fold_1.pth' model_path = 'BirdAST_Baseline_fold_1.pth' download_model(model_url, model_path) # Load the model (assumes you have the model architecture defined) eval_model = BirdAST(Config().backbone_name, Config().n_classes, n_mlp_layers=1, activation='silu') state_dict = torch.load('BirdAST_Baseline_fold_1.pth', map_location='cpu') eval_model.load_state_dict(state_dict) # Set to evaluation mode eval_model.eval() # Load the species mapping label_mapping = pd.read_csv('label_mapping.csv') species_id_to_name = {row['species_id']: row['scientific_name'] for index, row in label_mapping.iterrows()} def preprocess_for_inference(audio_arr, sr): spec = FEATURE_EXTRACTOR(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt") input_values = spec['input_values'] #.squeeze(0) #print(input) #print(input.shape) results = [] with torch.no_grad(): # Perform model evaluation output = eval_model(input_values) predict_score = F.softmax(output['logits'], dim=1) # Get the top 10 predictions topk_values, topk_indices = torch.topk(predict_score, 10, dim=1) # Map indices to species names and probabilities for idx, scores in zip(topk_indices[0], topk_values[0]): species_name = species_id_to_name[idx.item()] probability = scores.item() results.append([species_name, probability]) return results DESCRIPTION = """ Bird audio classification using SOTA Voice of Jungle Technology. """ """ with gr.Blocks() as demo: submit_btn = gr.Button("Submit") demo = gr.Interface( title="Bird audio classification", description=DESCRIPTION, fn=predict, inputs=["audio", "number", "number"], outputs=[ gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"), gr.Plot(label="waveform"), gr.Plot(label="spectrogram"), ], examples=[ ["312_Cissopis_leverinia_1.wav", 0, 5], ["1094_Pionus_fuscus_2.wav", 0, 10], ], ) """ css = """ .number-input { height: 100%; padding-bottom: 60px; /* Adust the value as needed for more or less space */ } .full-height { height: 100%; } .column-container { height: 100%; } """ with gr.Blocks(css = css) as demo: gr.Markdown("# Bird audio classification") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(elem_classes="column-container"): start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") end_time_input = gr.Number(label="End Time", value=1, elem_classes="number-input full-height") with gr.Column(): audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") with gr.Row(): raw_class_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Class Prediction") species_output = gr.Dataframe(headers=["class", "score"], row_count=10, label="Species Prediction") with gr.Row(): waveform_output = gr.Plot(label="Waveform") spectrogram_output = gr.Plot(label="Spectrogram") gr.Examples( examples=[ ["312_Cissopis_leverinia_1.wav", 0, 5], ["1094_Pionus_fuscus_2.wav", 0, 10], ], inputs=[audio_input, start_time_input, end_time_input] ) gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input], [raw_class_output, species_output, waveform_output, spectrogram_output]) demo.launch(share = True)