import torch import torch.nn as nn import torch.nn.functional as F import transformers from transformers import ASTConfig, ASTFeatureExtractor, ASTModel BirdAST_FEATURE_EXTRACTOR = ASTFeatureExtractor() DEFAULT_SR = 16_000 DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593" DEFAULT_N_CLASSES = 728 DEFAULT_ACTIVATION = "silu" DEFAULT_N_MLP_LAYERS = 1 def birdast_seq_preprocess(audio_array, sr=DEFAULT_SR): """ Preprocess audio array for BirdAST model audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1] sr: int, sampling rate of the audio array (default: 16_000) Note: 1. The audio array should be normalized to [-1, 1]. 2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated. """ # Extract features features = BirdAST_FEATURE_EXTRACTOR(audio_array, sampling_rate=sr, padding="max_length", return_tensors="pt") # Convert to PyTorch tensor spectrogram = torch.tensor(features['input_values']).squeeze(0) return spectrogram def birdast_seq_inference( model_weights, spectrogram, device = 'cpu', backbone_name=DEFAULT_BACKBONE, n_classes=DEFAULT_N_CLASSES, activation=DEFAULT_ACTIVATION, n_mlp_layers=DEFAULT_N_MLP_LAYERS ): """ Perform inference on BirdAST model model_weights: list, list of model weights spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,) device: str, device to run inference (default: 'cpu') backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593') n_classes: int, number of classes (default: 728) activation: str, activation function (default: 'silu') n_mlp_layers: int, number of MLP layers (default: 1) Returns: predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes) """ model = BirdAST( backbone_name=backbone_name, n_classes=n_classes, n_mlp_layers=n_mlp_layers, activation=activation ) predict_collects = [] for _weight in model_weights: model.load_state_dict(torch.load(_weight, map_location=device)) model.to(device) model.eval() with torch.no_grad(): spectrogram = spectrogram.to(device) output = model(spectrogram) logits = output['logits'] predictions = F.softmax(logits, dim=1) predict_collects.append(predictions) if device == 'cuda': predict_collects = [pred.cpu() for pred in predict_collects] predict_collects = torch.stack(predict_collects).numpy() return predict_collects class SelfAttentionPooling(nn.Module): """ Implementation of SelfAttentionPooling Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition https://arxiv.org/pdf/2008.01077v1.pdf """ def __init__(self, input_dim): super(SelfAttentionPooling, self).__init__() self.W = nn.Linear(input_dim, 1) self.softmax = nn.Softmax(dim=1) def forward(self, batch_rep): """ input: batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension attention_weight: att_w : size (N, T, 1) return: utter_rep: size (N, H) """ att_w = self.softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1) utter_rep = torch.sum(batch_rep * att_w, dim=1) return utter_rep class BirdAST(nn.Module): def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'): super(BirdAST, self).__init__() # pre-trained backbone backbone_config = ASTConfig.from_pretrained(backbone_name) self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config) self.hidden_size = backbone_config.hidden_size # set activation functions if activation == 'relu': self.activation = nn.ReLU() elif activation == 'silu': self.activation = nn.SiLU() elif activation == 'gelu': self.activation = nn.GELU() else: raise ValueError("Unsupported activation function. Choose 'relu', 'silu' or 'gelu'") #define self-attention pooling layer self.sa_pool = SelfAttentionPooling(self.hidden_size) # define MLP layers with activation layers = [] for _ in range(n_mlp_layers): layers.append(nn.Linear(self.hidden_size, self.hidden_size)) layers.append(self.activation) layers.append(nn.Linear(self.hidden_size, n_classes)) self.mlp = nn.Sequential(*layers) def forward(self, spectrogram): # spectrogram: (batch_size, n_mels, n_frames) # output: (batch_size, n_classes) ast_output = self.ast(spectrogram, output_hidden_states=False) hidden_state = ast_output.last_hidden_state pool_output = self.sa_pool(hidden_state) logits = self.mlp(pool_output) return {'logits': logits} if __name__ == '__main__': import numpy as np import matplotlib.pyplot as plt # example usage of BirdAST_Seq # create random audio array audio_array = np.random.randn(160_000 * 10) # Preprocess audio array spectrogram = birdast_seq_preprocess(audio_array) model_weights_dir = '/workspace/voice_of_jungle/training_logs' # Load model weights model_weights = [f'{model_weights_dir}/BirdAST_SeqPool_GroupKFold_fold_{i}.pth' for i in range(5)] # Perform inference predictions = birdast_seq_inference(model_weights, spectrogram.unsqueeze(0)) # Plot predictions fig, ax = plt.subplots() for i, pred in enumerate(predictions): ax.plot(pred[0], label=f'model_{i}') ax.legend() fig.savefig('test_BirdAST_Seq.png') print("Inference completed successfully!")